Arnab Das commited on
Commit
bd282c4
·
1 Parent(s): 49185b7

model for audio manipulation detection

Browse files
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import gradio as gr
 
 
 
 
2
 
3
  def process(filepath):
4
  return filepath
 
1
  import gradio as gr
2
+ from manipulate_model.utils import get_config_and_model
3
+
4
+
5
+ manpulate_config, manipulate_model = get_config_and_model()
6
 
7
  def process(filepath):
8
  return filepath
manipulate_model/decoder/aasist/aasist.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ # import fairseq
11
+
12
+
13
+ ___author__ = "Hemlata Tak"
14
+ __email__ = "[email protected]"
15
+
16
+ ############################
17
+ ## FOR fine-tuned SSL MODEL
18
+ ############################
19
+
20
+
21
+ # class SSLModel(nn.Module):
22
+ # def __init__(self, device):
23
+ # super(SSLModel, self).__init__()
24
+
25
+ # cp_path = "xlsr2_300m.pt" # Change the pre-trained XLSR model path.
26
+ # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
27
+ # [cp_path]
28
+ # )
29
+ # self.model = model[0]
30
+ # self.device = device
31
+ # self.out_dim = 1024
32
+ # return
33
+
34
+ # def extract_feat(self, input_data):
35
+
36
+ # # put the model to GPU if it not there
37
+ # if (
38
+ # next(self.model.parameters()).device != input_data.device
39
+ # or next(self.model.parameters()).dtype != input_data.dtype
40
+ # ):
41
+ # self.model.to(input_data.device, dtype=input_data.dtype)
42
+ # self.model.train()
43
+
44
+ # if True:
45
+ # # input should be in shape (batch, length)
46
+ # if input_data.ndim == 3:
47
+ # input_tmp = input_data[:, :, 0]
48
+ # else:
49
+ # input_tmp = input_data
50
+
51
+ # # [batch, length, dim]
52
+ # emb = self.model(input_tmp, mask=False, features_only=True)["x"]
53
+ # return emb
54
+
55
+
56
+ # ---------AASIST back-end------------------------#
57
+ """ Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans.
58
+ AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks.
59
+ In Proc. ICASSP 2022, pp: 6367--6371."""
60
+
61
+
62
+ class GraphAttentionLayer(nn.Module):
63
+ def __init__(self, in_dim, out_dim, **kwargs):
64
+ super().__init__()
65
+
66
+ # attention map
67
+ self.att_proj = nn.Linear(in_dim, out_dim)
68
+ self.att_weight = self._init_new_params(out_dim, 1)
69
+
70
+ # project
71
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
72
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
73
+
74
+ # batch norm
75
+ self.bn = nn.BatchNorm1d(out_dim)
76
+
77
+ # dropout for inputs
78
+ self.input_drop = nn.Dropout(p=0.2)
79
+
80
+ # activate
81
+ self.act = nn.SELU(inplace=True)
82
+
83
+ # temperature
84
+ self.temp = 1.0
85
+ if "temperature" in kwargs:
86
+ self.temp = kwargs["temperature"]
87
+
88
+ def forward(self, x):
89
+ """
90
+ x :(#bs, #node, #dim)
91
+ """
92
+ # apply input dropout
93
+ x = self.input_drop(x)
94
+
95
+ # derive attention map
96
+ att_map = self._derive_att_map(x)
97
+
98
+ # projection
99
+ x = self._project(x, att_map)
100
+
101
+ # apply batch norm
102
+ x = self._apply_BN(x)
103
+ x = self.act(x)
104
+ return x
105
+
106
+ def _pairwise_mul_nodes(self, x):
107
+ """
108
+ Calculates pairwise multiplication of nodes.
109
+ - for attention map
110
+ x :(#bs, #node, #dim)
111
+ out_shape :(#bs, #node, #node, #dim)
112
+ """
113
+
114
+ nb_nodes = x.size(1)
115
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
116
+ x_mirror = x.transpose(1, 2)
117
+
118
+ return x * x_mirror
119
+
120
+ def _derive_att_map(self, x):
121
+ """
122
+ x :(#bs, #node, #dim)
123
+ out_shape :(#bs, #node, #node, 1)
124
+ """
125
+ att_map = self._pairwise_mul_nodes(x)
126
+ # size: (#bs, #node, #node, #dim_out)
127
+ att_map = torch.tanh(self.att_proj(att_map))
128
+ # size: (#bs, #node, #node, 1)
129
+ att_map = torch.matmul(att_map, self.att_weight)
130
+
131
+ # apply temperature
132
+ att_map = att_map / self.temp
133
+
134
+ att_map = F.softmax(att_map, dim=-2)
135
+
136
+ return att_map
137
+
138
+ def _project(self, x, att_map):
139
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
140
+ x2 = self.proj_without_att(x)
141
+
142
+ return x1 + x2
143
+
144
+ def _apply_BN(self, x):
145
+ org_size = x.size()
146
+ x = x.view(-1, org_size[-1])
147
+ x = self.bn(x)
148
+ x = x.view(org_size)
149
+
150
+ return x
151
+
152
+ def _init_new_params(self, *size):
153
+ out = nn.Parameter(torch.FloatTensor(*size))
154
+ nn.init.xavier_normal_(out)
155
+ return out
156
+
157
+
158
+ class HtrgGraphAttentionLayer(nn.Module):
159
+ def __init__(self, in_dim, out_dim, **kwargs):
160
+ super().__init__()
161
+
162
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
163
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
164
+
165
+ # attention map
166
+ self.att_proj = nn.Linear(in_dim, out_dim)
167
+ self.att_projM = nn.Linear(in_dim, out_dim)
168
+
169
+ self.att_weight11 = self._init_new_params(out_dim, 1)
170
+ self.att_weight22 = self._init_new_params(out_dim, 1)
171
+ self.att_weight12 = self._init_new_params(out_dim, 1)
172
+ self.att_weightM = self._init_new_params(out_dim, 1)
173
+
174
+ # project
175
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
176
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
177
+
178
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
179
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
180
+
181
+ # batch norm
182
+ self.bn = nn.BatchNorm1d(out_dim)
183
+
184
+ # dropout for inputs
185
+ self.input_drop = nn.Dropout(p=0.2)
186
+
187
+ # activate
188
+ self.act = nn.SELU(inplace=True)
189
+
190
+ # temperature
191
+ self.temp = 1.0
192
+ if "temperature" in kwargs:
193
+ self.temp = kwargs["temperature"]
194
+
195
+ def forward(self, x1, x2, master=None):
196
+ """
197
+ x1 :(#bs, #node, #dim)
198
+ x2 :(#bs, #node, #dim)
199
+ """
200
+ # print('x1',x1.shape)
201
+ # print('x2',x2.shape)
202
+ num_type1 = x1.size(1)
203
+ num_type2 = x2.size(1)
204
+ # print('num_type1',num_type1)
205
+ # print('num_type2',num_type2)
206
+ x1 = self.proj_type1(x1)
207
+ # print('proj_type1',x1.shape)
208
+ x2 = self.proj_type2(x2)
209
+ # print('proj_type2',x2.shape)
210
+ x = torch.cat([x1, x2], dim=1)
211
+ # print('Concat x1 and x2',x.shape)
212
+
213
+ if master is None:
214
+ master = torch.mean(x, dim=1, keepdim=True)
215
+ # print('master',master.shape)
216
+ # apply input dropout
217
+ x = self.input_drop(x)
218
+
219
+ # derive attention map
220
+ att_map = self._derive_att_map(x, num_type1, num_type2)
221
+ # print('master',master.shape)
222
+ # directional edge for master node
223
+ master = self._update_master(x, master)
224
+ # print('master',master.shape)
225
+ # projection
226
+ x = self._project(x, att_map)
227
+ # print('proj x',x.shape)
228
+ # apply batch norm
229
+ x = self._apply_BN(x)
230
+ x = self.act(x)
231
+
232
+ x1 = x.narrow(1, 0, num_type1)
233
+ # print('x1',x1.shape)
234
+ x2 = x.narrow(1, num_type1, num_type2)
235
+ # print('x2',x2.shape)
236
+ return x1, x2, master
237
+
238
+ def _update_master(self, x, master):
239
+
240
+ att_map = self._derive_att_map_master(x, master)
241
+ master = self._project_master(x, master, att_map)
242
+
243
+ return master
244
+
245
+ def _pairwise_mul_nodes(self, x):
246
+ """
247
+ Calculates pairwise multiplication of nodes.
248
+ - for attention map
249
+ x :(#bs, #node, #dim)
250
+ out_shape :(#bs, #node, #node, #dim)
251
+ """
252
+
253
+ nb_nodes = x.size(1)
254
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
255
+ x_mirror = x.transpose(1, 2)
256
+
257
+ return x * x_mirror
258
+
259
+ def _derive_att_map_master(self, x, master):
260
+ """
261
+ x :(#bs, #node, #dim)
262
+ out_shape :(#bs, #node, #node, 1)
263
+ """
264
+ att_map = x * master
265
+ att_map = torch.tanh(self.att_projM(att_map))
266
+
267
+ att_map = torch.matmul(att_map, self.att_weightM)
268
+
269
+ # apply temperature
270
+ att_map = att_map / self.temp
271
+
272
+ att_map = F.softmax(att_map, dim=-2)
273
+
274
+ return att_map
275
+
276
+ def _derive_att_map(self, x, num_type1, num_type2):
277
+ """
278
+ x :(#bs, #node, #dim)
279
+ out_shape :(#bs, #node, #node, 1)
280
+ """
281
+ att_map = self._pairwise_mul_nodes(x)
282
+ # size: (#bs, #node, #node, #dim_out)
283
+ att_map = torch.tanh(self.att_proj(att_map))
284
+ # size: (#bs, #node, #node, 1)
285
+
286
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
287
+
288
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
289
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11
290
+ )
291
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
292
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22
293
+ )
294
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
295
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12
296
+ )
297
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
298
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12
299
+ )
300
+
301
+ att_map = att_board
302
+
303
+ # apply temperature
304
+ att_map = att_map / self.temp
305
+
306
+ att_map = F.softmax(att_map, dim=-2)
307
+
308
+ return att_map
309
+
310
+ def _project(self, x, att_map):
311
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
312
+ x2 = self.proj_without_att(x)
313
+
314
+ return x1 + x2
315
+
316
+ def _project_master(self, x, master, att_map):
317
+
318
+ x1 = self.proj_with_attM(torch.matmul(att_map.squeeze(-1).unsqueeze(1), x))
319
+ x2 = self.proj_without_attM(master)
320
+
321
+ return x1 + x2
322
+
323
+ def _apply_BN(self, x):
324
+ org_size = x.size()
325
+ x = x.view(-1, org_size[-1])
326
+ x = self.bn(x)
327
+ x = x.view(org_size)
328
+
329
+ return x
330
+
331
+ def _init_new_params(self, *size):
332
+ out = nn.Parameter(torch.FloatTensor(*size))
333
+ nn.init.xavier_normal_(out)
334
+ return out
335
+
336
+
337
+ class GraphPool(nn.Module):
338
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
339
+ super().__init__()
340
+ self.k = k
341
+ self.sigmoid = nn.Sigmoid()
342
+ self.proj = nn.Linear(in_dim, 1)
343
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
344
+ self.in_dim = in_dim
345
+
346
+ def forward(self, h):
347
+ Z = self.drop(h)
348
+ weights = self.proj(Z)
349
+ scores = self.sigmoid(weights)
350
+ new_h = self.top_k_graph(scores, h, self.k)
351
+
352
+ return new_h
353
+
354
+ def top_k_graph(self, scores, h, k):
355
+ """
356
+ args
357
+ =====
358
+ scores: attention-based weights (#bs, #node, 1)
359
+ h: graph data (#bs, #node, #dim)
360
+ k: ratio of remaining nodes, (float)
361
+ returns
362
+ =====
363
+ h: graph pool applied data (#bs, #node', #dim)
364
+ """
365
+ _, n_nodes, n_feat = h.size()
366
+ n_nodes = max(int(n_nodes * k), 1)
367
+ _, idx = torch.topk(scores, n_nodes, dim=1)
368
+ idx = idx.expand(-1, -1, n_feat)
369
+
370
+ h = h * scores
371
+ h = torch.gather(h, 1, idx)
372
+
373
+ return h
374
+
375
+
376
+ class Residual_block(nn.Module):
377
+ def __init__(self, nb_filts, first=False):
378
+ super().__init__()
379
+ self.first = first
380
+
381
+ if not self.first:
382
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
383
+ self.conv1 = nn.Conv2d(
384
+ in_channels=nb_filts[0],
385
+ out_channels=nb_filts[1],
386
+ kernel_size=(2, 3),
387
+ padding=(1, 1),
388
+ stride=1,
389
+ )
390
+ self.selu = nn.SELU(inplace=True)
391
+
392
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
393
+ self.conv2 = nn.Conv2d(
394
+ in_channels=nb_filts[1],
395
+ out_channels=nb_filts[1],
396
+ kernel_size=(2, 3),
397
+ padding=(0, 1),
398
+ stride=1,
399
+ )
400
+
401
+ if nb_filts[0] != nb_filts[1]:
402
+ self.downsample = True
403
+ self.conv_downsample = nn.Conv2d(
404
+ in_channels=nb_filts[0],
405
+ out_channels=nb_filts[1],
406
+ padding=(0, 1),
407
+ kernel_size=(1, 3),
408
+ stride=1,
409
+ )
410
+
411
+ else:
412
+ self.downsample = False
413
+
414
+ def forward(self, x):
415
+ identity = x
416
+ if not self.first:
417
+ out = self.bn1(x)
418
+ out = self.selu(out)
419
+ else:
420
+ out = x
421
+
422
+ # print('out',out.shape)
423
+ out = self.conv1(x)
424
+
425
+ # print('aft conv1 out',out.shape)
426
+ out = self.bn2(out)
427
+ out = self.selu(out)
428
+ # print('out',out.shape)
429
+ out = self.conv2(out)
430
+ # print('conv2 out',out.shape)
431
+
432
+ if self.downsample:
433
+ identity = self.conv_downsample(identity)
434
+
435
+ out += identity
436
+ # out = self.mp(out)
437
+ return out
438
+
439
+
440
+ class AASIST(nn.Module):
441
+ def __init__(self, config):
442
+ super().__init__()
443
+ self.config = config
444
+
445
+ # AASIST parameters
446
+ filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
447
+ gat_dims = [64, 32]
448
+ pool_ratios = [0.5, 0.5, 0.5, 0.5]
449
+ temperatures = [2.0, 2.0, 100.0, 100.0]
450
+
451
+ ####
452
+ # create network wav2vec 2.0
453
+ ####
454
+ # self.ssl_model = SSLModel(self.device)
455
+ self.LL = nn.Linear(self.config.model.decoder.encoding_dim, 128)
456
+
457
+ self.first_bn = nn.BatchNorm2d(num_features=1)
458
+ self.first_bn1 = nn.BatchNorm2d(num_features=64)
459
+ self.drop = nn.Dropout(0.5, inplace=True)
460
+ self.drop_way = nn.Dropout(0.2, inplace=True)
461
+ self.selu = nn.SELU(inplace=True)
462
+
463
+ # RawNet2 encoder
464
+ self.encoder = nn.Sequential(
465
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
466
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
467
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
468
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
469
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
470
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
471
+ )
472
+
473
+ self.attention = nn.Sequential(
474
+ nn.Conv2d(64, 128, kernel_size=(1, 1)),
475
+ nn.SELU(inplace=True),
476
+ nn.BatchNorm2d(128),
477
+ nn.Conv2d(128, 64, kernel_size=(1, 1)),
478
+ )
479
+ # position encoding
480
+ self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
481
+
482
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
483
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
484
+
485
+ # Graph module
486
+ self.GAT_layer_S = GraphAttentionLayer(
487
+ filts[-1][-1], gat_dims[0], temperature=temperatures[0]
488
+ )
489
+ self.GAT_layer_T = GraphAttentionLayer(
490
+ filts[-1][-1], gat_dims[0], temperature=temperatures[1]
491
+ )
492
+ # HS-GAL layer
493
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
494
+ gat_dims[0], gat_dims[1], temperature=temperatures[2]
495
+ )
496
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
497
+ gat_dims[1], gat_dims[1], temperature=temperatures[2]
498
+ )
499
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
500
+ gat_dims[0], gat_dims[1], temperature=temperatures[2]
501
+ )
502
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
503
+ gat_dims[1], gat_dims[1], temperature=temperatures[2]
504
+ )
505
+
506
+ # Graph pooling layers
507
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
508
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
509
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
510
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
511
+
512
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
513
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
514
+
515
+ if self.config.model.task == "audio-video":
516
+ self.out_layer = nn.Linear(5 * gat_dims[1], 4)
517
+ else:
518
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
519
+
520
+ def forward(self, x):
521
+ # -------pre-trained Wav2vec model fine tunning ------------------------##
522
+ # x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
523
+ x = self.LL(x) # (bs,frame_number,feat_out_dim)
524
+
525
+ # post-processing on front-end features
526
+ x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number)
527
+ x = x.unsqueeze(dim=1) # add channel
528
+ x = F.max_pool2d(x, (3, 3))
529
+ x = self.first_bn(x)
530
+ x = self.selu(x)
531
+
532
+ # RawNet2-based encoder
533
+ x = self.encoder(x)
534
+ x = self.first_bn1(x)
535
+ x = self.selu(x)
536
+
537
+ w = self.attention(x)
538
+
539
+ # ------------SA for spectral feature-------------#
540
+ w1 = F.softmax(w, dim=-1)
541
+ m = torch.sum(x * w1, dim=-1)
542
+ e_S = m.transpose(1, 2) + self.pos_S
543
+
544
+ # graph module layer
545
+ gat_S = self.GAT_layer_S(e_S)
546
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
547
+
548
+ # ------------SA for temporal feature-------------#
549
+ w2 = F.softmax(w, dim=-2)
550
+ m1 = torch.sum(x * w2, dim=-2)
551
+
552
+ e_T = m1.transpose(1, 2)
553
+
554
+ # graph module layer
555
+ gat_T = self.GAT_layer_T(e_T)
556
+ out_T = self.pool_T(gat_T)
557
+
558
+ # learnable master node
559
+ master1 = self.master1.expand(x.size(0), -1, -1)
560
+ master2 = self.master2.expand(x.size(0), -1, -1)
561
+
562
+ # inference 1
563
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
564
+ out_T, out_S, master=self.master1
565
+ )
566
+
567
+ out_S1 = self.pool_hS1(out_S1)
568
+ out_T1 = self.pool_hT1(out_T1)
569
+
570
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
571
+ out_T1, out_S1, master=master1
572
+ )
573
+ out_T1 = out_T1 + out_T_aug
574
+ out_S1 = out_S1 + out_S_aug
575
+ master1 = master1 + master_aug
576
+
577
+ # inference 2
578
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
579
+ out_T, out_S, master=self.master2
580
+ )
581
+ out_S2 = self.pool_hS2(out_S2)
582
+ out_T2 = self.pool_hT2(out_T2)
583
+
584
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
585
+ out_T2, out_S2, master=master2
586
+ )
587
+ out_T2 = out_T2 + out_T_aug
588
+ out_S2 = out_S2 + out_S_aug
589
+ master2 = master2 + master_aug
590
+
591
+ out_T1 = self.drop_way(out_T1)
592
+ out_T2 = self.drop_way(out_T2)
593
+ out_S1 = self.drop_way(out_S1)
594
+ out_S2 = self.drop_way(out_S2)
595
+ master1 = self.drop_way(master1)
596
+ master2 = self.drop_way(master2)
597
+
598
+ out_T = torch.max(out_T1, out_T2)
599
+ out_S = torch.max(out_S1, out_S2)
600
+ master = torch.max(master1, master2)
601
+
602
+ # Readout operation
603
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
604
+ T_avg = torch.mean(out_T, dim=1)
605
+
606
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
607
+ S_avg = torch.mean(out_S, dim=1)
608
+
609
+ last_hidden = torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
610
+
611
+ last_hidden = self.drop(last_hidden)
612
+ output = self.out_layer(last_hidden)
613
+
614
+ if self.config.model.task == "audio-video":
615
+ output = output.view(-1, 2, 2)
616
+
617
+ return output
manipulate_model/decoder/decoder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Decoder(nn.Module):
6
+ def __init__(self, config):
7
+ super(Decoder, self).__init__()
8
+ self.config = config
9
+
10
+ self.decoder = None
11
+
12
+ if config.model.decoder.name.lower() == "aasist":
13
+ from manipulate_model.decoder.aasist.aasist import AASIST
14
+
15
+ self.decoder = AASIST(config)
16
+ else:
17
+ raise ValueError("Invalid decoder name")
18
+
19
+ def forward(self, x):
20
+ return self.decoder(x)
manipulate_model/demo-model/audio/config.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ task: audio
3
+ encoder:
4
+ name: wavlm
5
+ version: base
6
+ pretrained: true
7
+ pretrained_path: manipulate_model/encoder_checkpoints/wavlm/WavLM-Base+.pt
8
+ output_layer: 3
9
+ encoder_freeze: false
10
+ decoder:
11
+ name: aasist
12
+ version: default
13
+ output_size: 2
14
+ online_encoding: true
15
+ data:
16
+ name: av1m
17
+ train_parts: all
18
+ val_parts: all
19
+ test_parts: all
20
+ train_size: -1
21
+ val_size: -1
22
+ test_size: -1
23
+ shape:
24
+ - 3
25
+ - 224
26
+ - 224
27
+ sr: 16000
28
+ fps: 25
29
+ center_transition: true
30
+ window_size: 4
31
+ sliding_window: false
32
+ train:
33
+ num_workers: 16
34
+ batch_size: 64
35
+ num_epochs: 15
36
+ optimizer: adam
37
+ scheduler: step
38
+ lr: 0.0001
39
+ step_size: 1
40
+ gamma: 0.1
41
+ loss: bce
42
+ log_interval: 100
43
+ shuffle: true
44
+ debug: false
manipulate_model/encoder/encoder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn.functional import pad
5
+ from collections import OrderedDict
6
+
7
+
8
+ class Encoder(nn.Module):
9
+ def __init__(self, config):
10
+ super(Encoder, self).__init__()
11
+ self.config = config
12
+
13
+ self.encoder = None
14
+ self.succeeding_layers = None
15
+
16
+ # AUDIO
17
+ if self.config.model.task == "audio":
18
+ if self.config.model.encoder.name.lower() == "wavlm":
19
+ from manipulate_model.encoder.wavlm.WavLM import WavLM, WavLMConfig
20
+
21
+ ckpt = torch.load(
22
+ config.model.encoder.pretrained_path, map_location="cpu"
23
+ )
24
+ cfg = WavLMConfig(ckpt["cfg"])
25
+ self.encoder = WavLM(cfg)
26
+
27
+
28
+ def forward(self, x):
29
+ if self.config.model.encoder.name.lower() == "wavlm":
30
+ return self.encoder(x, output_layer=self.config.model.encoder.output_layer)
31
+ elif self.config.model.encoder.name.lower() == "videomamba":
32
+ return self.encoder(x)
33
+
34
+ return self.encoder(x)
35
+
36
+ def get_encoding_dim(self):
37
+ return self.encoder.get_encoding_dim()
38
+
39
+ def get_temporal_dim(self):
40
+ return self.encoder.get_temporal_dim(window_size=self.config.data.window_size)
manipulate_model/encoder/wavlm/WavLM.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from manipulate_model.encoder.wavlm.modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = (
165
+ "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
166
+ )
167
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
168
+
169
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
170
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
171
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
172
+ self.activation_fn: str = "gelu" # activation function to use
173
+
174
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
175
+ self.conv_feature_layers: str = (
176
+ "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
177
+ )
178
+ self.conv_bias: bool = False # include bias in conv encoder
179
+ self.feature_grad_mult: float = (
180
+ 1.0 # multiply feature extractor var grads by this
181
+ )
182
+
183
+ self.normalize: bool = (
184
+ False # normalize input to have 0 mean and unit variance during training
185
+ )
186
+
187
+ # dropouts
188
+ self.dropout: float = 0.1 # dropout probability for the transformer
189
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
190
+ self.activation_dropout: float = (
191
+ 0.0 # dropout probability after activation in FFN
192
+ )
193
+ self.encoder_layerdrop: float = (
194
+ 0.0 # probability of dropping a tarnsformer layer
195
+ )
196
+ self.dropout_input: float = (
197
+ 0.0 # dropout to apply to the input (after feat extr)
198
+ )
199
+ self.dropout_features: float = (
200
+ 0.0 # dropout to apply to the features (after feat extr)
201
+ )
202
+
203
+ # masking
204
+ self.mask_length: int = 10 # mask length
205
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
206
+ self.mask_selection: str = "static" # how to choose mask length
207
+ self.mask_other: float = (
208
+ 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
209
+ )
210
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
211
+ self.mask_min_space: int = (
212
+ 1 # min space between spans (if no overlap is enabled)
213
+ )
214
+
215
+ # channel masking
216
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
217
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
218
+ self.mask_channel_selection: str = (
219
+ "static" # how to choose mask length for channel masking
220
+ )
221
+ self.mask_channel_other: float = (
222
+ 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
223
+ )
224
+ self.no_mask_channel_overlap: bool = (
225
+ False # whether to allow channel masks to overlap
226
+ )
227
+ self.mask_channel_min_space: int = (
228
+ 1 # min space between spans (if no overlap is enabled)
229
+ )
230
+
231
+ # positional embeddings
232
+ self.conv_pos: int = (
233
+ 128 # number of filters for convolutional positional embeddings
234
+ )
235
+ self.conv_pos_groups: int = (
236
+ 16 # number of groups for convolutional positional embedding
237
+ )
238
+
239
+ # relative position embedding
240
+ self.relative_position_embedding: bool = (
241
+ False # apply relative position embedding
242
+ )
243
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
244
+ self.max_distance: int = (
245
+ 1280 # maximum distance for relative position embedding
246
+ )
247
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
248
+
249
+ if cfg is not None:
250
+ self.update(cfg)
251
+
252
+ def update(self, cfg: dict):
253
+ self.__dict__.update(cfg)
254
+
255
+
256
+ class WavLM(nn.Module):
257
+ def __init__(
258
+ self,
259
+ cfg: WavLMConfig,
260
+ ) -> None:
261
+ super().__init__()
262
+ logger.info(f"WavLM Config: {cfg.__dict__}")
263
+
264
+ self.cfg = cfg
265
+ feature_enc_layers = eval(cfg.conv_feature_layers)
266
+ self.embed = feature_enc_layers[-1][0]
267
+
268
+ self.feature_extractor = ConvFeatureExtractionModel(
269
+ conv_layers=feature_enc_layers,
270
+ dropout=0.0,
271
+ mode=cfg.extractor_mode,
272
+ conv_bias=cfg.conv_bias,
273
+ )
274
+
275
+ self.post_extract_proj = (
276
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
277
+ if self.embed != cfg.encoder_embed_dim
278
+ else None
279
+ )
280
+
281
+ self.mask_prob = cfg.mask_prob
282
+ self.mask_selection = cfg.mask_selection
283
+ self.mask_other = cfg.mask_other
284
+ self.mask_length = cfg.mask_length
285
+ self.no_mask_overlap = cfg.no_mask_overlap
286
+ self.mask_min_space = cfg.mask_min_space
287
+
288
+ self.mask_channel_prob = cfg.mask_channel_prob
289
+ self.mask_channel_selection = cfg.mask_channel_selection
290
+ self.mask_channel_other = cfg.mask_channel_other
291
+ self.mask_channel_length = cfg.mask_channel_length
292
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
293
+ self.mask_channel_min_space = cfg.mask_channel_min_space
294
+
295
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
296
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
297
+
298
+ self.feature_grad_mult = cfg.feature_grad_mult
299
+
300
+ self.mask_emb = nn.Parameter(
301
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
302
+ )
303
+
304
+ self.encoder = TransformerEncoder(cfg)
305
+ self.layer_norm = LayerNorm(self.embed)
306
+
307
+ def apply_mask(self, x, padding_mask):
308
+ B, T, C = x.shape
309
+ if self.mask_prob > 0:
310
+ mask_indices = compute_mask_indices(
311
+ (B, T),
312
+ padding_mask,
313
+ self.mask_prob,
314
+ self.mask_length,
315
+ self.mask_selection,
316
+ self.mask_other,
317
+ min_masks=2,
318
+ no_overlap=self.no_mask_overlap,
319
+ min_space=self.mask_min_space,
320
+ )
321
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
322
+ x[mask_indices] = self.mask_emb
323
+ else:
324
+ mask_indices = None
325
+
326
+ if self.mask_channel_prob > 0:
327
+ mask_channel_indices = compute_mask_indices(
328
+ (B, C),
329
+ None,
330
+ self.mask_channel_prob,
331
+ self.mask_channel_length,
332
+ self.mask_channel_selection,
333
+ self.mask_channel_other,
334
+ no_overlap=self.no_mask_channel_overlap,
335
+ min_space=self.mask_channel_min_space,
336
+ )
337
+ mask_channel_indices = (
338
+ torch.from_numpy(mask_channel_indices)
339
+ .to(x.device)
340
+ .unsqueeze(1)
341
+ .expand(-1, T, -1)
342
+ )
343
+ x[mask_channel_indices] = 0
344
+
345
+ return x, mask_indices
346
+
347
+ def forward_padding_mask(
348
+ self,
349
+ features: torch.Tensor,
350
+ padding_mask: torch.Tensor,
351
+ ) -> torch.Tensor:
352
+ extra = padding_mask.size(1) % features.size(1)
353
+ if extra > 0:
354
+ padding_mask = padding_mask[:, :-extra]
355
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
356
+ padding_mask = padding_mask.all(-1)
357
+ return padding_mask
358
+
359
+ def extract_features(
360
+ self,
361
+ source: torch.Tensor,
362
+ padding_mask: Optional[torch.Tensor] = None,
363
+ mask: bool = False,
364
+ ret_conv: bool = False,
365
+ output_layer: Optional[int] = None,
366
+ ret_layer_results: bool = False,
367
+ ):
368
+
369
+ if self.feature_grad_mult > 0:
370
+ features = self.feature_extractor(source)
371
+ if self.feature_grad_mult != 1.0:
372
+ features = GradMultiply.apply(features, self.feature_grad_mult)
373
+ else:
374
+ with torch.no_grad():
375
+ features = self.feature_extractor(source)
376
+
377
+ features = features.transpose(1, 2)
378
+ features = self.layer_norm(features)
379
+
380
+ if padding_mask is not None:
381
+ padding_mask = self.forward_padding_mask(features, padding_mask)
382
+
383
+ if self.post_extract_proj is not None:
384
+ features = self.post_extract_proj(features)
385
+
386
+ features = self.dropout_input(features)
387
+
388
+ if mask:
389
+ x, mask_indices = self.apply_mask(features, padding_mask)
390
+ else:
391
+ x = features
392
+
393
+ # feature: (B, T, D), float
394
+ # target: (B, T), long
395
+ # x: (B, T, D), float
396
+ # padding_mask: (B, T), bool
397
+ # mask_indices: (B, T), bool
398
+ x, layer_results = self.encoder(
399
+ x,
400
+ padding_mask=padding_mask,
401
+ layer=None if output_layer is None else output_layer - 1,
402
+ )
403
+
404
+ res = {
405
+ "x": x,
406
+ "padding_mask": padding_mask,
407
+ "features": features,
408
+ "layer_results": layer_results,
409
+ }
410
+
411
+ feature = res["features"] if ret_conv else res["x"]
412
+ if ret_layer_results:
413
+ feature = (feature, res["layer_results"])
414
+ return feature, res["padding_mask"]
415
+
416
+ def forward(self, x, output_layer=None):
417
+ return self.extract_features(x, output_layer=output_layer)[0]
418
+
419
+ def get_encoding_dim(self):
420
+ return self.cfg.encoder_embed_dim
421
+
422
+ def get_temporal_dim(self, window_size):
423
+ return 2 * window_size - 1
424
+
425
+
426
+ class ConvFeatureExtractionModel(nn.Module):
427
+ def __init__(
428
+ self,
429
+ conv_layers: List[Tuple[int, int, int]],
430
+ dropout: float = 0.0,
431
+ mode: str = "default",
432
+ conv_bias: bool = False,
433
+ conv_type: str = "default",
434
+ ):
435
+ super().__init__()
436
+
437
+ assert mode in {"default", "layer_norm"}
438
+
439
+ def block(
440
+ n_in,
441
+ n_out,
442
+ k,
443
+ stride,
444
+ is_layer_norm=False,
445
+ is_group_norm=False,
446
+ conv_bias=False,
447
+ ):
448
+ def make_conv():
449
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
450
+ nn.init.kaiming_normal_(conv.weight)
451
+ return conv
452
+
453
+ assert (
454
+ is_layer_norm and is_group_norm
455
+ ) == False, "layer norm and group norm are exclusive"
456
+
457
+ if is_layer_norm:
458
+ return nn.Sequential(
459
+ make_conv(),
460
+ nn.Dropout(p=dropout),
461
+ nn.Sequential(
462
+ TransposeLast(),
463
+ Fp32LayerNorm(dim, elementwise_affine=True),
464
+ TransposeLast(),
465
+ ),
466
+ nn.GELU(),
467
+ )
468
+ elif is_group_norm:
469
+ return nn.Sequential(
470
+ make_conv(),
471
+ nn.Dropout(p=dropout),
472
+ Fp32GroupNorm(dim, dim, affine=True),
473
+ nn.GELU(),
474
+ )
475
+ else:
476
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
477
+
478
+ self.conv_type = conv_type
479
+ if self.conv_type == "default":
480
+ in_d = 1
481
+ self.conv_layers = nn.ModuleList()
482
+ for i, cl in enumerate(conv_layers):
483
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
484
+ (dim, k, stride) = cl
485
+
486
+ self.conv_layers.append(
487
+ block(
488
+ in_d,
489
+ dim,
490
+ k,
491
+ stride,
492
+ is_layer_norm=mode == "layer_norm",
493
+ is_group_norm=mode == "default" and i == 0,
494
+ conv_bias=conv_bias,
495
+ )
496
+ )
497
+ in_d = dim
498
+ elif self.conv_type == "conv2d":
499
+ in_d = 1
500
+ self.conv_layers = nn.ModuleList()
501
+ for i, cl in enumerate(conv_layers):
502
+ assert len(cl) == 3
503
+ (dim, k, stride) = cl
504
+
505
+ self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
506
+ self.conv_layers.append(torch.nn.ReLU())
507
+ in_d = dim
508
+ elif self.conv_type == "custom":
509
+ in_d = 1
510
+ idim = 80
511
+ self.conv_layers = nn.ModuleList()
512
+ for i, cl in enumerate(conv_layers):
513
+ assert len(cl) == 3
514
+ (dim, k, stride) = cl
515
+ self.conv_layers.append(
516
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
517
+ )
518
+ self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
519
+ self.conv_layers.append(torch.nn.ReLU())
520
+ in_d = dim
521
+ if (i + 1) % 2 == 0:
522
+ self.conv_layers.append(
523
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
524
+ )
525
+ idim = int(math.ceil(idim / 2))
526
+ else:
527
+ pass
528
+
529
+ def forward(self, x, mask=None):
530
+
531
+ # BxT -> BxCxT
532
+ x = x.unsqueeze(1)
533
+ if self.conv_type == "custom":
534
+ for conv in self.conv_layers:
535
+ if isinstance(conv, nn.LayerNorm):
536
+ x = x.transpose(1, 2)
537
+ x = conv(x).transpose(1, 2)
538
+ else:
539
+ x = conv(x)
540
+ x = x.transpose(2, 3).contiguous()
541
+ x = x.view(x.size(0), -1, x.size(-1))
542
+ else:
543
+ for conv in self.conv_layers:
544
+ x = conv(x)
545
+ if self.conv_type == "conv2d":
546
+ b, c, t, f = x.size()
547
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
548
+ return x
549
+
550
+
551
+ class TransformerEncoder(nn.Module):
552
+ def __init__(self, args):
553
+ super().__init__()
554
+
555
+ self.dropout = args.dropout
556
+ self.embedding_dim = args.encoder_embed_dim
557
+
558
+ self.pos_conv = nn.Conv1d(
559
+ self.embedding_dim,
560
+ self.embedding_dim,
561
+ kernel_size=args.conv_pos,
562
+ padding=args.conv_pos // 2,
563
+ groups=args.conv_pos_groups,
564
+ )
565
+ dropout = 0
566
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
567
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
568
+ nn.init.constant_(self.pos_conv.bias, 0)
569
+
570
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
571
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
572
+
573
+ if hasattr(args, "relative_position_embedding"):
574
+ self.relative_position_embedding = args.relative_position_embedding
575
+ self.num_buckets = args.num_buckets
576
+ self.max_distance = args.max_distance
577
+ else:
578
+ self.relative_position_embedding = False
579
+ self.num_buckets = 0
580
+ self.max_distance = 0
581
+
582
+ self.layers = nn.ModuleList(
583
+ [
584
+ TransformerSentenceEncoderLayer(
585
+ embedding_dim=self.embedding_dim,
586
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
587
+ num_attention_heads=args.encoder_attention_heads,
588
+ dropout=self.dropout,
589
+ attention_dropout=args.attention_dropout,
590
+ activation_dropout=args.activation_dropout,
591
+ activation_fn=args.activation_fn,
592
+ layer_norm_first=args.layer_norm_first,
593
+ has_relative_attention_bias=(
594
+ self.relative_position_embedding and i == 0
595
+ ),
596
+ num_buckets=self.num_buckets,
597
+ max_distance=self.max_distance,
598
+ gru_rel_pos=args.gru_rel_pos,
599
+ )
600
+ for i in range(args.encoder_layers)
601
+ ]
602
+ )
603
+
604
+ self.layer_norm_first = args.layer_norm_first
605
+ self.layer_norm = LayerNorm(self.embedding_dim)
606
+ self.layerdrop = args.encoder_layerdrop
607
+
608
+ self.apply(init_bert_params)
609
+
610
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
611
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
612
+
613
+ if self.layer_norm_first and layer is None:
614
+ x = self.layer_norm(x)
615
+
616
+ return x, layer_results
617
+
618
+ def extract_features(
619
+ self, x, padding_mask=None, streaming_mask=None, tgt_layer=None
620
+ ):
621
+
622
+ if padding_mask is not None:
623
+ x[padding_mask] = 0
624
+
625
+ x_conv = self.pos_conv(x.transpose(1, 2))
626
+ x_conv = x_conv.transpose(1, 2)
627
+ x = x + x_conv
628
+
629
+ if not self.layer_norm_first:
630
+ x = self.layer_norm(x)
631
+
632
+ x = F.dropout(x, p=self.dropout, training=self.training)
633
+
634
+ # B x T x C -> T x B x C
635
+ x = x.transpose(0, 1)
636
+
637
+ layer_results = []
638
+ z = None
639
+ if tgt_layer is not None:
640
+ layer_results.append((x, z))
641
+ r = None
642
+ pos_bias = None
643
+ for i, layer in enumerate(self.layers):
644
+ dropout_probability = np.random.random()
645
+ if not self.training or (dropout_probability > self.layerdrop):
646
+ x, z, pos_bias = layer(
647
+ x,
648
+ self_attn_padding_mask=padding_mask,
649
+ need_weights=False,
650
+ self_attn_mask=streaming_mask,
651
+ pos_bias=pos_bias,
652
+ )
653
+ if tgt_layer is not None:
654
+ layer_results.append((x, z))
655
+ if i == tgt_layer:
656
+ r = x
657
+ break
658
+
659
+ if r is not None:
660
+ x = r
661
+
662
+ # T x B x C -> B x T x C
663
+ x = x.transpose(0, 1)
664
+
665
+ return x, layer_results
666
+
667
+
668
+ class TransformerSentenceEncoderLayer(nn.Module):
669
+ """
670
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
671
+ models.
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ embedding_dim: float = 768,
677
+ ffn_embedding_dim: float = 3072,
678
+ num_attention_heads: float = 8,
679
+ dropout: float = 0.1,
680
+ attention_dropout: float = 0.1,
681
+ activation_dropout: float = 0.1,
682
+ activation_fn: str = "relu",
683
+ layer_norm_first: bool = False,
684
+ has_relative_attention_bias: bool = False,
685
+ num_buckets: int = 0,
686
+ max_distance: int = 0,
687
+ rescale_init: bool = False,
688
+ gru_rel_pos: bool = False,
689
+ ) -> None:
690
+
691
+ super().__init__()
692
+ # Initialize parameters
693
+ self.embedding_dim = embedding_dim
694
+ self.dropout = dropout
695
+ self.activation_dropout = activation_dropout
696
+
697
+ # Initialize blocks
698
+ self.activation_name = activation_fn
699
+ self.activation_fn = get_activation_fn(activation_fn)
700
+ self.self_attn = MultiheadAttention(
701
+ self.embedding_dim,
702
+ num_attention_heads,
703
+ dropout=attention_dropout,
704
+ self_attention=True,
705
+ has_relative_attention_bias=has_relative_attention_bias,
706
+ num_buckets=num_buckets,
707
+ max_distance=max_distance,
708
+ rescale_init=rescale_init,
709
+ gru_rel_pos=gru_rel_pos,
710
+ )
711
+
712
+ self.dropout1 = nn.Dropout(dropout)
713
+ self.dropout2 = nn.Dropout(self.activation_dropout)
714
+ self.dropout3 = nn.Dropout(dropout)
715
+
716
+ self.layer_norm_first = layer_norm_first
717
+
718
+ # layer norm associated with the self attention layer
719
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
720
+
721
+ if self.activation_name == "glu":
722
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
723
+ else:
724
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
725
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
726
+
727
+ # layer norm associated with the position wise feed-forward NN
728
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
729
+
730
+ def forward(
731
+ self,
732
+ x: torch.Tensor,
733
+ self_attn_mask: torch.Tensor = None,
734
+ self_attn_padding_mask: torch.Tensor = None,
735
+ need_weights: bool = False,
736
+ pos_bias=None,
737
+ ):
738
+ """
739
+ LayerNorm is applied either before or after the self-attention/ffn
740
+ modules similar to the original Transformer imlementation.
741
+ """
742
+ residual = x
743
+
744
+ if self.layer_norm_first:
745
+ x = self.self_attn_layer_norm(x)
746
+ x, attn, pos_bias = self.self_attn(
747
+ query=x,
748
+ key=x,
749
+ value=x,
750
+ key_padding_mask=self_attn_padding_mask,
751
+ need_weights=False,
752
+ attn_mask=self_attn_mask,
753
+ position_bias=pos_bias,
754
+ )
755
+ x = self.dropout1(x)
756
+ x = residual + x
757
+
758
+ residual = x
759
+ x = self.final_layer_norm(x)
760
+ if self.activation_name == "glu":
761
+ x = self.fc1(x)
762
+ else:
763
+ x = self.activation_fn(self.fc1(x))
764
+ x = self.dropout2(x)
765
+ x = self.fc2(x)
766
+ x = self.dropout3(x)
767
+ x = residual + x
768
+ else:
769
+ x, attn, pos_bias = self.self_attn(
770
+ query=x,
771
+ key=x,
772
+ value=x,
773
+ key_padding_mask=self_attn_padding_mask,
774
+ need_weights=need_weights,
775
+ attn_mask=self_attn_mask,
776
+ position_bias=pos_bias,
777
+ )
778
+
779
+ x = self.dropout1(x)
780
+ x = residual + x
781
+
782
+ x = self.self_attn_layer_norm(x)
783
+
784
+ residual = x
785
+ if self.activation_name == "glu":
786
+ x = self.fc1(x)
787
+ else:
788
+ x = self.activation_fn(self.fc1(x))
789
+ x = self.dropout2(x)
790
+ x = self.fc2(x)
791
+ x = self.dropout3(x)
792
+ x = residual + x
793
+ x = self.final_layer_norm(x)
794
+
795
+ return x, attn, pos_bias
manipulate_model/encoder/wavlm/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights
manipulate_model/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from manipulate_model.encoder.encoder import Encoder
5
+ from manipulate_model.decoder.decoder import Decoder
6
+
7
+
8
+ class Model(nn.Module):
9
+ def __init__(self, config):
10
+ super(Model, self).__init__()
11
+ self.config = config
12
+
13
+ self.encoder = Encoder(self.config)
14
+ self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim()
15
+ self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim()
16
+ self.decoder = Decoder(self.config)
17
+
18
+ def forward(self, x):
19
+ if self.config.model.encoder_freeze:
20
+ with torch.no_grad():
21
+ x = self.encoder(x)
22
+ else:
23
+ x = self.encoder(x)
24
+ x = self.decoder(x)
25
+ return x
manipulate_model/utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from torchvision.io import read_video
7
+ from torch.nn.functional import pad, normalize, softmax
8
+
9
+ from manipulate_model.model import Model
10
+
11
+
12
+
13
+ def get_config_and_model(model_root="manipulate_model/demo-model/audio"):
14
+ config_path = os.path.join(model_root, "config.yaml")
15
+ config = OmegaConf.load(config_path)
16
+ if isinstance(config.model.encoder, str):
17
+ config.model.encoder = OmegaConf.load(config.model.encoder)
18
+ if isinstance(config.model.decoder, str):
19
+ config.model.decoder = OmegaConf.load(config.model.decoder)
20
+
21
+ model = Model(config)
22
+ #weights = torch.load(os.path.join(model_root, "weights.pt"))
23
+ #model.load_state_dict(weights["model_state_dict"])
24
+
25
+ return config, model
26
+
27
+
28
+ def load_audio(file_path, config):
29
+ # Load audio
30
+ # Parameters
31
+ # ----------
32
+ # file_path : str
33
+ # Path to audio file
34
+ # Returns
35
+ # -------
36
+ # torch.Tensor
37
+
38
+ audio = None
39
+
40
+ if file_path.endswith(".wav") or file_path.endswith(".flac"):
41
+ audio, sample_rate = torchaudio.load(file_path)
42
+ elif file_path.endswith(".mp3"):
43
+ pass
44
+ elif file_path.endswith(".mp4"):
45
+ #_, audio, _ = read_video(file_path)
46
+ pass
47
+
48
+ return preprocess_audio(audio, config)
49
+
50
+
51
+ def preprocess_audio(audio, config, step_size=1):
52
+ # Preprocess audio
53
+ # Parameters
54
+ # ----------
55
+ # audio : torch.Tensor
56
+ # Audio signal
57
+ # config : OmegaConf
58
+ # Configuration object
59
+ # Returns
60
+ # -------
61
+ # torch.Tensor : Normalized audio signal
62
+
63
+ window_size = config.data.window_size
64
+ sr = config.data.sr
65
+ fps = config.data.fps
66
+
67
+ audio_len = audio.shape[1]
68
+ step_size = step_size * (sr // fps)
69
+ window_size = window_size * (sr // fps)
70
+ audio = pad(audio, (window_size, window_size), "constant", 0)
71
+
72
+ sliced_audio = []
73
+
74
+ for i in range(0, audio_len + window_size, step_size):
75
+ audio_slice = audio[:, i : i + window_size]
76
+
77
+ if audio_slice.shape[1] < window_size:
78
+ audio_slice = pad(
79
+ audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0
80
+ )
81
+
82
+ audio_slice = normalize(audio_slice, dim=1)
83
+ sliced_audio.append(audio_slice)
84
+
85
+ sliced_audio = torch.stack(sliced_audio).squeeze()
86
+
87
+ return sliced_audio
88
+
89
+
90
+ def infere(model, x, config, device="cpu", bs=8):
91
+ print(x)
92
+ model.eval()
93
+
94
+ x = load_audio(x, config)
95
+
96
+ # Inference (x is a stack of windows)
97
+ frame_predictions = []
98
+
99
+ with torch.no_grad():
100
+ n_iter = x.shape[0]
101
+
102
+ for i in range(0, n_iter, bs):
103
+ input_batch = x[i: i + bs]
104
+ input_batch = input_batch.to(device)
105
+
106
+ output = softmax(model(input_batch), dim=1)
107
+ frame_predictions.append(output.cpu().numpy())
108
+
109
+ frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0]
110
+
111
+
112
+ return frame_predictions
113
+
114
+ def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size):
115
+ # Convert frame predictions to timestamps
116
+ # Parameters
117
+ # ----------
118
+ # frame_predictions : np.ndarray
119
+ # Frame predictions
120
+ # fps : int
121
+ # Frames per second
122
+ # Returns
123
+ # -------
124
+ # np.ndarray : Timestamps
125
+
126
+ frame_predictions = (
127
+ frame_predictions[
128
+ int(window_size / 2) : -int(window_size / 2), 0
129
+ ] # removes the padding, does not consider step size as of now
130
+ .round()
131
+ .astype(int)
132
+ )
133
+ timestamps = []
134
+
135
+ for i, frame_prediction in enumerate(frame_predictions):
136
+ if frame_prediction == 1:
137
+ timestamps.append(i / fps)
138
+
139
+ return timestamps