File size: 13,546 Bytes
430712c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
#!/usr/bin/python
#-*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy, math, pdb, sys, random
import time, os, itertools, shutil, importlib
from tuneThreshold import tuneThresholdfromScore
from DatasetLoader import test_dataset_loader, loadWAV
import pickle
import numpy as np
import time
from tqdm import tqdm
import soundfile


class WrappedModel(nn.Module):

    ## The purpose of this wrapper is to make the model structure consistent between single and multi-GPU

    def __init__(self, model):
        super(WrappedModel, self).__init__()
        self.module = model

    def forward(self, x, x_clean=None, label=None,l2_reg_dict=None, epoch=-1):
        return self.module(x, x_clean, label, epoch=epoch)


class SpeakerNet(nn.Module):

    def __init__(self, model, optimizer, trainfunc, nPerSpeaker, **kwargs):
        super(SpeakerNet, self).__init__()

        SpeakerNetModel = importlib.import_module('models.'+model).__getattribute__('MainModel')
        self.__S__ = SpeakerNetModel(**kwargs);

        LossFunction = importlib.import_module('loss.'+trainfunc).__getattribute__('LossFunction')
        self.__L__ = LossFunction(**kwargs);

        self.nPerSpeaker = nPerSpeaker
        self.weight_finetuning_reg = kwargs['weight_finetuning_reg']


    def forward(self, data, data_clean=None, label=None, l2_reg_dict=None, epoch=-1):
        if label is None:
            data_reshape = data[0].cuda()
            outp = self.__S__.forward([data_reshape, data[1]])
            return outp
        elif len(data) == 3 and data[2] == "gen_ps":
            data_reshape = data[0].reshape(-1,data[0].size()[-1]).cuda()
            outp = self.__S__.forward([data_reshape, data[1]])
            pseudo_labels = self.__L__.get_pseudo_labels(outp, label)
            return pseudo_labels
        else:
            data_reshape = data[0].reshape(-1,data[0].size()[-1]).cuda()
            data_clean_reshape = data_clean.reshape(-1,data_clean.size()[-1]).cuda()
            outp = self.__S__.forward([data_reshape, data[1]])
            outp_clean = self.__S__.forward([data_clean_reshape, data[1]])
            nloss, prec1, ce = self.__L__.forward(outp, outp_clean, label, epoch)

            if l2_reg_dict is not None:
                Learned_dict = l2_reg_dict
                l2_reg = 0
                for name,param in self.__S__.model.named_parameters():
                    if name in Learned_dict:
                        l2_reg = l2_reg + torch.norm(param-Learned_dict[name].cuda(),2)
                tloss = nloss/nloss.detach() + self.weight_finetuning_reg*l2_reg/(l2_reg.detach()+1e-5)
            else:
                tloss = nloss
                print("Without L2 Reg")

            return tloss, prec1, nloss, ce




