theolepage commited on
Commit
430712c
1 Parent(s): 1e28455

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+
3
+ exp/
4
+ data/
5
+
6
+ WavLM-Base+.pt
DatasetLoader.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import numpy
6
+ import random
7
+ import pdb
8
+ import os
9
+ import threading
10
+ import time
11
+ import math
12
+ import glob
13
+ # import soundfile
14
+ from scipy import signal
15
+ import soundfile
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import torch.distributed as dist
18
+
19
+ def round_down(num, divisor):
20
+ return num - (num%divisor)
21
+
22
+ def worker_init_fn(worker_id):
23
+ numpy.random.seed(numpy.random.get_state()[1][0] + worker_id)
24
+
25
+
26
+ def loadWAV(filename, max_frames, evalmode=True, num_eval=5):
27
+
28
+ # Maximum audio length
29
+ max_audio = max_frames * 160 + 240
30
+
31
+ # Read wav file and convert to torch tensor
32
+ audio, sample_rate = soundfile.read(filename)
33
+
34
+
35
+ audiosize = audio.shape[0]
36
+
37
+ if audiosize <= max_audio:
38
+ shortage = max_audio - audiosize + 1
39
+ audio = numpy.pad(audio, (0, shortage), 'wrap')
40
+ audiosize = audio.shape[0]
41
+
42
+ if evalmode:
43
+ startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval)
44
+ else:
45
+ startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))])
46
+
47
+ feats = []
48
+ if evalmode and max_frames == 0:
49
+ feats.append(audio)
50
+ else:
51
+ for asf in startframe:
52
+ feats.append(audio[int(asf):int(asf)+max_audio])
53
+
54
+ feat = numpy.stack(feats,axis=0).astype(float)
55
+
56
+ return feat;
57
+
58
+ class AugmentWAV(object):
59
+
60
+ def __init__(self, musan_path, rir_path, max_frames):
61
+
62
+ self.max_frames = max_frames
63
+ self.max_audio = max_audio = max_frames * 160 + 240
64
+
65
+ self.noisetypes = ['noise','speech','music']
66
+
67
+ self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]}
68
+ self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1] }
69
+ self.noiselist = {}
70
+
71
+ augment_files = glob.glob(os.path.join(musan_path,'*/*/*.wav'));
72
+
73
+ for file in augment_files:
74
+ if not file.split('/')[-3] in self.noiselist:
75
+ self.noiselist[file.split('/')[-3]] = []
76
+ self.noiselist[file.split('/')[-3]].append(file)
77
+
78
+ self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav'));
79
+
80
+ def additive_noise(self, noisecat, audio):
81
+
82
+ clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4)
83
+
84
+ numnoise = self.numnoise[noisecat]
85
+ noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1]))
86
+
87
+ noises = []
88
+
89
+ for noise in noiselist:
90
+
91
+ noiseaudio = loadWAV(noise, self.max_frames, evalmode=False)
92
+ noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])
93
+ noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4)
94
+ noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio)
95
+
96
+ return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio
97
+
98
+ def reverberate(self, audio):
99
+
100
+ rir_file = random.choice(self.rir_files)
101
+
102
+ rir, fs = soundfile.read(rir_file)
103
+ rir = numpy.expand_dims(rir.astype(float),0)
104
+ rir = rir / numpy.sqrt(numpy.sum(rir**2))
105
+
106
+ return signal.convolve(audio, rir, mode='full')[:,:self.max_audio]
107
+
108
+
109
+ class train_dataset_loader(Dataset):
110
+ def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs):
111
+
112
+ self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames)
113
+
114
+ self.train_list = train_list
115
+ self.max_frames = max_frames;
116
+ self.musan_path = musan_path
117
+ self.rir_path = rir_path
118
+ self.augment = augment
119
+
120
+ # Read training files
121
+ with open(train_list) as dataset_file:
122
+ lines = dataset_file.readlines();
123
+
124
+ # Make a dictionary of ID names and ID indices
125
+ dictkeys = list(set([x.split()[0] for x in lines]))
126
+ dictkeys.sort()
127
+ dictkeys = { key : ii for ii, key in enumerate(dictkeys) }
128
+
129
+ # Parse the training list into file names and ID indices
130
+ self.data_list = []
131
+ self.data_label = []
132
+
133
+ for lidx, line in enumerate(lines):
134
+ data = line.strip().split();
135
+
136
+ speaker_label = dictkeys[data[0]];
137
+ filename = os.path.join(train_path,data[1]);
138
+
139
+ self.data_label.append(speaker_label)
140
+ self.data_list.append(filename)
141
+
142
+
143
+ def __getitem__(self, indices):
144
+
145
+ feat_clean = []
146
+ feat = []
147
+
148
+ for index in indices:
149
+ try:
150
+ audio_clean = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
151
+ except:
152
+ print(self.data_list[index])
153
+
154
+ if len(audio_clean.shape) == 3:
155
+ print(self.data_list[index])
156
+
157
+ if self.augment:
158
+ augtype = random.randint(0,5)
159
+ if augtype == 0:
160
+ audio = audio_clean
161
+ elif augtype == 1:
162
+ audio = self.augment_wav.reverberate(audio_clean)
163
+ elif augtype == 2:
164
+ audio = self.augment_wav.additive_noise('music',audio_clean)
165
+ elif augtype == 3:
166
+ audio = self.augment_wav.additive_noise('speech',audio_clean)
167
+ elif augtype == 4:
168
+ audio = self.augment_wav.additive_noise('noise',audio_clean)
169
+ elif augtype == 5:
170
+ audio = self.augment_wav.additive_noise('speech',audio_clean)
171
+ audio = self.augment_wav.additive_noise('music',audio_clean)
172
+
173
+ feat_clean.append(audio_clean)
174
+ feat.append(audio)
175
+
176
+ feat_clean = numpy.concatenate(feat_clean, axis=0)
177
+ feat = numpy.concatenate(feat, axis=0)
178
+
179
+ return torch.FloatTensor(feat_clean), torch.FloatTensor(feat), self.data_label[index], self.data_list[index]
180
+
181
+ def __len__(self):
182
+ return len(self.data_list)
183
+
184
+
185
+
186
+ class test_dataset_loader(Dataset):
187
+ def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs):
188
+ self.max_frames = eval_frames;
189
+ self.num_eval = num_eval
190
+ self.test_path = test_path
191
+ self.test_list = test_list
192
+
193
+ def __getitem__(self, index):
194
+ # print(self.test_list[index])
195
+ audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval)
196
+
197
+ audio2 = loadWAV(os.path.join(self.test_path,self.test_list[index]), 0, evalmode=True, num_eval=self.num_eval)
198
+
199
+ return torch.FloatTensor(audio), torch.FloatTensor(audio2), self.test_list[index]
200
+ # return torch.FloatTensor(audio2), self.test_list[index]
201
+
202
+ def __len__(self):
203
+ return len(self.test_list)
204
+
205
+
206
+ class train_dataset_sampler(torch.utils.data.Sampler):
207
+ def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs):
208
+
209
+ self.data_label = data_source.data_label;
210
+ self.nPerSpeaker = nPerSpeaker;
211
+ self.max_seg_per_spk = max_seg_per_spk;
212
+ self.batch_size = batch_size;
213
+ self.epoch = 0;
214
+ self.seed = seed;
215
+ self.distributed = distributed;
216
+
217
+ def __iter__(self):
218
+
219
+ g = torch.Generator()
220
+ g.manual_seed(self.seed + self.epoch)
221
+ indices = torch.randperm(len(self.data_label), generator=g).tolist()
222
+
223
+ data_dict = {}
224
+
225
+ # Sort into dictionary of file indices for each ID
226
+ for index in indices:
227
+ speaker_label = self.data_label[index]
228
+ if not (speaker_label in data_dict):
229
+ data_dict[speaker_label] = [];
230
+ data_dict[speaker_label].append(index);
231
+
232
+
233
+ ## Group file indices for each class
234
+ dictkeys = list(data_dict.keys());
235
+ dictkeys.sort()
236
+
237
+ lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)]
238
+
239
+ flattened_list = []
240
+ flattened_label = []
241
+
242
+ for findex, key in enumerate(dictkeys):
243
+ data = data_dict[key]
244
+ numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker)
245
+
246
+ rp = lol(numpy.arange(numSeg),self.nPerSpeaker)
247
+ flattened_label.extend([findex] * (len(rp)))
248
+ for indices in rp:
249
+ flattened_list.append([data[i] for i in indices])
250
+
251
+ ## Mix data in random order
252
+ mixid = torch.randperm(len(flattened_label), generator=g).tolist()
253
+ mixlabel = []
254
+ mixmap = []
255
+
256
+ ## Prevent two pairs of the same speaker in the same batch
257
+ for ii in mixid:
258
+ startbatch = round_down(len(mixlabel), self.batch_size)
259
+ if flattened_label[ii] not in mixlabel[startbatch:]:
260
+ mixlabel.append(flattened_label[ii])
261
+ mixmap.append(ii)
262
+
263
+ mixed_list = [flattened_list[i] for i in mixmap]
264
+
265
+ ## Divide data to each GPU
266
+ if self.distributed:
267
+ total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size())
268
+ start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size )
269
+ end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size )
270
+ self.num_samples = end_index - start_index
271
+ return iter(mixed_list[start_index:end_index])
272
+ else:
273
+ total_size = round_down(len(mixed_list), self.batch_size)
274
+ self.num_samples = total_size
275
+ return iter(mixed_list[:total_size])
276
+
277
+
278
+ def __len__(self) -> int:
279
+ return self.num_samples
280
+
281
+ def set_epoch(self, epoch: int) -> None:
282
+ self.epoch = epoch
283
+
284
+
285
+ if __name__ == '__main__':
286
+ train_dataset = train_dataset_loader(train_list='/mnt/proj3/open-24-5/pengjy_new/WavLM_Adapter/CNCeleb_lst/CNCeleb_trainlist_200spk.txt',
287
+ augment=False,
288
+ musan_path='/mnt/proj3/open-24-5/pengjy_new/musan_split/',
289
+ rir_path='/mnt/proj3/open-24-5/plchot/data_augment/16kHz/simulated_rirs/',
290
+ max_frames=300,
291
+ train_path='/mnt/proj3/open-24-5/pengjy_new/Data/CN-Celeb_flac/data',
292
+ )
293
+
294
+ train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=1, max_seg_per_spk=500, batch_size=100, distributed=False,seed=120)
295
+ # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
296
+
297
+ train_loader = torch.utils.data.DataLoader(
298
+ train_dataset,
299
+ batch_size=100,
300
+ num_workers=10,
301
+ sampler=train_sampler,
302
+ pin_memory=True,
303
+ drop_last=True,
304
+ )
305
+ for data, data_label in train_loader:
306
+ print(data.shape)
307
+ data = data.transpose(1,0)
308
+ print(data.shape)
309
+ quit()
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # wavlm_ssl_sv
2
+
3
+ This repository contains the source code of the article **Towards Supervised Performance on Speaker Verification with Self-Supervised Learning by Leveraging Large-Scale ASR Models** (INTERSPEECH 2024) [[arXiv]](https://arxiv.org/pdf/2406.02285).
4
+
5
+ The proposed framework fine-tunes a pre-trained **WavLM** using pseudo-labels, generated through **Self-Supervised Learning** (SSL), for **Speaker Verification** (SV). Initial pseudo-labels are derived from an SSL DINO-based model and are iteratively refined by clustering the model embeddings.
6
+
7
+ <p align="center">
8
+ <img src="training_framework.svg" width=900 />
9
+ </p>
10
+
11
+ Our method achieves **0.99% EER on VoxCeleb1-O**, establishing the new SOTA on Speaker Verification with SSL.
12
+
13
+ *Please refer to the article for more details on the implementation and a comparative study with other works.*
14
+
15
+ ---
16
+
17
+ ## Usage
18
+
19
+ ### Installation
20
+
21
+ - Install dependencies with `pip install -r requirements.txt`.
22
+ - Prepare data for VoxCeleb, MUSAN, and RIR datasets following [voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer#data-preparation).
23
+ - Download [WavLM-Base+ model](https://github.com/microsoft/unilm/tree/master/wavlm) and place `WavLM-Base+.pt` at the root folder.
24
+
25
+ ### Training
26
+
27
+ #### Step 1: Extract DINO speaker embeddings
28
+
29
+ The code to train the DINO model is not currently provided. We recommend using [sslsv](https://github.com/theolepage/sslsv) or [3D-Speaker](https://github.com/modelscope/3D-Speaker) to extract initial speaker embeddings.
30
+
31
+ Alternatively, you can directly download the DINO embeddings we used for our system: [dino_vox2_embeddings.pt](https://drive.google.com/file/d/1YnxrMIgrr6NQgZ3Hv2_5YdP5W8xfdyLH/view?usp=sharing).
32
+
33
+ *Note: the embeddings file must be a `Dict[str, torch.Tensor]` representing all VoxCeleb2 samples with the following format for keys: `id00012/21Uxsk56VDQ/00001.wav`.*
34
+
35
+ #### Step 2: Generate pseudo-labels
36
+
37
+ ```bash
38
+ python pseudo_labeling.py PATH_TO_EMBEDDINGS_FILE PATH_TO_PL_FILE
39
+ ```
40
+
41
+ #### Step 3: Fine-tune WavLM MHFA
42
+
43
+ ```bash
44
+ python trainSpeakerNet.py --config configs/wavlm_mhfa_dlg_lc.yaml --train_list PATH_TO_PL_FILE --distributed
45
+ ```
46
+
47
+ #### Iterative process
48
+
49
+ 1. Extract embeddings from the WavLM MHFA model:
50
+ `python trainSpeakerNet_Eval.py --config configs/wavlm_mhfa_dlg_lc.yaml --generate_embeddings --embeddings_path PATH_TO_EMBEDDINGS_FILE`.
51
+
52
+ 2. Repeat steps 2 and 3. *Make sure to change `save_path` in the config to avoid overwriting the existing model.*
53
+
54
+ #### Step 4: Large-Margin Fine-Tuning
55
+
56
+ 1. Copy the latest model checkpoint to `exp/wavlm_mhfa_dlg_lc_lmft/model` to resume training.
57
+
58
+ 2. Start training: `python trainSpeakerNet.py --config configs/wavlm_mhfa_dlg_lc_lmft.yaml --train_list PATH_TO_PL_FILE --distributed`.
59
+
60
+ ### Evaluation
61
+
62
+ ```bash
63
+ python trainSpeakerNet_Eval.py --config configs/wavlm_mhfa_dlg_lc_lmft.yaml --eval
64
+ ```
65
+
66
+ ### Model weights
67
+
68
+ The checkpoint of our best model reaching 0.99% EER on VoxCeleb1-O is available for download: [`wavlm_mhfa_dlg_lc_lmft`](https://drive.google.com/drive/folders/1ygZPvdGwepWDDfIQp6aPRktt2QxLt6cE?usp=drive_link).
69
+
70
+ ---
71
+
72
+ ## Acknowledgements
73
+
74
+ This repository contains third-party components and code adapted from other open-source projects, including: [SLT22_MultiHead-Factorized-Attentive-Pooling](https://github.com/JunyiPeng00/SLT22_MultiHead-Factorized-Attentive-Pooling) and [Loss-Gated-Learning](https://github.com/TaoRuijie/Loss-Gated-Learning).
75
+
76
+ ---
77
+
78
+ ## Citation
79
+
80
+ If you use this project, please consider starring this repository on GitHub and citing the following paper.
81
+
82
+ ```BibTeX
83
+ @InProceedings{miara2024WavLMSSLSV,
84
+ author = {Miara, Victor and Lepage, Théo and Dehak, Réda},
85
+ booktitle = {INTERSPEECH},
86
+ title = {Towards Supervised Performance on Speaker Verification with Self-Supervised Learning by Leveraging Large-Scale ASR Models},
87
+ year = {2024},
88
+ url = {https://arxiv.org/abs/2406.02285},
89
+ }
90
+ ```
91
+
92
+ ---
93
+
94
+ ## License
95
+
96
+ This project is released under the [MIT License](https://github.com/theolepage/wavlm_ssl_sv/blob/main/LICENSE.md).
SpeakerNet.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy, math, pdb, sys, random
8
+ import time, os, itertools, shutil, importlib
9
+ from tuneThreshold import tuneThresholdfromScore
10
+ from DatasetLoader import test_dataset_loader, loadWAV
11
+ import pickle
12
+ import numpy as np
13
+ import time
14
+ from tqdm import tqdm
15
+ import soundfile
16
+
17
+
18
+ class WrappedModel(nn.Module):
19
+
20
+ ## The purpose of this wrapper is to make the model structure consistent between single and multi-GPU
21
+
22
+ def __init__(self, model):
23
+ super(WrappedModel, self).__init__()
24
+ self.module = model
25
+
26
+ def forward(self, x, x_clean=None, label=None,l2_reg_dict=None, epoch=-1):
27
+ return self.module(x, x_clean, label, epoch=epoch)
28
+
29
+
30
+ class SpeakerNet(nn.Module):
31
+
32
+ def __init__(self, model, optimizer, trainfunc, nPerSpeaker, **kwargs):
33
+ super(SpeakerNet, self).__init__()
34
+
35
+ SpeakerNetModel = importlib.import_module('models.'+model).__getattribute__('MainModel')
36
+ self.__S__ = SpeakerNetModel(**kwargs);
37
+
38
+ LossFunction = importlib.import_module('loss.'+trainfunc).__getattribute__('LossFunction')
39
+ self.__L__ = LossFunction(**kwargs);
40
+
41
+ self.nPerSpeaker = nPerSpeaker
42
+ self.weight_finetuning_reg = kwargs['weight_finetuning_reg']
43
+
44
+
45
+ def forward(self, data, data_clean=None, label=None, l2_reg_dict=None, epoch=-1):
46
+ if label is None:
47
+ data_reshape = data[0].cuda()
48
+ outp = self.__S__.forward([data_reshape, data[1]])
49
+ return outp
50
+ elif len(data) == 3 and data[2] == "gen_ps":
51
+ data_reshape = data[0].reshape(-1,data[0].size()[-1]).cuda()
52
+ outp = self.__S__.forward([data_reshape, data[1]])
53
+ pseudo_labels = self.__L__.get_pseudo_labels(outp, label)
54
+ return pseudo_labels
55
+ else:
56
+ data_reshape = data[0].reshape(-1,data[0].size()[-1]).cuda()
57
+ data_clean_reshape = data_clean.reshape(-1,data_clean.size()[-1]).cuda()
58
+ outp = self.__S__.forward([data_reshape, data[1]])
59
+ outp_clean = self.__S__.forward([data_clean_reshape, data[1]])
60
+ nloss, prec1, ce = self.__L__.forward(outp, outp_clean, label, epoch)
61
+
62
+ if l2_reg_dict is not None:
63
+ Learned_dict = l2_reg_dict
64
+ l2_reg = 0
65
+ for name,param in self.__S__.model.named_parameters():
66
+ if name in Learned_dict:
67
+ l2_reg = l2_reg + torch.norm(param-Learned_dict[name].cuda(),2)
68
+ tloss = nloss/nloss.detach() + self.weight_finetuning_reg*l2_reg/(l2_reg.detach()+1e-5)
69
+ else:
70
+ tloss = nloss
71
+ print("Without L2 Reg")
72
+
73
+ return tloss, prec1, nloss, ce
74
+
75
+
76
+
77
+
78
+ class ModelTrainer(object):
79
+
80
+ def __init__(self, speaker_model, optimizer, scheduler, gpu, mixedprec, **kwargs):
81
+
82
+ self.__model__ = speaker_model
83
+
84
+ WavLM_params = list(map(id, self.__model__.module.__S__.model.parameters()))
85
+ Backend_params = filter(lambda p: id(p) not in WavLM_params, self.__model__.module.parameters())
86
+ self.path = kwargs['pretrained_model_path']
87
+
88
+ Optimizer = importlib.import_module('optimizer.'+optimizer).__getattribute__('Optimizer')
89
+
90
+ # Define the initial param groups
91
+ param_groups = [{'params': Backend_params, 'lr': kwargs['LR_MHFA']}]
92
+
93
+ # Extract the encoder layers
94
+ encoder_layers = self.__model__.module.__S__.model.encoder.layers
95
+
96
+ # Iterate over the encoder layers to create param groups
97
+ for i in range(12): # Assuming 12 layers from 0 to 11 (for BASE model, when it comes to LARGE model, 12->24)
98
+ lr = kwargs['LR_Transformer'] * (kwargs['LLRD_factor'] ** i)
99
+ param_groups.append({'params': encoder_layers[i].parameters(), 'lr': lr})
100
+
101
+ # Initialize the optimizer with these param groups
102
+ self.__optimizer__ = Optimizer(param_groups, **kwargs)
103
+
104
+ # self.__optimizer__ = Optimizer(self.__model__.parameters(), **kwargs)
105
+ # print('scheduler.'+scheduler)
106
+ Scheduler = importlib.import_module('scheduler.'+scheduler).__getattribute__('Scheduler')
107
+ # print(kwargs)
108
+ try:
109
+ self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, **kwargs)
110
+ except:
111
+ self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, lr_decay=0.9, **kwargs)
112
+
113
+ # self.scaler = GradScaler()
114
+
115
+ self.gpu = gpu
116
+
117
+ self.mixedprec = mixedprec
118
+ print("Mix prec: %s"%(self.mixedprec))
119
+
120
+ assert self.lr_step in ['epoch', 'iteration']
121
+
122
+ # ## ===== ===== ===== ===== ===== ===== ===== =====
123
+ # ## Train network
124
+ # ## ===== ===== ===== ===== ===== ===== ===== =====
125
+
126
+ def update_lgl_threshold(self, lgl_threshold):
127
+ self.__model__.module.__L__.lgl_threshold = lgl_threshold
128
+
129
+ # """
130
+ def train_network(self, loader, loss_vals_path, epoch, verbose):
131
+ if torch.distributed.is_initialized():
132
+ rank = torch.distributed.get_rank()
133
+ unique_loss_vals_path = f"{loss_vals_path.split('.')[0]}_rank{rank}.txt"
134
+ else:
135
+ unique_loss_vals_path = loss_vals_path
136
+
137
+ self.__model__.train();
138
+
139
+ stepsize = loader.batch_size;
140
+
141
+ counter = 0;
142
+ index = 0;
143
+ loss = 0;
144
+ top1 = 0 # EER or accuracy
145
+
146
+ tstart = time.time()
147
+ Learned_dict = {}
148
+ checkpoint = torch.load(self.path)
149
+ for name, param in checkpoint['model'].items():
150
+ if 'w2v_encoder.w2v_model.' in name:
151
+ newname = name.replace('w2v_encoder.w2v_model.', '')
152
+ else:
153
+ newname = name
154
+ Learned_dict[newname] = param;
155
+
156
+ # for data_clean, data, data_label, data_path in loader:
157
+ # telapsed = time.time() - tstart
158
+ # tstart = time.time()
159
+ # counter += 1;
160
+ # index += stepsize
161
+ # sys.stdout.write("\rProcessing (%d) "%(index));
162
+ # sys.stdout.write("Loss %f TEER/TAcc %2.3f%% - %.2f Hz "%(loss/counter, top1/counter, stepsize/telapsed));
163
+ # if counter % 100 == 0:
164
+ # sys.stdout.flush()
165
+
166
+ with open(unique_loss_vals_path, 'w') as loss_vals_file:
167
+ for data_clean, data, data_label, data_path in loader:
168
+ data_clean = data_clean.transpose(1,0)
169
+ data = data.transpose(1,0)
170
+ self.__model__.zero_grad()
171
+ label = torch.LongTensor(data_label).cuda()
172
+
173
+ nloss, prec1, spkloss, ce = self.__model__([data,"train"], data_clean, label, Learned_dict, epoch=epoch)
174
+
175
+ for ce_val, path in zip(ce.detach().cpu().numpy(), data_path):
176
+ loss_vals_file.write(f'{ce_val} {"/".join(path.split("/")[5:])}\n')
177
+
178
+ nloss.backward()
179
+
180
+ self.__optimizer__.step();
181
+
182
+ loss += spkloss.detach().cpu()
183
+ top1 += prec1.detach().cpu()
184
+
185
+
186
+ counter += 1;
187
+ index += stepsize;
188
+
189
+
190
+
191
+ telapsed = time.time() - tstart
192
+ tstart = time.time()
193
+
194
+ if verbose:
195
+ sys.stdout.write("\rProcessing (%d) "%(index));
196
+ sys.stdout.write("Loss %f TEER/TAcc %2.3f%% - %.2f Hz "%(loss/counter, top1/counter, stepsize/telapsed));
197
+ sys.stdout.flush();
198
+
199
+ if self.lr_step == 'iteration': self.__scheduler__.step()
200
+
201
+ if self.lr_step == 'epoch': self.__scheduler__.step()
202
+
203
+ sys.stdout.write("\n");
204
+ return (loss/counter, top1/counter);
205
+ # """
206
+
207
+ ## ===== ===== ===== ===== ===== ===== ===== =====
208
+ ## Evaluate from list
209
+ ## ===== ===== ===== ===== ===== ===== ===== =====
210
+
211
+ def evaluateFromList(self, test_list, test_path, nDataLoaderThread, print_interval=10, num_eval=15, **kwargs):
212
+
213
+ self.__model__.eval();
214
+
215
+ lines = []
216
+ files = []
217
+ feats = {}
218
+ tstart = time.time()
219
+
220
+ ## Read all lines
221
+ with open(test_list) as f:
222
+ lines = f.readlines()
223
+
224
+ ## Get a list of unique file names
225
+ files = sum([x.strip().split()[-2:] for x in lines],[])
226
+ setfiles = list(set(files))
227
+ setfiles.sort()
228
+
229
+ ## Define test data loader
230
+ test_dataset = test_dataset_loader(setfiles, test_path, num_eval=num_eval, **kwargs)
231
+ test_loader = torch.utils.data.DataLoader(
232
+ test_dataset,
233
+ batch_size=1,
234
+ shuffle=False,
235
+ num_workers=nDataLoaderThread,
236
+ drop_last=False,
237
+ )
238
+ ref_feat_list = []
239
+ ref_feat_2_list = []
240
+ max_len = 0
241
+ forward = 0
242
+ ## Extract features for every image
243
+ for idx, data in enumerate(test_loader):
244
+
245
+
246
+ inp1 = data[0][0].cuda()
247
+ inp2 = data[1][0].cuda()
248
+ telapsed_2 = time.time()
249
+ b,utt_l = inp2.shape
250
+ if utt_l > max_len:
251
+ max_len = utt_l
252
+ ref_feat = self.__model__([inp1, "test"]).cuda()
253
+ ref_feat = ref_feat.detach().cpu()
254
+ ref_feat_2 = self.__model__([inp2[:,:700000], "test"]).cuda() # The reason why here is set to 700000 is due to GPU memory size.
255
+ ref_feat_2 = ref_feat_2.detach().cpu()
256
+
257
+ feats[data[2][0]] = [ref_feat,ref_feat_2]
258
+
259
+ ref_feat_list.extend(ref_feat.numpy())
260
+ ref_feat_2_list.extend(ref_feat_2.numpy())
261
+
262
+ telapsed = time.time() - tstart
263
+ forward = forward + time.time() - telapsed_2
264
+
265
+ if idx % print_interval == 0:
266
+ sys.stdout.write("\rReading %d of %d: %.2f Hz, forward speed: %.2f Hz, embedding size %d, max_len %d"%(idx,len(setfiles),idx/telapsed,idx/forward, ref_feat.size()[-1],max_len));
267
+
268
+ print('')
269
+ all_scores = [];
270
+ all_labels = [];
271
+ all_trials = [];
272
+ all_scores_1 = [];
273
+ all_scores_2 = [];
274
+
275
+ tstart = time.time()
276
+
277
+ ref_feat_list = numpy.array(ref_feat_list)
278
+ ref_feat_2_list = numpy.array(ref_feat_2_list)
279
+
280
+ ref_feat_list_mean = 0
281
+ ref_feat_2_list_mean = 0
282
+
283
+
284
+ ## Read files and compute all scores
285
+ for idx, line in enumerate(lines):
286
+
287
+ data = line.split();
288
+
289
+ ## Append random label if missing
290
+ if len(data) == 2: data = [random.randint(0,1)] + data
291
+
292
+ ref_feat,ref_feat_2 = feats[data[1]]
293
+ com_feat,com_feat_2 = feats[data[2]]
294
+
295
+ # if self.__model__.module.__L__.test_normalize:
296
+ ref_feat = F.normalize(ref_feat-ref_feat_list_mean, p=2, dim=1) # B, D
297
+ com_feat = F.normalize(com_feat-ref_feat_list_mean, p=2, dim=1)
298
+ ref_feat_2 = F.normalize(ref_feat_2-ref_feat_2_list_mean, p=2, dim=1) # B, D
299
+ com_feat_2 = F.normalize(com_feat_2-ref_feat_2_list_mean, p=2, dim=1)
300
+
301
+ score_1 = torch.mean(torch.matmul(ref_feat, com_feat.T)) # higher is positive
302
+ score_2 = torch.mean(torch.matmul(ref_feat_2, com_feat_2.T))
303
+ score = (score_1 + score_2) / 2
304
+ score = score.detach().cpu().numpy()
305
+
306
+ all_scores.append(score);
307
+ all_scores_1.append(score_1);
308
+ all_scores_2.append(score_2);
309
+
310
+ all_labels.append(int(data[0]));
311
+ all_trials.append(data[1]+" "+data[2])
312
+
313
+ if idx % (10*print_interval) == 0:
314
+ telapsed = time.time() - tstart
315
+ sys.stdout.write("\rComputing %d of %d: %.2f Hz"%(idx,len(lines),idx/telapsed));
316
+ sys.stdout.flush();
317
+
318
+ print('')
319
+
320
+ return (all_scores, all_labels, all_trials,all_scores_1,all_scores_2);
321
+
322
+ def generate_embeddings(self, wav_files, output, device):
323
+ res = {}
324
+
325
+ for file in tqdm(wav_files):
326
+ wav, sr = soundfile.read(file)
327
+ wav = torch.from_numpy(wav).float().to(device)
328
+
329
+ with torch.no_grad():
330
+ embedding = self.__model__([wav.unsqueeze(0), "test"]).detach().cpu()
331
+
332
+ key = '/'.join(file.split('/')[-3:])
333
+ res[key] = embedding
334
+
335
+ torch.save(res, output)
336
+
337
+ def saveParameters(self, path):
338
+ torch.save(self.__model__.module.state_dict(), path);
339
+
340
+
341
+ ## ===== ===== ===== ===== ===== ===== ===== =====
342
+ ## Load parameters
343
+ ## ===== ===== ===== ===== ===== ===== ===== =====
344
+
345
+ def loadParameters(self, path):
346
+
347
+ self_state = self.__model__.module.state_dict();
348
+ loaded_state = torch.load(path, map_location="cuda:%d"%self.gpu);
349
+ # loaded_state = torch.load(path, map_location="cpu");
350
+
351
+
352
+
353
+ for name, param in loaded_state.items():
354
+ origname = name;
355
+
356
+ if name not in self_state:
357
+ name = name.replace("module.", "");
358
+
359
+ if name not in self_state:
360
+ print("%s is not in the model."%origname);
361
+ continue;
362
+
363
+ if self_state[name].size() != loaded_state[origname].size():
364
+ print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size()));
365
+ continue;
366
+
367
+ self_state[name].copy_(param);
368
+
369
+
370
+
371
+
372
+
configs/wavlm_mhfa_dlg_lc.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_frames: 300
2
+ max_epoch: 15
3
+ batch_size: 120
4
+ margin: 0.2
5
+
6
+ eval_frames: 400
7
+ augment: True
8
+
9
+ ## Training details
10
+ trainfunc: aamsoftmax
11
+
12
+ scale: 30
13
+
14
+ lr_decay: 0.95
15
+
16
+ pretrained_model_path: WavLM-Base+.pt
17
+ weight_finetuning_reg: 0.01
18
+ LLRD_factor: 1.0
19
+ LR_Transformer: 2e-5
20
+ LR_MHFA: 5e-3
21
+
22
+ ## Loss functions
23
+ nClasses: 7500
24
+
25
+ ## Load and save
26
+ save_path: exp/wavlm_mhfa_dlg_lc
27
+ # save_path: exp/wavlm_mhfa_dlg_lc_iter2
28
+
29
+ ## Model definition
30
+ model: Baseline.Spk_Encoder
31
+
32
+ nOut: 256
33
+ port: 6754
configs/wavlm_mhfa_dlg_lc_lmft.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_frames: 600
2
+ max_epoch: 5
3
+ batch_size: 50
4
+ margin: 0.5
5
+
6
+ eval_frames: 400
7
+ augment: True
8
+
9
+ ## Training details
10
+ trainfunc: aamsoftmax
11
+
12
+ scale: 30
13
+
14
+ lr: 5e-4
15
+ lr_decay: 0.95
16
+
17
+ pretrained_model_path: WavLM-Base+.pt
18
+ weight_finetuning_reg: 0.01
19
+ LLRD_factor: 1.0
20
+ LR_Transformer: 2e-5
21
+ LR_MHFA: 5e-3
22
+
23
+ ## Loss functions
24
+ nClasses: 7500
25
+
26
+ ## Load and save
27
+ save_path: exp/wavlm_mhfa_dlg_lc_lmft
28
+
29
+ ## Model definition
30
+ model: Baseline.Spk_Encoder
31
+
32
+ nOut: 256
33
+ port: 6754
loss/aamsoftmax.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+ # Adapted from https://github.com/wujiyang/Face_Pytorch (Apache License)
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import time, pdb, numpy, math
9
+ from utils import accuracy
10
+ import numpy as np
11
+
12
+ class LossFunction(nn.Module):
13
+ def __init__(self, nOut, nClasses, margin=0.3, scale=15, easy_margin=False, **kwargs):
14
+ super(LossFunction, self).__init__()
15
+
16
+ self.test_normalize = True
17
+
18
+ self.m = margin
19
+ self.s = scale
20
+ self.in_feats = nOut
21
+ self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True)
22
+ # self.ce = nn.CrossEntropyLoss()
23
+ self.ce = nn.CrossEntropyLoss(reduction='none') # return loss per sample
24
+ nn.init.xavier_normal_(self.weight, gain=1)
25
+
26
+ self.easy_margin = easy_margin
27
+ self.cos_m = math.cos(self.m)
28
+ self.sin_m = math.sin(self.m)
29
+
30
+ # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
31
+ self.th = math.cos(math.pi - self.m)
32
+ self.mm = math.sin(math.pi - self.m) * self.m
33
+
34
+ self.lgl_threshold = 1e6
35
+ self.lc_threshold = 0.5
36
+
37
+ print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s))
38
+
39
+ def _forward(self, x, label):
40
+ # cos(theta)
41
+ cosine = F.linear(F.normalize(x), F.normalize(self.weight))
42
+ # cos(theta + m)
43
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
44
+ phi = cosine * self.cos_m - sine * self.sin_m
45
+
46
+ if self.easy_margin:
47
+ phi = torch.where(cosine > 0, phi, cosine)
48
+ else:
49
+ phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
50
+
51
+ #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
52
+ one_hot = torch.zeros_like(cosine)
53
+ one_hot.scatter_(1, label.view(-1, 1), 1)
54
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
55
+ output = output * self.s
56
+
57
+ return output
58
+
59
+ def _forward_softmax_sharpened(self, x, e=0.1):
60
+ # regular softmax
61
+ output = F.linear(x, self.weight)
62
+ probas = F.softmax(output / e, dim=1)
63
+ return probas
64
+
65
+ def forward(self, x, x_clean, label=None, epoch=-1):
66
+ assert x.size()[0] == label.size()[0]
67
+ assert x.size()[1] == self.in_feats
68
+
69
+ output = self._forward(x, label)
70
+ output_clean = self._forward_softmax_sharpened(x_clean)
71
+
72
+ ce = self.ce(output, label)
73
+
74
+ # No LGL
75
+ # prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
76
+ # return ce, prec1, None
77
+
78
+ mask = (torch.log(ce) <= self.lgl_threshold).detach()
79
+
80
+ if epoch <= 8:
81
+ # LGL only
82
+ nselect = torch.clamp(sum(mask), min=1).item()
83
+ loss = torch.sum(ce * mask, dim=-1) / nselect
84
+ prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0]
85
+ return loss, prec1, ce
86
+
87
+ # LGL + LC
88
+
89
+ label_LC = output_clean.argmax(dim=1)
90
+
91
+ max_vals = torch.gather(output_clean, 1, label_LC.unsqueeze(1)).squeeze(1)
92
+ mask_LC = (max_vals > self.lc_threshold).detach()
93
+
94
+ ce_LC = self.ce(output, label_LC)
95
+
96
+ mask_LGL_LC = ~mask & mask_LC
97
+ loss = torch.mean(ce * mask + ce_LC * mask_LGL_LC, dim=-1)
98
+ prec1 = accuracy(output.detach(), label * mask.detach() + label_LC * mask_LGL_LC.detach(), topk=(1,))[0]
99
+
100
+ return loss, prec1, ce
101
+
102
+ def get_pseudo_labels(self, x, label):
103
+ output = self._forward_softmax_sharpened(x)
104
+ return output.argmax(dim=1)
105
+
106
+ """
107
+ def forward(self, x, x_clean, label=None):
108
+
109
+ assert x.size()[0] == label.size()[0]
110
+ assert x.size()[1] == self.in_feats
111
+
112
+ P_aam = self._forward(x, label)
113
+
114
+ P_softmax = self._forward_softmax_sharpened(x)
115
+ P_clean_softmax = self._forward_softmax_sharpened(x_clean)
116
+
117
+ ce = self.ce(P_aam, label)
118
+
119
+ # No LGL
120
+ # prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
121
+ # return ce, prec1, None
122
+
123
+ mask = (torch.log(ce) <= self.lgl_threshold).detach()
124
+
125
+ # LGL only
126
+ # nselect = torch.clamp(sum(mask), min=1).item()
127
+ # loss = torch.sum(ce * mask, dim=-1) / nselect
128
+ # prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0]
129
+ # return loss, prec1, ce
130
+
131
+ # LGL + LC
132
+ label_LC = P_clean_softmax.argmax(dim=1)
133
+ ce_LC = self.ce(P_softmax, label_LC)
134
+
135
+ inverted_mask = ~mask
136
+ loss = torch.mean(ce * mask + ce_LC * inverted_mask, dim=-1)
137
+ prec1 = accuracy(P_softmax.detach(), label * mask.detach() + label_LC * inverted_mask.detach(), topk=(1,))[0]
138
+
139
+ return loss, prec1, ce
140
+ """
models/Baseline/Spk_Encoder.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import LayerNorm
6
+ from .WavLM import *
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ class MHFA(nn.Module):
12
+ def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256):
13
+ super(MHFA, self).__init__()
14
+
15
+ # Define learnable weights for key and value computations across layers
16
+ self.weights_k = nn.Parameter(data=torch.ones(13), requires_grad=True)
17
+ self.weights_v = nn.Parameter(data=torch.ones(13), requires_grad=True)
18
+
19
+ # Initialize given parameters
20
+ self.head_nb = head_nb
21
+ self.ins_dim = inputs_dim
22
+ self.cmp_dim = compression_dim
23
+ self.ous_dim = outputs_dim
24
+
25
+ # Define compression linear layers for keys and values
26
+ self.cmp_linear_k = nn.Linear(self.ins_dim, self.cmp_dim)
27
+ self.cmp_linear_v = nn.Linear(self.ins_dim, self.cmp_dim)
28
+
29
+ # Define linear layer to compute multi-head attention weights
30
+ self.att_head = nn.Linear(self.cmp_dim, self.head_nb)
31
+
32
+ # Define a fully connected layer for final output
33
+ self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, self.ous_dim)
34
+
35
+ def forward(self, x):
36
+ # Input x has shape: [Batch, Dim, Frame_len, Nb_Layer]
37
+
38
+ # Compute the key by taking a weighted sum of input across layers
39
+ k = torch.sum(x.mul(nn.functional.softmax(self.weights_k, dim=-1)), dim=-1).transpose(1, 2)
40
+
41
+ # Compute the value in a similar fashion
42
+ v = torch.sum(x.mul(nn.functional.softmax(self.weights_v, dim=-1)), dim=-1).transpose(1, 2)
43
+
44
+ # Pass the keys and values through compression linear layers
45
+ k = self.cmp_linear_k(k)
46
+ v = self.cmp_linear_v(v)
47
+
48
+ # Compute attention weights using compressed keys
49
+ att_k = self.att_head(k)
50
+
51
+ # Adjust dimensions for computing attention output
52
+ v = v.unsqueeze(-2)
53
+
54
+ # Compute attention output by taking weighted sum of values using softmaxed attention weights
55
+ pooling_outs = torch.sum(v.mul(nn.functional.softmax(att_k, dim=1).unsqueeze(-1)), dim=1)
56
+
57
+ # Reshape the tensor before passing through the fully connected layer
58
+ b, h, f = pooling_outs.shape
59
+ pooling_outs = pooling_outs.reshape(b, -1)
60
+
61
+ # Pass through fully connected layer to get the final output
62
+ outs = self.pooling_fc(pooling_outs)
63
+
64
+ return outs
65
+
66
+
67
+ class spk_extractor(nn.Module):
68
+ def __init__(self,**kwargs):
69
+ super(spk_extractor, self).__init__()
70
+ # checkpoint = torch.load('/mnt/proj3/open-24-5/pengjy_new/WavLM/Pretrained_model/WavLM-Base+.pt')
71
+ print("Pre-trained Model: {}".format(kwargs['pretrained_model_path']))
72
+ checkpoint = torch.load(kwargs['pretrained_model_path'])
73
+ cfg = WavLMConfig(checkpoint['cfg'])
74
+ self.model = WavLM(cfg)
75
+ self.loadParameters(checkpoint['model'])
76
+ self.backend = MHFA(head_nb=64)
77
+
78
+
79
+ def forward(self,wav_and_flag):
80
+
81
+ x = wav_and_flag[0]
82
+
83
+ cnn_outs, layer_results = self.model.extract_features(x, output_layer=13)
84
+ layer_reps = [x.transpose(0, 1) for x, _ in layer_results]
85
+ x = torch.stack(layer_reps).transpose(0,-1).transpose(0,1)
86
+
87
+ out = self.backend(x)
88
+ return out
89
+
90
+ def loadParameters(self, param):
91
+
92
+ self_state = self.model.state_dict();
93
+ loaded_state = param
94
+
95
+ for name, param in loaded_state.items():
96
+ origname = name;
97
+
98
+
99
+ if name not in self_state:
100
+ # print("%s is not in the model."%origname);
101
+ continue;
102
+
103
+ if self_state[name].size() != loaded_state[origname].size():
104
+ print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size()));
105
+ continue;
106
+
107
+ self_state[name].copy_(param);
108
+
109
+
110
+ def MainModel(**kwargs):
111
+ model = spk_extractor(**kwargs)
112
+ return model
models/Baseline/WavLM.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from .modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+
213
+ if cfg is not None:
214
+ self.update(cfg)
215
+
216
+ def update(self, cfg: dict):
217
+ self.__dict__.update(cfg)
218
+
219
+
220
+ class WavLM(nn.Module):
221
+ def __init__(
222
+ self,
223
+ cfg: WavLMConfig,
224
+ ) -> None:
225
+ super().__init__()
226
+ logger.info(f"WavLM Config: {cfg.__dict__}")
227
+
228
+ self.cfg = cfg
229
+ feature_enc_layers = eval(cfg.conv_feature_layers)
230
+ self.embed = feature_enc_layers[-1][0]
231
+
232
+ self.feature_extractor = ConvFeatureExtractionModel(
233
+ conv_layers=feature_enc_layers,
234
+ dropout=0.0,
235
+ mode=cfg.extractor_mode,
236
+ conv_bias=cfg.conv_bias,
237
+ )
238
+
239
+ self.post_extract_proj = (
240
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
241
+ if self.embed != cfg.encoder_embed_dim
242
+ else None
243
+ )
244
+
245
+ self.mask_prob = cfg.mask_prob
246
+ self.mask_selection = cfg.mask_selection
247
+ self.mask_other = cfg.mask_other
248
+ self.mask_length = cfg.mask_length
249
+ self.no_mask_overlap = cfg.no_mask_overlap
250
+ self.mask_min_space = cfg.mask_min_space
251
+
252
+ self.mask_channel_prob = cfg.mask_channel_prob
253
+ self.mask_channel_selection = cfg.mask_channel_selection
254
+ self.mask_channel_other = cfg.mask_channel_other
255
+ self.mask_channel_length = cfg.mask_channel_length
256
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
257
+ self.mask_channel_min_space = cfg.mask_channel_min_space
258
+
259
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
260
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
261
+
262
+ self.feature_grad_mult = cfg.feature_grad_mult
263
+
264
+ self.mask_emb = nn.Parameter(
265
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
266
+ )
267
+
268
+ self.encoder = TransformerEncoder(cfg)
269
+ self.layer_norm = LayerNorm(self.embed)
270
+
271
+ def apply_mask(self, x, padding_mask):
272
+ B, T, C = x.shape
273
+ if self.mask_prob > 0:
274
+ mask_indices = compute_mask_indices(
275
+ (B, T),
276
+ padding_mask,
277
+ self.mask_prob,
278
+ self.mask_length,
279
+ self.mask_selection,
280
+ self.mask_other,
281
+ min_masks=2,
282
+ no_overlap=self.no_mask_overlap,
283
+ min_space=self.mask_min_space,
284
+ )
285
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
286
+ x[mask_indices] = self.mask_emb
287
+ else:
288
+ mask_indices = None
289
+
290
+ if self.mask_channel_prob > 0:
291
+ mask_channel_indices = compute_mask_indices(
292
+ (B, C),
293
+ None,
294
+ self.mask_channel_prob,
295
+ self.mask_channel_length,
296
+ self.mask_channel_selection,
297
+ self.mask_channel_other,
298
+ no_overlap=self.no_mask_channel_overlap,
299
+ min_space=self.mask_channel_min_space,
300
+ )
301
+ mask_channel_indices = (
302
+ torch.from_numpy(mask_channel_indices)
303
+ .to(x.device)
304
+ .unsqueeze(1)
305
+ .expand(-1, T, -1)
306
+ )
307
+ x[mask_channel_indices] = 0
308
+
309
+ return x, mask_indices
310
+
311
+ def forward_padding_mask(
312
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
313
+ ) -> torch.Tensor:
314
+ extra = padding_mask.size(1) % features.size(1)
315
+ if extra > 0:
316
+ padding_mask = padding_mask[:, :-extra]
317
+ padding_mask = padding_mask.view(
318
+ padding_mask.size(0), features.size(1), -1
319
+ )
320
+ padding_mask = padding_mask.all(-1)
321
+ return padding_mask
322
+
323
+ def extract_features(
324
+ self,
325
+ source: torch.Tensor,
326
+ padding_mask: Optional[torch.Tensor] = None,
327
+ mask: bool = False,
328
+ ret_conv: bool = False,
329
+ output_layer: Optional[int] = None,
330
+ ret_layer_results: bool = False,
331
+ ):
332
+
333
+
334
+ with torch.no_grad():
335
+ features = self.feature_extractor(source)
336
+
337
+ cnn_outs = features
338
+ features = features[-1].transpose(1, 2)
339
+ features = self.layer_norm(features)
340
+
341
+ if padding_mask is not None:
342
+ padding_mask = self.forward_padding_mask(features, padding_mask)
343
+
344
+ if self.post_extract_proj is not None:
345
+ features = self.post_extract_proj(features)
346
+
347
+ features = self.dropout_input(features)
348
+
349
+ if mask:
350
+ x, mask_indices = self.apply_mask(
351
+ features, padding_mask
352
+ )
353
+ else:
354
+ x = features
355
+
356
+ # feature: (B, T, D), float
357
+ # target: (B, T), long
358
+ # x: (B, T, D), float
359
+ # padding_mask: (B, T), bool
360
+ # mask_indices: (B, T), bool
361
+ x, layer_results = self.encoder(
362
+ x,
363
+ padding_mask=padding_mask,
364
+ layer=None if output_layer is None else output_layer - 1
365
+ )
366
+ return cnn_outs, layer_results
367
+ # res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
368
+
369
+ # feature = res["features"] if ret_conv else res["x"]
370
+ # if ret_layer_results:
371
+ # feature = (feature, res["layer_results"])
372
+ # return feature, res["padding_mask"]
373
+
374
+
375
+ class ConvFeatureExtractionModel(nn.Module):
376
+ def __init__(
377
+ self,
378
+ conv_layers: List[Tuple[int, int, int]],
379
+ dropout: float = 0.0,
380
+ mode: str = "default",
381
+ conv_bias: bool = False,
382
+ conv_type: str = "default"
383
+ ):
384
+ super().__init__()
385
+
386
+ assert mode in {"default", "layer_norm"}
387
+
388
+ def block(
389
+ n_in,
390
+ n_out,
391
+ k,
392
+ stride,
393
+ is_layer_norm=False,
394
+ is_group_norm=False,
395
+ conv_bias=False,
396
+ ):
397
+ def make_conv():
398
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
399
+ nn.init.kaiming_normal_(conv.weight)
400
+ return conv
401
+
402
+ assert (
403
+ is_layer_norm and is_group_norm
404
+ ) == False, "layer norm and group norm are exclusive"
405
+
406
+ if is_layer_norm:
407
+ return nn.Sequential(
408
+ make_conv(),
409
+ nn.Dropout(p=dropout),
410
+ nn.Sequential(
411
+ TransposeLast(),
412
+ Fp32LayerNorm(dim, elementwise_affine=True),
413
+ TransposeLast(),
414
+ ),
415
+ nn.GELU(),
416
+ )
417
+ elif is_group_norm:
418
+ return nn.Sequential(
419
+ make_conv(),
420
+ nn.Dropout(p=dropout),
421
+ Fp32GroupNorm(dim, dim, affine=True),
422
+ nn.GELU(),
423
+ )
424
+ else:
425
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
426
+
427
+ self.conv_type = conv_type
428
+ if self.conv_type == "default":
429
+ in_d = 1
430
+ self.conv_layers = nn.ModuleList()
431
+ for i, cl in enumerate(conv_layers):
432
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
433
+ (dim, k, stride) = cl
434
+
435
+ self.conv_layers.append(
436
+ block(
437
+ in_d,
438
+ dim,
439
+ k,
440
+ stride,
441
+ is_layer_norm=mode == "layer_norm",
442
+ is_group_norm=mode == "default" and i == 0,
443
+ conv_bias=conv_bias,
444
+ )
445
+ )
446
+ in_d = dim
447
+ elif self.conv_type == "conv2d":
448
+ in_d = 1
449
+ self.conv_layers = nn.ModuleList()
450
+ for i, cl in enumerate(conv_layers):
451
+ assert len(cl) == 3
452
+ (dim, k, stride) = cl
453
+
454
+ self.conv_layers.append(
455
+ torch.nn.Conv2d(in_d, dim, k, stride)
456
+ )
457
+ self.conv_layers.append(torch.nn.ReLU())
458
+ in_d = dim
459
+ elif self.conv_type == "custom":
460
+ in_d = 1
461
+ idim = 80
462
+ self.conv_layers = nn.ModuleList()
463
+ for i, cl in enumerate(conv_layers):
464
+ assert len(cl) == 3
465
+ (dim, k, stride) = cl
466
+ self.conv_layers.append(
467
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
468
+ )
469
+ self.conv_layers.append(
470
+ torch.nn.LayerNorm([dim, idim])
471
+ )
472
+ self.conv_layers.append(torch.nn.ReLU())
473
+ in_d = dim
474
+ if (i + 1) % 2 == 0:
475
+ self.conv_layers.append(
476
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+ )
478
+ idim = int(math.ceil(idim / 2))
479
+ else:
480
+ pass
481
+
482
+ def forward(self, x, mask=None):
483
+
484
+ # BxT -> BxCxT
485
+ x_lst = []
486
+ x = x.unsqueeze(1)
487
+ if self.conv_type == "custom":
488
+ for conv in self.conv_layers:
489
+ if isinstance(conv, nn.LayerNorm):
490
+ x = x.transpose(1, 2)
491
+ x = conv(x).transpose(1, 2)
492
+ else:
493
+ x = conv(x)
494
+ x = x.transpose(2, 3).contiguous()
495
+ x = x.view(x.size(0), -1, x.size(-1))
496
+ else:
497
+ for conv in self.conv_layers:
498
+ x = conv(x)
499
+ x_lst.append(x)
500
+ if self.conv_type == "conv2d":
501
+ b, c, t, f = x.size()
502
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
503
+ return x_lst
504
+
505
+
506
+ class TransformerEncoder(nn.Module):
507
+ def __init__(self, args):
508
+ super().__init__()
509
+
510
+ self.dropout = args.dropout
511
+ self.embedding_dim = args.encoder_embed_dim
512
+
513
+ self.pos_conv = nn.Conv1d(
514
+ self.embedding_dim,
515
+ self.embedding_dim,
516
+ kernel_size=args.conv_pos,
517
+ padding=args.conv_pos // 2,
518
+ groups=args.conv_pos_groups,
519
+ )
520
+ dropout = 0
521
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
522
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
523
+ nn.init.constant_(self.pos_conv.bias, 0)
524
+
525
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
526
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
527
+
528
+ if hasattr(args, "relative_position_embedding"):
529
+ self.relative_position_embedding = args.relative_position_embedding
530
+ self.num_buckets = args.num_buckets
531
+ self.max_distance = args.max_distance
532
+ else:
533
+ self.relative_position_embedding = False
534
+ self.num_buckets = 0
535
+ self.max_distance = 0
536
+
537
+ self.layers = nn.ModuleList(
538
+ [
539
+ TransformerSentenceEncoderLayer(
540
+ embedding_dim=self.embedding_dim,
541
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
542
+ num_attention_heads=args.encoder_attention_heads,
543
+ dropout=self.dropout,
544
+ attention_dropout=args.attention_dropout,
545
+ activation_dropout=args.activation_dropout,
546
+ activation_fn=args.activation_fn,
547
+ layer_norm_first=args.layer_norm_first,
548
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
549
+ num_buckets=self.num_buckets,
550
+ max_distance=self.max_distance,
551
+ gru_rel_pos=args.gru_rel_pos,
552
+ )
553
+ for i in range(args.encoder_layers)
554
+ ]
555
+ )
556
+
557
+ self.layer_norm_first = args.layer_norm_first
558
+ self.layer_norm = LayerNorm(self.embedding_dim)
559
+ self.layerdrop = args.encoder_layerdrop
560
+
561
+ self.apply(init_bert_params)
562
+
563
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
564
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
565
+
566
+ if self.layer_norm_first and layer is None:
567
+ x = self.layer_norm(x)
568
+
569
+ return x, layer_results
570
+
571
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
572
+
573
+ if padding_mask is not None:
574
+ x[padding_mask] = 0
575
+
576
+ x_conv = self.pos_conv(x.transpose(1, 2))
577
+ x_conv = x_conv.transpose(1, 2)
578
+ x = x + x_conv
579
+
580
+ if not self.layer_norm_first:
581
+ x = self.layer_norm(x)
582
+
583
+ x = F.dropout(x, p=self.dropout, training=self.training)
584
+
585
+ # B x T x C -> T x B x C
586
+ x = x.transpose(0, 1)
587
+
588
+ layer_results = []
589
+ z = None
590
+ if tgt_layer is not None:
591
+ layer_results.append((x, z))
592
+ r = None
593
+ pos_bias = None
594
+ for i, layer in enumerate(self.layers):
595
+ dropout_probability = np.random.random()
596
+ if not self.training or (dropout_probability > self.layerdrop):
597
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
598
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
599
+ if tgt_layer is not None:
600
+ layer_results.append((x, z))
601
+ if i == tgt_layer:
602
+ r = x
603
+ break
604
+
605
+ if r is not None:
606
+ x = r
607
+
608
+ # T x B x C -> B x T x C
609
+ x = x.transpose(0, 1)
610
+
611
+ return x, layer_results
612
+
613
+
614
+ class TransformerSentenceEncoderLayer(nn.Module):
615
+ """
616
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
617
+ models.
618
+ """
619
+
620
+ def __init__(
621
+ self,
622
+ embedding_dim: float = 768,
623
+ ffn_embedding_dim: float = 3072,
624
+ num_attention_heads: float = 8,
625
+ dropout: float = 0.1,
626
+ attention_dropout: float = 0.1,
627
+ activation_dropout: float = 0.1,
628
+ activation_fn: str = "relu",
629
+ layer_norm_first: bool = False,
630
+ has_relative_attention_bias: bool = False,
631
+ num_buckets: int = 0,
632
+ max_distance: int = 0,
633
+ rescale_init: bool = False,
634
+ gru_rel_pos: bool = False,
635
+ ) -> None:
636
+
637
+ super().__init__()
638
+ # Initialize parameters
639
+ self.embedding_dim = embedding_dim
640
+ self.dropout = dropout
641
+ self.activation_dropout = activation_dropout
642
+
643
+ # Initialize blocks
644
+ self.activation_name = activation_fn
645
+ self.activation_fn = get_activation_fn(activation_fn)
646
+ self.self_attn = MultiheadAttention(
647
+ self.embedding_dim,
648
+ num_attention_heads,
649
+ dropout=attention_dropout,
650
+ self_attention=True,
651
+ has_relative_attention_bias=has_relative_attention_bias,
652
+ num_buckets=num_buckets,
653
+ max_distance=max_distance,
654
+ rescale_init=rescale_init,
655
+ gru_rel_pos=gru_rel_pos,
656
+ )
657
+
658
+ self.dropout1 = nn.Dropout(dropout)
659
+ self.dropout2 = nn.Dropout(self.activation_dropout)
660
+ self.dropout3 = nn.Dropout(dropout)
661
+
662
+ self.layer_norm_first = layer_norm_first
663
+
664
+ # layer norm associated with the self attention layer
665
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
666
+
667
+ if self.activation_name == "glu":
668
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
669
+ else:
670
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
671
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
672
+
673
+ import torchaudio.functional as AudioF
674
+
675
+
676
+ # layer norm associated with the position wise feed-forward NN
677
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
678
+
679
+ def forward(
680
+ self,
681
+ x: torch.Tensor,
682
+ self_attn_mask: torch.Tensor = None,
683
+ self_attn_padding_mask: torch.Tensor = None,
684
+ need_weights: bool = False,
685
+ pos_bias=None
686
+ ):
687
+ """
688
+ LayerNorm is applied either before or after the self-attention/ffn
689
+ modules similar to the original Transformer imlementation.
690
+ """
691
+ residual = x
692
+
693
+ if self.layer_norm_first:
694
+ x = self.self_attn_layer_norm(x)
695
+ x, attn, pos_bias = self.self_attn(
696
+ query=x,
697
+ key=x,
698
+ value=x,
699
+ key_padding_mask=self_attn_padding_mask,
700
+ need_weights=False,
701
+ attn_mask=self_attn_mask,
702
+ position_bias=pos_bias
703
+ )
704
+ x = self.dropout1(x)
705
+ x = residual + x
706
+
707
+ residual = x
708
+
709
+ x = self.final_layer_norm(x)
710
+ if self.activation_name == "glu":
711
+ x = self.fc1(x)
712
+ else:
713
+ x = self.activation_fn(self.fc1(x))
714
+ x = self.dropout2(x)
715
+ x = self.fc2(x)
716
+ x = self.dropout3(x)
717
+ x = residual + x
718
+ else:
719
+
720
+
721
+ x, attn, pos_bias = self.self_attn(
722
+ query=x,
723
+ key=x,
724
+ value=x,
725
+ key_padding_mask=self_attn_padding_mask,
726
+ need_weights=need_weights,
727
+ attn_mask=self_attn_mask,
728
+ position_bias=pos_bias
729
+ )
730
+
731
+ x = self.dropout1(x)
732
+ x = residual + x
733
+
734
+ x = self.self_attn_layer_norm(x)
735
+
736
+ residual = x
737
+ if self.activation_name == "glu":
738
+ x = self.fc1(x)
739
+ else:
740
+ x = self.activation_fn(self.fc1(x))
741
+ x = self.dropout2(x)
742
+ x = self.fc2(x)
743
+ x = self.dropout3(x)
744
+ x = residual + x
745
+
746
+
747
+ x = self.final_layer_norm(x)
748
+
749
+ return x, attn, pos_bias
models/Baseline/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights
optimizer/adamw.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+
6
+ def Optimizer(parameters, lr, **kwargs):
7
+
8
+ print('Initialised Adam optimizer')
9
+
10
+ return torch.optim.AdamW(parameters, lr = lr);
pseudo_labeling.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+
4
+ import torch
5
+
6
+ from cuml.cluster import KMeans
7
+
8
+ # from sklearn.cluster import KMeans
9
+ from sklearn.cluster import AgglomerativeClustering
10
+
11
+ from sklearn.metrics import normalized_mutual_info_score
12
+
13
+
14
+ def main(args):
15
+ # Load embeddings
16
+ embeddings_file = torch.load(args.embeddings_file)
17
+ files = list(embeddings_file.keys())
18
+ labels = [file.split('/')[-3] for file in files]
19
+ embeddings = torch.cat(list(embeddings_file.values())).numpy()
20
+ print(f"Embedding shape: {embeddings.shape}")
21
+
22
+ # K-Means
23
+ print("KMeans...")
24
+ kmeans_start_time = time.time()
25
+ kmeans = KMeans(
26
+ n_clusters=args.n_clusters,
27
+ random_state=0,
28
+ max_samples_per_batch=1000000,
29
+ verbose=True
30
+ ).fit(embeddings)
31
+ pseudo_labels = kmeans.labels_
32
+ centroids = kmeans.cluster_centers_
33
+ print(f"K-Means duration: {(time.time() - kmeans_start_time)/60:.2f} min")
34
+
35
+ # AHC
36
+ if args.n_clusters_ahc > 0:
37
+ print("AHC...")
38
+ ahc_start_time = time.time()
39
+ ahc_labels = AgglomerativeClustering(
40
+ n_clusters=args.n_clusters_ahc
41
+ ).fit_predict(centroids)
42
+ pseudo_labels = [ahc_labels[pl] for pl in pseudo_labels]
43
+ print(f"AHC duration: {(time.time() - ahc_start_time)/60:.2f} min")
44
+
45
+ # Print NMI
46
+ nmi_score = normalized_mutual_info_score(labels, pseudo_labels)
47
+ print(f"NMI: {nmi_score}")
48
+
49
+ # Export pseudo labels
50
+ with open(args.output_file, 'w') as f:
51
+ for file, pseudo_label in zip(files, pseudo_labels):
52
+ f.write(f"{pseudo_label} {file}\n")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument(
58
+ 'embeddings_file',
59
+ help='Path to embeddings file (.pt).'
60
+ )
61
+ parser.add_argument(
62
+ 'output_file',
63
+ help='Path to output file (.txt).'
64
+ )
65
+ parser.add_argument(
66
+ '--n_clusters',
67
+ help='Number of clusters for KMeans.',
68
+ type=int,
69
+ default=50000
70
+ )
71
+ parser.add_argument(
72
+ '--n_clusters_ahc',
73
+ help='Number of clusters for Agglomerative Clustering.',
74
+ type=int,
75
+ default=7500
76
+ )
77
+ args = parser.parse_args()
78
+
79
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ torchaudio>=0.7.0
3
+ numpy
4
+ scipy
5
+ scikit-learn
6
+ pyyaml
7
+ soundfile
8
+
9
+ --extra-index-url https://pypi.nvidia.com
10
+ cuml-cu12==24.8.*
scheduler/steplr.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+
6
+ def Scheduler(optimizer, test_interval, max_epoch, lr_decay, **kwargs):
7
+
8
+ sche_fn = torch.optim.lr_scheduler.StepLR(optimizer, step_size=test_interval, gamma=lr_decay)
9
+
10
+ lr_step = 'epoch'
11
+
12
+ print('Initialised step LR scheduler')
13
+
14
+ return sche_fn, lr_step
15
+
16
+
tools/rsync_jz.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source_path="."
2
+ target_path="jeanzay:~/wavlm_ssl_sv"
3
+
4
+ rsync -azh $source_path $target_path \
5
+ --progress \
6
+ --force \
7
+ --delete \
8
+ --exclude="slurm_*" \
9
+ --exclude="data" \
10
+ --exclude="exp" \
11
+ --keep-dirlinks
12
+
13
+ while inotifywait -r -e modify,create,delete $source_path
14
+ do
15
+ rsync -azh $source_path $target_path \
16
+ --progress \
17
+ --force \
18
+ --delete \
19
+ --exclude="slurm_*" \
20
+ --exclude="data" \
21
+ --exclude="exp" \
22
+ --keep-dirlinks
23
+ done
trainSpeakerNet.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import sys, time, os, argparse, socket
5
+ import yaml
6
+ import numpy
7
+ import pdb
8
+ import torch
9
+ import glob
10
+ import zipfile
11
+ import warnings
12
+ import datetime
13
+ from tuneThreshold import *
14
+ from SpeakerNet import *
15
+ from DatasetLoader import *
16
+ import torch.distributed as dist
17
+ import torch.multiprocessing as mp
18
+ from scipy.stats import norm
19
+ from sklearn.mixture import GaussianMixture
20
+
21
+ ## ===== ===== ===== ===== ===== ===== ===== =====
22
+ ## Parse arguments
23
+ ## ===== ===== ===== ===== ===== ===== ===== =====
24
+ # os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
25
+
26
+ parser = argparse.ArgumentParser(description = "SpeakerNet");
27
+
28
+ parser.add_argument('--config', type=str, default=None, help='Config YAML file');
29
+
30
+ ## Data loader
31
+ parser.add_argument('--max_frames', type=int, default=200, help='Input length to the network for training');
32
+ parser.add_argument('--eval_frames', type=int, default=300, help='Input length to the network for testing; 0 uses the whole files');
33
+ parser.add_argument('--batch_size', type=int, default=400, help='Batch size, number of speakers per batch');
34
+ parser.add_argument('--max_seg_per_spk', type=int, default=500, help='Maximum number of utterances per speaker per epoch');
35
+ parser.add_argument('--nDataLoaderThread', type=int, default=10, help='Number of loader threads');
36
+ parser.add_argument('--augment', type=bool, default=True, help='Augment input')
37
+ parser.add_argument('--seed', type=int, default=20211202, help='Seed for the random number generator');
38
+
39
+
40
+
41
+ ## Training details
42
+ parser.add_argument('--test_interval', type=int, default=1, help='Test and save every [test_interval] epochs');
43
+ parser.add_argument('--max_epoch', type=int, default=50, help='Maximum number of epochs');
44
+ parser.add_argument('--trainfunc', type=str, default="aamsoftmax", help='Loss function');
45
+
46
+ ## Optimizer
47
+ parser.add_argument('--optimizer', type=str, default="adamw", help='sgd or adam');
48
+ parser.add_argument('--scheduler', type=str, default="steplr", help='Learning rate scheduler');
49
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate');
50
+ parser.add_argument("--lr_decay", type=float, default=0.9, help='Learning rate decay every [test_interval] epochs');
51
+
52
+
53
+ ## Pre-trained Transformer Model
54
+ parser.add_argument('--pretrained_model_path', type=str, default="None", help='Absolute path to the pre-trained model');
55
+ parser.add_argument('--weight_finetuning_reg', type=float, default=0.001, help='L2 regularization towards the initial pre-trained model');
56
+ parser.add_argument('--LLRD_factor', type=float, default=1.0, help='Layer-wise Learning Rate Decay (LLRD) factor');
57
+ parser.add_argument('--LR_Transformer', type=float, default=2e-5, help='Learning rate of pre-trained model');
58
+ parser.add_argument('--LR_MHFA', type=float, default=5e-3, help='Learning rate of back-end attentive pooling model');
59
+
60
+ ## Loss functions
61
+ parser.add_argument("--hard_prob", type=float, default=0.5, help='Hard negative mining probability, otherwise random, only for some loss functions');
62
+ parser.add_argument("--hard_rank", type=int, default=10, help='Hard negative mining rank in the batch, only for some loss functions');
63
+ parser.add_argument('--margin', type=float, default=0.2, help='Loss margin, only for some loss functions');
64
+ parser.add_argument('--scale', type=float, default=30, help='Loss scale, only for some loss functions');
65
+ parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses');
66
+ parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses');
67
+
68
+ ## Evaluation parameters
69
+ parser.add_argument('--dcf_p_target', type=float, default=0.05, help='A priori probability of the specified target speaker');
70
+ parser.add_argument('--dcf_c_miss', type=float, default=1, help='Cost of a missed detection');
71
+ parser.add_argument('--dcf_c_fa', type=float, default=1, help='Cost of a spurious detection');
72
+
73
+ ## Load and save
74
+ parser.add_argument('--initial_model', type=str, default="", help='Initial model weights');
75
+ parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs');
76
+
77
+ ## Training and test data
78
+ parser.add_argument('--train_list', type=str, default="data/train_list.txt", help='Train list');
79
+ parser.add_argument('--test_list', type=str, default="data/test_list.txt", help='Evaluation list');
80
+ parser.add_argument('--train_path', type=str, default="data/voxceleb2", help='Absolute path to the train set');
81
+ parser.add_argument('--test_path', type=str, default="data/voxceleb1", help='Absolute path to the test set');
82
+ parser.add_argument('--musan_path', type=str, default="data/musan_split", help='Absolute path to the test set');
83
+ parser.add_argument('--rir_path', type=str, default="data/simulated_rirs", help='Absolute path to the test set');
84
+
85
+ ## Model definition
86
+ parser.add_argument('--n_mels', type=int, default=80, help='Number of mel filterbanks');
87
+ parser.add_argument('--log_input', type=bool, default=False, help='Log input features')
88
+ parser.add_argument('--model', type=str, default="", help='Name of model definition');
89
+ parser.add_argument('--encoder_type', type=str, default="SAP", help='Type of encoder');
90
+ parser.add_argument('--nOut', type=int, default=192, help='Embedding size in the last FC layer');
91
+
92
+ ## For test only
93
+ parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only')
94
+
95
+ ## Distributed and mixed precision training
96
+ parser.add_argument('--port', type=str, default="7888", help='Port for distributed training, input as text');
97
+ parser.add_argument('--distributed', dest='distributed', action='store_true', help='Enable distributed training')
98
+ parser.add_argument('--mixedprec', dest='mixedprec', action='store_true', help='Enable mixed precision training')
99
+
100
+ args = parser.parse_args();
101
+
102
+ ## Parse YAML
103
+ def find_option_type(key, parser):
104
+ for opt in parser._get_optional_actions():
105
+ if ('--' + key) in opt.option_strings:
106
+ return opt.type
107
+ raise ValueError
108
+
109
+ if args.config is not None:
110
+ with open(args.config, "r") as f:
111
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
112
+ for k, v in yml_config.items():
113
+ if k in args.__dict__:
114
+ typ = find_option_type(k, parser)
115
+ args.__dict__[k] = typ(v)
116
+ else:
117
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
118
+
119
+
120
+ ## Try to import NSML
121
+ try:
122
+ import nsml
123
+ from nsml import HAS_DATASET, DATASET_PATH, PARALLEL_WORLD, PARALLEL_PORTS, MY_RANK
124
+ from nsml import NSML_NFS_OUTPUT, SESSION_NAME
125
+ except:
126
+ pass;
127
+
128
+ warnings.simplefilter("ignore")
129
+
130
+ ## ===== ===== ===== ===== ===== ===== ===== =====
131
+ ## Trainer script
132
+ ## ===== ===== ===== ===== ===== ===== ===== =====
133
+
134
+ def LGL_threshold_update_gmm(loss_vals_path):
135
+ with open(loss_vals_path, 'r') as f:
136
+ lines = [line.strip().split() for line in f.readlines()]
137
+
138
+ # losses = [float(line[0]) for line in lines]
139
+ losses = []
140
+ errs = 0
141
+ for line in lines:
142
+ try:
143
+ losses.append(float(line[0]))
144
+ except ValueError:
145
+ errs += 1
146
+ pass
147
+ if errs > 0:
148
+ print('Could not read %d lines' % errs)
149
+
150
+ log_losses = np.log(losses)
151
+
152
+ gmm = GaussianMixture(n_components=2, random_state=0, covariance_type='full', tol=0.00001, max_iter=1000)
153
+ gmm.fit(log_losses.reshape(-1, 1))
154
+
155
+ mean1 = gmm.means_[0, 0]
156
+ covar1 = gmm.covariances_[0, 0]
157
+ weight1 = gmm.weights_[0]
158
+ x = np.linspace(min(log_losses), max(log_losses), 1000)
159
+ g1 = weight1 * norm.pdf(x, mean1, np.sqrt(covar1))
160
+
161
+ mean2 = gmm.means_[1, 0]
162
+ covar2 = gmm.covariances_[1, 0]
163
+ weight2 = gmm.weights_[1]
164
+ g2 = weight2 * norm.pdf(x, mean2, np.sqrt(covar2))
165
+
166
+ intersection = np.argwhere(np.diff(np.sign(g1 - g2))).flatten()
167
+
168
+ max1 = x[np.argmax(g1)]
169
+ max2 = x[np.argmax(g2)]
170
+ good_intersection = x[intersection][(x[intersection] > min(max1, max2)) & (x[intersection] < max(max1, max2))]
171
+ assert len(good_intersection) == 1, 'Wrong number of intersections'
172
+ good_intersection = good_intersection[0]
173
+
174
+ return good_intersection
175
+
176
+ import idr_torch
177
+
178
+ def main_worker(gpu, ngpus_per_node, args):
179
+
180
+ args.gpu = gpu
181
+
182
+ args.gpu = idr_torch.rank
183
+ ngpus_per_node = idr_torch.size
184
+
185
+ ## Load models
186
+ s = SpeakerNet(**vars(args));
187
+
188
+ if args.distributed:
189
+ # os.environ['MASTER_ADDR']='localhost'
190
+ # os.environ['MASTER_PORT']=args.port
191
+
192
+ # dist.init_process_group(backend='nccl', world_size=ngpus_per_node, rank=args.gpu, init_method='tcp://localhost:12345')
193
+ dist.init_process_group(backend='nccl', world_size=ngpus_per_node, rank=args.gpu)
194
+
195
+ torch.cuda.set_device(args.gpu)
196
+ s.cuda(args.gpu)
197
+
198
+ s = torch.nn.parallel.DistributedDataParallel(s, device_ids=[args.gpu])#, find_unused_parameters=True)
199
+
200
+ print('Loaded the model on GPU {:d}'.format(args.gpu))
201
+
202
+ else:
203
+ s = WrappedModel(s).cuda(args.gpu)
204
+
205
+ it = 1
206
+ eers = [100];
207
+
208
+ if args.gpu == 0:
209
+ ## Write args to scorefile
210
+ scorefile = open(args.result_save_path+"/scores.txt", "a+");
211
+
212
+ ## Initialise trainer and data loader
213
+ train_dataset = train_dataset_loader(**vars(args))
214
+
215
+ train_sampler = train_dataset_sampler(train_dataset, **vars(args))
216
+
217
+ train_loader = torch.utils.data.DataLoader(
218
+ train_dataset,
219
+ batch_size=args.batch_size,
220
+ num_workers=args.nDataLoaderThread,
221
+ sampler=train_sampler,
222
+ pin_memory=True,
223
+ worker_init_fn=worker_init_fn,
224
+ drop_last=True,
225
+ )
226
+
227
+ # trainLoader = get_data_loader(args.train_list, **vars(args));
228
+ trainer = ModelTrainer(s, **vars(args))
229
+
230
+ ## Load model weights
231
+ modelfiles = glob.glob('%s/model0*.model'%args.model_save_path)
232
+ modelfiles.sort()
233
+
234
+ if(args.initial_model != ""):
235
+ trainer.loadParameters(args.initial_model);
236
+ print("Model {} loaded!".format(args.initial_model));
237
+ elif len(modelfiles) >= 1:
238
+ print("Model {} loaded from previous state!".format(modelfiles[-1]));
239
+ trainer.loadParameters(modelfiles[-1]);
240
+ it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1
241
+
242
+ for ii in range(1,it):
243
+ trainer.__scheduler__.step()
244
+
245
+
246
+ pytorch_total_params = sum(p.numel() for p in s.module.__S__.parameters())
247
+
248
+ print('Total parameters: ',pytorch_total_params)
249
+ ## Evaluation code - must run on single GPU
250
+ if args.eval == True:
251
+
252
+
253
+ print('Test list',args.test_list)
254
+
255
+ sc, lab, _, sc1,sc2 = trainer.evaluateFromList(**vars(args))
256
+
257
+ if args.gpu == 0:
258
+
259
+ result = tuneThresholdfromScore(sc, lab, [1, 0.1]);
260
+ result_s1 = tuneThresholdfromScore(sc1, lab, [1, 0.1]);
261
+ result_s2 = tuneThresholdfromScore(sc2, lab, [1, 0.1]);
262
+
263
+
264
+
265
+ fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
266
+ mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa)
267
+
268
+ print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]), "VEER_s1 {:2.4f}".format(result_s1[1]),"VEER_s2 {:2.4f}".format(result_s2[1]),"MinDCF {:2.5f}".format(mindcf));
269
+
270
+ if ("nsml" in sys.modules) and args.gpu == 0:
271
+ training_report = {};
272
+ training_report["summary"] = True;
273
+ training_report["epoch"] = it;
274
+ training_report["step"] = it;
275
+ training_report["val_eer"] = result[1];
276
+ training_report["val_dcf"] = mindcf;
277
+
278
+ nsml.report(**training_report);
279
+
280
+ return
281
+
282
+ ## Save training code and params
283
+ if args.gpu == 0:
284
+ pyfiles = glob.glob('./*.py')
285
+ strtime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
286
+
287
+ zipf = zipfile.ZipFile(args.result_save_path+ '/run%s.zip'%strtime, 'w', zipfile.ZIP_DEFLATED)
288
+ for file in pyfiles:
289
+ zipf.write(file)
290
+ zipf.close()
291
+
292
+ with open(args.result_save_path + '/run%s.cmd'%strtime, 'w') as f:
293
+ f.write('%s'%args)
294
+
295
+
296
+ ## Core training script
297
+ for it in range(it,args.max_epoch+1):
298
+
299
+ train_sampler.set_epoch(it)
300
+
301
+ clr = [x['lr'] for x in trainer.__optimizer__.param_groups]
302
+
303
+ loss_vals_dir = 'exp/' + args.save_path.split('/')[-1] + '/loss_vals'
304
+ os.makedirs(loss_vals_dir, exist_ok=True)
305
+ loss_vals_path = os.path.join(loss_vals_dir, 'epoch_%d.txt' % it)
306
+
307
+ if it >= 5:
308
+ prev_loss_vals_path = os.path.join(loss_vals_dir, 'epoch_%d.txt' % (it - 1))
309
+ LGL_threshold = LGL_threshold_update_gmm(prev_loss_vals_path)
310
+ # LGL_threshold = 1
311
+
312
+ if args.gpu == 0:
313
+ if LGL_threshold is not None:
314
+ print('Updated LGL threshold to %f' % LGL_threshold)
315
+ else:
316
+ print('Wrong number of intersections, keeping LGL threshold at %f' % LGL_threshold)
317
+
318
+ trainer.update_lgl_threshold(LGL_threshold)
319
+
320
+
321
+ loss, traineer = trainer.train_network(train_loader, loss_vals_path, it, verbose=(args.gpu == 0))
322
+
323
+ if args.distributed:
324
+ dist.barrier()
325
+ with open(loss_vals_path, 'w') as final_file:
326
+ for r in range(dist.get_world_size()):
327
+ part_file_path = f"{loss_vals_path.split('.')[0]}_rank{r}.txt"
328
+ with open(part_file_path, 'r') as part_file:
329
+ final_file.write(part_file.read())
330
+
331
+ if args.gpu == 0:
332
+ print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f}".format(it, traineer.item(), loss.item(), max(clr)));
333
+ scorefile.write("Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f} \n".format(it, traineer.item(), loss.item(), max(clr)));
334
+
335
+ if it % args.test_interval == 0:
336
+
337
+ # sc, lab, _, as1, as2 = trainer.evaluateFromList(**vars(args))
338
+
339
+ if args.gpu == 0:
340
+ trainer.saveParameters(args.model_save_path+"/model%09d.model"%it);
341
+
342
+ scorefile.flush()
343
+
344
+ if ("nsml" in sys.modules) and args.gpu == 0:
345
+ training_report = {};
346
+ training_report["summary"] = True;
347
+ training_report["epoch"] = it;
348
+ training_report["step"] = it;
349
+ training_report["train_loss"] = loss;
350
+ training_report["min_eer"] = min(eers);
351
+
352
+ nsml.report(**training_report);
353
+
354
+ if args.gpu == 0:
355
+ scorefile.close();
356
+
357
+ ## ===== ===== ===== ===== ===== ===== ===== =====
358
+ ## Main function
359
+ ## ===== ===== ===== ===== ===== ===== ===== =====
360
+
361
+
362
+ def main():
363
+
364
+ # print(os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set'))
365
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
366
+ # print(os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set'))
367
+
368
+
369
+ if ("nsml" in sys.modules) and not args.eval:
370
+ args.save_path = os.path.join(args.save_path,SESSION_NAME.replace('/','_'))
371
+
372
+ args.model_save_path = args.save_path+"/model"
373
+ args.result_save_path = args.save_path+"/result"
374
+ args.feat_save_path = ""
375
+
376
+ os.makedirs(args.model_save_path, exist_ok=True)
377
+ os.makedirs(args.result_save_path, exist_ok=True)
378
+
379
+ n_gpus = torch.cuda.device_count()
380
+ print(n_gpus)
381
+
382
+ print('Python Version:', sys.version)
383
+ print('PyTorch Version:', torch.__version__)
384
+ print('Number of GPUs:', torch.cuda.device_count())
385
+ print('Save path:',args.save_path)
386
+
387
+ if args.distributed:
388
+ # mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, args))
389
+ main_worker(None, None, args)
390
+ else:
391
+ main_worker(0, None, args)
392
+
393
+
394
+ if __name__ == '__main__':
395
+ main()
trainSpeakerNet_Eval.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import sys, time, os, argparse, socket
5
+ import yaml
6
+ import numpy
7
+ import pdb
8
+ import torch
9
+ import glob
10
+ import zipfile
11
+ import csv
12
+ import warnings
13
+ import datetime
14
+ from tuneThreshold import *
15
+ from SpeakerNet import *
16
+ from DatasetLoader import *
17
+ import torch.distributed as dist
18
+ import torch.multiprocessing as mp
19
+
20
+ ## ===== ===== ===== ===== ===== ===== ===== =====
21
+ ## Parse arguments
22
+ ## ===== ===== ===== ===== ===== ===== ===== =====
23
+ # os.environ['CUDA_VISIBLE_DEVICES']='0'
24
+ parser = argparse.ArgumentParser(description = "SpeakerNet");
25
+
26
+ parser.add_argument('--config', type=str, default=None, help='Config YAML file');
27
+
28
+ ## Data loader
29
+ parser.add_argument('--max_frames', type=int, default=200, help='Input length to the network for training');
30
+ parser.add_argument('--eval_frames', type=int, default=300, help='Input length to the network for testing; 0 uses the whole files');
31
+ parser.add_argument('--batch_size', type=int, default=400, help='Batch size, number of speakers per batch');
32
+ parser.add_argument('--max_seg_per_spk', type=int, default=500, help='Maximum number of utterances per speaker per epoch');
33
+ parser.add_argument('--nDataLoaderThread', type=int, default=10, help='Number of loader threads');
34
+ parser.add_argument('--augment', type=bool, default=True, help='Augment input')
35
+ parser.add_argument('--seed', type=int, default=20211202, help='Seed for the random number generator');
36
+
37
+
38
+
39
+ ## Training details
40
+ parser.add_argument('--test_interval', type=int, default=1, help='Test and save every [test_interval] epochs');
41
+ parser.add_argument('--max_epoch', type=int, default=50, help='Maximum number of epochs');
42
+ parser.add_argument('--trainfunc', type=str, default="aamsoftmax", help='Loss function');
43
+
44
+ ## Optimizer
45
+ parser.add_argument('--optimizer', type=str, default="adamw", help='sgd or adam');
46
+ parser.add_argument('--scheduler', type=str, default="steplr", help='Learning rate scheduler');
47
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate');
48
+
49
+
50
+ ## Pre-trained Transformer Model
51
+ parser.add_argument('--pretrained_model_path', type=str, default="None", help='Absolute path to the pre-trained model');
52
+ parser.add_argument('--weight_finetuning_reg', type=float, default=0.001, help='L2 regularization towards the initial pre-trained model');
53
+ parser.add_argument('--LLRD_factor', type=float, default=1.0, help='Layer-wise Learning Rate Decay (LLRD) factor');
54
+ parser.add_argument('--LR_Transformer', type=float, default=2e-5, help='Learning rate of pre-trained model');
55
+ parser.add_argument('--LR_MHFA', type=float, default=5e-3, help='Learning rate of back-end attentive pooling model');
56
+
57
+ ## Loss functions
58
+ parser.add_argument("--hard_prob", type=float, default=0.5, help='Hard negative mining probability, otherwise random, only for some loss functions');
59
+ parser.add_argument("--hard_rank", type=int, default=10, help='Hard negative mining rank in the batch, only for some loss functions');
60
+ parser.add_argument('--margin', type=float, default=0.2, help='Loss margin, only for some loss functions');
61
+ parser.add_argument('--scale', type=float, default=30, help='Loss scale, only for some loss functions');
62
+ parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses');
63
+ parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses');
64
+
65
+ ## Evaluation parameters
66
+ parser.add_argument('--dcf_p_target', type=float, default=0.05, help='A priori probability of the specified target speaker');
67
+ parser.add_argument('--dcf_c_miss', type=float, default=1, help='Cost of a missed detection');
68
+ parser.add_argument('--dcf_c_fa', type=float, default=1, help='Cost of a spurious detection');
69
+
70
+ ## Load and save
71
+ parser.add_argument('--initial_model', type=str, default="", help='Initial model weights');
72
+ parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs');
73
+
74
+ ## Training and test data
75
+ parser.add_argument('--train_list', type=str, default="data/train_list.txt", help='Train list');
76
+ parser.add_argument('--test_list', type=str, default="data/test_list.txt", help='Evaluation list');
77
+ parser.add_argument('--train_path', type=str, default="data/voxceleb2", help='Absolute path to the train set');
78
+ parser.add_argument('--test_path', type=str, default="data/voxceleb1", help='Absolute path to the test set');
79
+ parser.add_argument('--musan_path', type=str, default="data/musan_split", help='Absolute path to the test set');
80
+ parser.add_argument('--rir_path', type=str, default="data/simulated_rirs", help='Absolute path to the test set');
81
+
82
+ ## Model definition
83
+ parser.add_argument('--n_mels', type=int, default=80, help='Number of mel filterbanks');
84
+ parser.add_argument('--log_input', type=bool, default=False, help='Log input features')
85
+ parser.add_argument('--model', type=str, default="", help='Name of model definition');
86
+ parser.add_argument('--encoder_type', type=str, default="SAP", help='Type of encoder');
87
+ parser.add_argument('--nOut', type=int, default=192, help='Embedding size in the last FC layer');
88
+
89
+ ## For test only
90
+ parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only')
91
+
92
+ parser.add_argument('--generate_embeddings', dest='generate_embeddings', action='store_true', help='Generate embeddings for the train set')
93
+ parser.add_argument('--embeddings_path', type=str, default="")
94
+ parser.add_argument('--generate_pseudo_labels', dest='generate_pseudo_labels', action='store_true', help='Generate pseudo labels for the train set')
95
+
96
+ ## Distributed and mixed precision training
97
+ parser.add_argument('--port', type=str, default="7888", help='Port for distributed training, input as text');
98
+ parser.add_argument('--distributed', dest='distributed', action='store_true', help='Enable distributed training')
99
+ parser.add_argument('--mixedprec', dest='mixedprec', action='store_true', help='Enable mixed precision training')
100
+
101
+ args = parser.parse_args();
102
+
103
+ ## Parse YAML
104
+ def find_option_type(key, parser):
105
+ for opt in parser._get_optional_actions():
106
+ if ('--' + key) in opt.option_strings:
107
+ return opt.type
108
+ raise ValueError
109
+
110
+ if args.config is not None:
111
+ with open(args.config, "r") as f:
112
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
113
+ for k, v in yml_config.items():
114
+ if k in args.__dict__:
115
+ typ = find_option_type(k, parser)
116
+ args.__dict__[k] = typ(v)
117
+ else:
118
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
119
+
120
+
121
+ ## Try to import NSML
122
+ try:
123
+ import nsml
124
+ from nsml import HAS_DATASET, DATASET_PATH, PARALLEL_WORLD, PARALLEL_PORTS, MY_RANK
125
+ from nsml import NSML_NFS_OUTPUT, SESSION_NAME
126
+ except:
127
+ pass;
128
+
129
+ warnings.simplefilter("ignore")
130
+
131
+ ## ===== ===== ===== ===== ===== ===== ===== =====
132
+ ## Trainer script
133
+ ## ===== ===== ===== ===== ===== ===== ===== =====
134
+
135
+ def main_worker(gpu, ngpus_per_node, args):
136
+
137
+ args.gpu = gpu
138
+
139
+ ## Load models
140
+ s = SpeakerNet(**vars(args));
141
+
142
+
143
+ s = WrappedModel(s).cuda(args.gpu)
144
+
145
+ it = 1
146
+ eers = [100];
147
+
148
+ # trainLoader = get_data_loader(args.train_list, **vars(args));
149
+ trainer = ModelTrainer(s, **vars(args))
150
+
151
+ ## Load model weights
152
+ modelfiles = glob.glob('%s/model0*.model'%args.model_save_path)
153
+ modelfiles.sort()
154
+
155
+ if(args.initial_model != ""):
156
+ trainer.loadParameters(args.initial_model);
157
+ print("Model {} loaded!".format(args.initial_model));
158
+ elif len(modelfiles) >= 1:
159
+ # print("Model {} loaded from previous state!".format(modelfiles[-2]));
160
+ # trainer.loadParameters(modelfiles[-2]);
161
+ it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1
162
+
163
+ for ii in range(1,it):
164
+ trainer.__scheduler__.step()
165
+
166
+
167
+ # pytorch_total_params = sum(p.numel() for p in s.module.__S__.model.feature_extractor.parameters())
168
+ pytorch_total_params = sum(p.numel() for p in s.module.__S__.parameters())
169
+
170
+
171
+ print('Total parameters: ',pytorch_total_params)
172
+ # quit();
173
+ ## Evaluation code - must run on single GPU
174
+ if args.eval == True:
175
+ scorefile_score = open(args.result_save_path+"/Eval_scores_mean_O_All.txt", "w");
176
+ print('Test list',args.test_list)
177
+
178
+ for i in range(1,15):
179
+ print("Model {} loaded from previous state!".format(modelfiles[-i]));
180
+ trainer.loadParameters(modelfiles[-i]);
181
+ # trainer.loadParameters(modelfiles[0]);
182
+
183
+ # sc, lab, _,sc1,sc2 = trainer.evaluateFromList_1utterance(**vars(args))
184
+ sc, lab, _,sc1,sc2 = trainer.evaluateFromList(**vars(args))
185
+
186
+ if args.gpu == 0:
187
+
188
+ result = tuneThresholdfromScore(sc, lab, [1, 0.1]);
189
+ result1 = tuneThresholdfromScore(sc1, lab, [1, 0.1]);
190
+ result2 = tuneThresholdfromScore(sc2, lab, [1, 0.1]);
191
+
192
+ fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
193
+
194
+ mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa)
195
+ mindcf_1, threshold_1 = ComputeMinDcf(fnrs, fprs, thresholds, 0.01, args.dcf_c_miss, args.dcf_c_fa)
196
+
197
+ print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]),"MinDCF05 {:2.5f}".format(mindcf), "MinDCF01 {:2.5f}".format(mindcf_1));
198
+
199
+ scorefile_score.write("Epoch {}, VEER {:2.4f}, VEER_S1 {:2.4f}, VEER_S2 {:2.4f}, MinDCF05 {:2.5f}, MinDCF01 {:2.5f}\n".format(modelfiles[-i], result[1], result1[1], result2[1], mindcf,mindcf_1));
200
+ scorefile_score.flush()
201
+
202
+ scorefile_score.close()
203
+ return
204
+
205
+ if args.generate_embeddings == True:
206
+ print('Generate embeddings for the train set')
207
+ wav_list_file = args.train_list
208
+ with open(wav_list_file,'r') as f:
209
+ wav_files = [args.train_path + '/' + line.strip().split()[1] for line in f.readlines()]
210
+
211
+ print("Model {} loaded from previous state!".format(modelfiles[-1]));
212
+ trainer.loadParameters(modelfiles[-1]);
213
+
214
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
215
+
216
+ trainer.generate_embeddings(wav_files, args.embeddings_path, device)
217
+
218
+
219
+ ## ===== ===== ===== ===== ===== ===== ===== =====
220
+ ## Main function
221
+ ## ===== ===== ===== ===== ===== ===== ===== =====
222
+
223
+
224
+ def main():
225
+
226
+ if ("nsml" in sys.modules) and not args.eval:
227
+ args.save_path = os.path.join(args.save_path,SESSION_NAME.replace('/','_'))
228
+
229
+ args.model_save_path = args.save_path+"/model"
230
+ args.result_save_path = args.save_path+"/result"
231
+ args.feat_save_path = ""
232
+
233
+ os.makedirs(args.model_save_path, exist_ok=True)
234
+ os.makedirs(args.result_save_path, exist_ok=True)
235
+
236
+ n_gpus = torch.cuda.device_count()
237
+
238
+ print('Python Version:', sys.version)
239
+ print('PyTorch Version:', torch.__version__)
240
+ print('Number of GPUs:', torch.cuda.device_count())
241
+ print('Save path:',args.save_path)
242
+
243
+ if args.distributed:
244
+ mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, args))
245
+ else:
246
+ main_worker(0, None, args)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ main()
train_ddp_jz.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=wavlm_ssl_sv
4
+ #SBATCH --output=slurm_%j
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=2
7
+ #SBATCH --gres=gpu:2
8
+ #SBATCH --cpus-per-task=10
9
+ #SBATCH --constraint=a100
10
+ #SBATCH --time=20:00:00
11
+ #SBATCH --hint=nomultithread
12
+ #SBATCH --account=kdp@a100
13
+
14
+ module purge
15
+
16
+ module load cpuarch/amd
17
+ module load pytorch-gpu/py3/1.12.1
18
+
19
+ srun python -u trainSpeakerNet.py --config configs/wavlm_mhfa_dlg_lc.yaml --train_list exp/train_list_dino.txt --distributed
training_framework.svg ADDED
tuneThreshold.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import glob
6
+ import sys
7
+ import time
8
+ from sklearn import metrics
9
+ import numpy
10
+ import pdb
11
+ from operator import itemgetter
12
+
13
+ def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None):
14
+
15
+ fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1)
16
+ fnr = 1 - tpr
17
+
18
+ tunedThreshold = [];
19
+ if target_fr:
20
+ for tfr in target_fr:
21
+ idx = numpy.nanargmin(numpy.absolute((tfr - fnr)))
22
+ tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]);
23
+
24
+ for tfa in target_fa:
25
+ idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1]
26
+ tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]);
27
+
28
+ idxE = numpy.nanargmin(numpy.absolute((fnr - fpr)))
29
+ eer = max(fpr[idxE],fnr[idxE])*100
30
+
31
+ return (tunedThreshold, eer, fpr, fnr);
32
+
33
+ # Creates a list of false-negative rates, a list of false-positive rates
34
+ # and a list of decision thresholds that give those error-rates.
35
+ def ComputeErrorRates(scores, labels):
36
+
37
+ # Sort the scores from smallest to largest, and also get the corresponding
38
+ # indexes of the sorted scores. We will treat the sorted scores as the
39
+ # thresholds at which the the error-rates are evaluated.
40
+ sorted_indexes, thresholds = zip(*sorted(
41
+ [(index, threshold) for index, threshold in enumerate(scores)],
42
+ key=itemgetter(1)))
43
+ sorted_labels = []
44
+ labels = [labels[i] for i in sorted_indexes]
45
+ fnrs = []
46
+ fprs = []
47
+
48
+ # At the end of this loop, fnrs[i] is the number of errors made by
49
+ # incorrectly rejecting scores less than thresholds[i]. And, fprs[i]
50
+ # is the total number of times that we have correctly accepted scores
51
+ # greater than thresholds[i].
52
+ for i in range(0, len(labels)):
53
+ if i == 0:
54
+ fnrs.append(labels[i])
55
+ fprs.append(1 - labels[i])
56
+ else:
57
+ fnrs.append(fnrs[i-1] + labels[i])
58
+ fprs.append(fprs[i-1] + 1 - labels[i])
59
+ fnrs_norm = sum(labels)
60
+ fprs_norm = len(labels) - fnrs_norm
61
+
62
+ # Now divide by the total number of false negative errors to
63
+ # obtain the false positive rates across all thresholds
64
+ fnrs = [x / float(fnrs_norm) for x in fnrs]
65
+
66
+ # Divide by the total number of corret positives to get the
67
+ # true positive rate. Subtract these quantities from 1 to
68
+ # get the false positive rates.
69
+ fprs = [1 - x / float(fprs_norm) for x in fprs]
70
+ return fnrs, fprs, thresholds
71
+
72
+ # Computes the minimum of the detection cost function. The comments refer to
73
+ # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
74
+ def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
75
+ min_c_det = float("inf")
76
+ min_c_det_threshold = thresholds[0]
77
+ for i in range(0, len(fnrs)):
78
+ # See Equation (2). it is a weighted sum of false negative
79
+ # and false positive errors.
80
+ c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
81
+ if c_det < min_c_det:
82
+ min_c_det = c_det
83
+ min_c_det_threshold = thresholds[i]
84
+ # See Equations (3) and (4). Now we normalize the cost.
85
+ c_def = min(c_miss * p_target, c_fa * (1 - p_target))
86
+ min_dcf = min_c_det / c_def
87
+ return min_dcf, min_c_det_threshold
utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ def accuracy(output, target, topk=(1,)):
8
+ """Computes the precision@k for the specified values of k"""
9
+ maxk = max(topk)
10
+ batch_size = target.size(0)
11
+
12
+ _, pred = output.topk(maxk, 1, True, True)
13
+ pred = pred.t()
14
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
15
+
16
+ res = []
17
+ for k in topk:
18
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
19
+ res.append(correct_k.mul_(100.0 / batch_size))
20
+ return res
21
+
22
+ class PreEmphasis(torch.nn.Module):
23
+
24
+ def __init__(self, coef: float = 0.97):
25
+ super().__init__()
26
+ self.coef = coef
27
+ # make kernel
28
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
29
+ self.register_buffer(
30
+ 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
31
+ )
32
+
33
+ def forward(self, input: torch.tensor) -> torch.tensor:
34
+ assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
35
+ # reflect padding to match lengths of in/out
36
+ input = input.unsqueeze(1)
37
+ input = F.pad(input, (1, 0), 'reflect')
38
+ return F.conv1d(input, self.flipped_filter).squeeze(1)