class ModelTrainer(object):

    def __init__(self, speaker_model, optimizer, scheduler, gpu, mixedprec, **kwargs):

        self.__model__  = speaker_model

        WavLM_params = list(map(id, self.__model__.module.__S__.model.parameters()))
        Backend_params = filter(lambda p: id(p) not in WavLM_params, self.__model__.module.parameters())   
        self.path = kwargs['pretrained_model_path']

        Optimizer = importlib.import_module('optimizer.'+optimizer).__getattribute__('Optimizer')

        # Define the initial param groups
        param_groups = [{'params': Backend_params, 'lr': kwargs['LR_MHFA']}]

        # Extract the encoder layers
        encoder_layers = self.__model__.module.__S__.model.encoder.layers

        # Iterate over the encoder layers to create param groups
        for i in range(12):  # Assuming 12 layers from 0 to 11 (for BASE model, when it comes to LARGE model, 12->24)
            lr = kwargs['LR_Transformer'] * (kwargs['LLRD_factor'] ** i)
            param_groups.append({'params': encoder_layers[i].parameters(), 'lr': lr})

        # Initialize the optimizer with these param groups
        self.__optimizer__ = Optimizer(param_groups, **kwargs)

        # self.__optimizer__ = Optimizer(self.__model__.parameters(), **kwargs)
        # print('scheduler.'+scheduler)
        Scheduler = importlib.import_module('scheduler.'+scheduler).__getattribute__('Scheduler')
        # print(kwargs)
        try:
            self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, **kwargs)
        except:
            self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, lr_decay=0.9, **kwargs)

        # self.scaler = GradScaler() 

        self.gpu = gpu

        self.mixedprec = mixedprec
        print("Mix prec: %s"%(self.mixedprec))

        assert self.lr_step in ['epoch', 'iteration']

    # ## ===== ===== ===== ===== ===== ===== ===== =====
    # ## Train network
    # ## ===== ===== ===== ===== ===== ===== ===== =====
    
    def update_lgl_threshold(self, lgl_threshold):
        self.__model__.module.__L__.lgl_threshold = lgl_threshold
    
    # """
    def train_network(self, loader, loss_vals_path, epoch, verbose):
        if torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            unique_loss_vals_path = f"{loss_vals_path.split('.')[0]}_rank{rank}.txt"
        else:
            unique_loss_vals_path = loss_vals_path
        
        self.__model__.train();

        stepsize = loader.batch_size;

        counter = 0;
        index   = 0;
        loss    = 0;
        top1    = 0     # EER or accuracy

        tstart = time.time()
        Learned_dict = {}
        checkpoint = torch.load(self.path)
        for name, param in checkpoint['model'].items():
            if 'w2v_encoder.w2v_model.' in name:
                newname = name.replace('w2v_encoder.w2v_model.', '')
            else:
                newname = name
            Learned_dict[newname] = param;

        # for data_clean, data, data_label, data_path in loader:
        #     telapsed = time.time() - tstart
        #     tstart = time.time()
        #     counter += 1;
        #     index   += stepsize
        #     sys.stdout.write("\rProcessing (%d) "%(index));
        #     sys.stdout.write("Loss %f TEER/TAcc %2.3f%% - %.2f Hz "%(loss/counter, top1/counter, stepsize/telapsed));
        #     if counter % 100 == 0:
        #         sys.stdout.flush()
            
        with open(unique_loss_vals_path, 'w') as loss_vals_file:
            for data_clean, data, data_label, data_path in loader:
                data_clean = data_clean.transpose(1,0)
                data = data.transpose(1,0)
                self.__model__.zero_grad()
                label   = torch.LongTensor(data_label).cuda()

                nloss, prec1, spkloss, ce = self.__model__([data,"train"], data_clean, label, Learned_dict, epoch=epoch)
                
                for ce_val, path in zip(ce.detach().cpu().numpy(), data_path):
                    loss_vals_file.write(f'{ce_val} {"/".join(path.split("/")[5:])}\n')
                
                nloss.backward()

                self.__optimizer__.step();

                loss    += spkloss.detach().cpu()
                top1    += prec1.detach().cpu()
                

                counter += 1;
                index   += stepsize;

            

                telapsed = time.time() - tstart
                tstart = time.time()

                if verbose:
                    sys.stdout.write("\rProcessing (%d) "%(index));
                    sys.stdout.write("Loss %f TEER/TAcc %2.3f%% - %.2f Hz "%(loss/counter, top1/counter, stepsize/telapsed));
                    sys.stdout.flush();

                if self.lr_step == 'iteration': self.__scheduler__.step()
            
        if self.lr_step == 'epoch': self.__scheduler__.step()

        sys.stdout.write("\n");
        return (loss/counter, top1/counter);
    # """

    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Evaluate from list
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def evaluateFromList(self, test_list, test_path, nDataLoaderThread, print_interval=10, num_eval=15, **kwargs):
        
        self.__model__.eval();
        
        lines       = []
        files       = []
        feats       = {}
        tstart      = time.time()

        ## Read all lines
        with open(test_list) as f:
            lines = f.readlines()
        
        ## Get a list of unique file names
        files = sum([x.strip().split()[-2:] for x in lines],[])
        setfiles = list(set(files))
        setfiles.sort()

        ## Define test data loader
        test_dataset = test_dataset_loader(setfiles, test_path, num_eval=num_eval, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=nDataLoaderThread,
            drop_last=False,
        )
        ref_feat_list = []
        ref_feat_2_list = []
        max_len = 0
        forward = 0
        ## Extract features for every image
        for idx, data in enumerate(test_loader):
            

            inp1                = data[0][0].cuda()
            inp2                = data[1][0].cuda()
            telapsed_2 = time.time() 
            b,utt_l = inp2.shape
            if utt_l > max_len:
                max_len = utt_l
            ref_feat            = self.__model__([inp1, "test"]).cuda()
            ref_feat = ref_feat.detach().cpu()
            ref_feat_2            = self.__model__([inp2[:,:700000], "test"]).cuda() # The reason why here is set to 700000 is due to GPU memory size.
            ref_feat_2 = ref_feat_2.detach().cpu()

            feats[data[2][0]]   = [ref_feat,ref_feat_2]
            
            ref_feat_list.extend(ref_feat.numpy())
            ref_feat_2_list.extend(ref_feat_2.numpy())

            telapsed = time.time() - tstart
            forward = forward + time.time() - telapsed_2

            if idx % print_interval == 0:
                sys.stdout.write("\rReading %d of %d: %.2f Hz, forward speed: %.2f Hz, embedding size %d, max_len %d"%(idx,len(setfiles),idx/telapsed,idx/forward, ref_feat.size()[-1],max_len));

        print('')
        all_scores = [];
        all_labels = [];
        all_trials = [];
        all_scores_1 = [];        
        all_scores_2 = [];

        tstart = time.time()

        ref_feat_list = numpy.array(ref_feat_list)
        ref_feat_2_list = numpy.array(ref_feat_2_list)

        ref_feat_list_mean = 0
        ref_feat_2_list_mean  = 0


        ## Read files and compute all scores
        for idx, line in enumerate(lines):

            data = line.split();

            ## Append random label if missing
            if len(data) == 2: data = [random.randint(0,1)] + data

            ref_feat,ref_feat_2 = feats[data[1]]
            com_feat,com_feat_2 = feats[data[2]]

            # if self.__model__.module.__L__.test_normalize:
            ref_feat = F.normalize(ref_feat-ref_feat_list_mean, p=2, dim=1) # B, D
            com_feat = F.normalize(com_feat-ref_feat_list_mean, p=2, dim=1)
            ref_feat_2 = F.normalize(ref_feat_2-ref_feat_2_list_mean, p=2, dim=1) # B, D
            com_feat_2 = F.normalize(com_feat_2-ref_feat_2_list_mean, p=2, dim=1)

            score_1 = torch.mean(torch.matmul(ref_feat, com_feat.T)) # higher is positive
            score_2 = torch.mean(torch.matmul(ref_feat_2, com_feat_2.T))
            score = (score_1 + score_2) / 2
            score = score.detach().cpu().numpy()

            all_scores.append(score);  
            all_scores_1.append(score_1);
            all_scores_2.append(score_2);

            all_labels.append(int(data[0]));
            all_trials.append(data[1]+" "+data[2])

            if idx % (10*print_interval) == 0:
                telapsed = time.time() - tstart
                sys.stdout.write("\rComputing %d of %d: %.2f Hz"%(idx,len(lines),idx/telapsed));
                sys.stdout.flush();

        print('')

        return (all_scores, all_labels, all_trials,all_scores_1,all_scores_2);
    
    def generate_embeddings(self, wav_files, output, device):
        res = {}

        for file in tqdm(wav_files):
            wav, sr = soundfile.read(file)
            wav = torch.from_numpy(wav).float().to(device)

            with torch.no_grad():
                embedding = self.__model__([wav.unsqueeze(0), "test"]).detach().cpu()
            
            key = '/'.join(file.split('/')[-3:])
            res[key] = embedding

        torch.save(res, output)

    def saveParameters(self, path):
        torch.save(self.__model__.module.state_dict(), path);


    ## ===== ===== ===== ===== ===== ===== ===== =====
    ## Load parameters
    ## ===== ===== ===== ===== ===== ===== ===== =====

    def loadParameters(self, path):

        self_state = self.__model__.module.state_dict();
        loaded_state = torch.load(path, map_location="cuda:%d"%self.gpu);
        # loaded_state = torch.load(path, map_location="cpu");



        for name, param in loaded_state.items():
            origname = name;

            if name not in self_state:
                name = name.replace("module.", "");

                if name not in self_state:
                    print("%s is not in the model."%origname);
                    continue;

            if self_state[name].size() != loaded_state[origname].size():
                print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size()));
                continue;

            self_state[name].copy_(param);