#! /usr/bin/python # -*- encoding: utf-8 -*- import torch import numpy import random import pdb import os import threading import time import math import glob # import soundfile from scipy import signal import soundfile from torch.utils.data import Dataset, DataLoader import torch.distributed as dist def round_down(num, divisor): return num - (num%divisor) def worker_init_fn(worker_id): numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) def loadWAV(filename, max_frames, evalmode=True, num_eval=5): # Maximum audio length max_audio = max_frames * 160 + 240 # Read wav file and convert to torch tensor audio, sample_rate = soundfile.read(filename) audiosize = audio.shape[0] if audiosize <= max_audio: shortage = max_audio - audiosize + 1 audio = numpy.pad(audio, (0, shortage), 'wrap') audiosize = audio.shape[0] if evalmode: startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval) else: startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))]) feats = [] if evalmode and max_frames == 0: feats.append(audio) else: for asf in startframe: feats.append(audio[int(asf):int(asf)+max_audio]) feat = numpy.stack(feats,axis=0).astype(float) return feat; class AugmentWAV(object): def __init__(self, musan_path, rir_path, max_frames): self.max_frames = max_frames self.max_audio = max_audio = max_frames * 160 + 240 self.noisetypes = ['noise','speech','music'] self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1] } self.noiselist = {} augment_files = glob.glob(os.path.join(musan_path,'*/*/*.wav')); for file in augment_files: if not file.split('/')[-3] in self.noiselist: self.noiselist[file.split('/')[-3]] = [] self.noiselist[file.split('/')[-3]].append(file) self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav')); def additive_noise(self, noisecat, audio): clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) numnoise = self.numnoise[noisecat] noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1])) noises = [] for noise in noiselist: noiseaudio = loadWAV(noise, self.max_frames, evalmode=False) noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1]) noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio) return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio def reverberate(self, audio): rir_file = random.choice(self.rir_files) rir, fs = soundfile.read(rir_file) rir = numpy.expand_dims(rir.astype(float),0) rir = rir / numpy.sqrt(numpy.sum(rir**2)) return signal.convolve(audio, rir, mode='full')[:,:self.max_audio] class train_dataset_loader(Dataset): def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs): self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames) self.train_list = train_list self.max_frames = max_frames; self.musan_path = musan_path self.rir_path = rir_path self.augment = augment # Read training files with open(train_list) as dataset_file: lines = dataset_file.readlines(); # Make a dictionary of ID names and ID indices dictkeys = list(set([x.split()[0] for x in lines])) dictkeys.sort() dictkeys = { key : ii for ii, key in enumerate(dictkeys) } # Parse the training list into file names and ID indices self.data_list = [] self.data_label = [] for lidx, line in enumerate(lines): data = line.strip().split(); speaker_label = dictkeys[data[0]]; filename = os.path.join(train_path,data[1]); self.data_label.append(speaker_label) self.data_list.append(filename) def __getitem__(self, indices): feat_clean = [] feat = [] for index in indices: try: audio_clean = loadWAV(self.data_list[index], self.max_frames, evalmode=False) except: print(self.data_list[index]) if len(audio_clean.shape) == 3: print(self.data_list[index]) if self.augment: augtype = random.randint(0,5) if augtype == 0: audio = audio_clean elif augtype == 1: audio = self.augment_wav.reverberate(audio_clean) elif augtype == 2: audio = self.augment_wav.additive_noise('music',audio_clean) elif augtype == 3: audio = self.augment_wav.additive_noise('speech',audio_clean) elif augtype == 4: audio = self.augment_wav.additive_noise('noise',audio_clean) elif augtype == 5: audio = self.augment_wav.additive_noise('speech',audio_clean) audio = self.augment_wav.additive_noise('music',audio_clean) feat_clean.append(audio_clean) feat.append(audio) feat_clean = numpy.concatenate(feat_clean, axis=0) feat = numpy.concatenate(feat, axis=0) return torch.FloatTensor(feat_clean), torch.FloatTensor(feat), self.data_label[index], self.data_list[index] def __len__(self): return len(self.data_list) class test_dataset_loader(Dataset): def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs): self.max_frames = eval_frames; self.num_eval = num_eval self.test_path = test_path self.test_list = test_list def __getitem__(self, index): # print(self.test_list[index]) audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval) audio2 = loadWAV(os.path.join(self.test_path,self.test_list[index]), 0, evalmode=True, num_eval=self.num_eval) return torch.FloatTensor(audio), torch.FloatTensor(audio2), self.test_list[index] # return torch.FloatTensor(audio2), self.test_list[index] def __len__(self): return len(self.test_list) class train_dataset_sampler(torch.utils.data.Sampler): def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs): self.data_label = data_source.data_label; self.nPerSpeaker = nPerSpeaker; self.max_seg_per_spk = max_seg_per_spk; self.batch_size = batch_size; self.epoch = 0; self.seed = seed; self.distributed = distributed; def __iter__(self): g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.data_label), generator=g).tolist() data_dict = {} # Sort into dictionary of file indices for each ID for index in indices: speaker_label = self.data_label[index] if not (speaker_label in data_dict): data_dict[speaker_label] = []; data_dict[speaker_label].append(index); ## Group file indices for each class dictkeys = list(data_dict.keys()); dictkeys.sort() lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)] flattened_list = [] flattened_label = [] for findex, key in enumerate(dictkeys): data = data_dict[key] numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker) rp = lol(numpy.arange(numSeg),self.nPerSpeaker) flattened_label.extend([findex] * (len(rp))) for indices in rp: flattened_list.append([data[i] for i in indices]) ## Mix data in random order mixid = torch.randperm(len(flattened_label), generator=g).tolist() mixlabel = [] mixmap = [] ## Prevent two pairs of the same speaker in the same batch for ii in mixid: startbatch = round_down(len(mixlabel), self.batch_size) if flattened_label[ii] not in mixlabel[startbatch:]: mixlabel.append(flattened_label[ii]) mixmap.append(ii) mixed_list = [flattened_list[i] for i in mixmap] ## Divide data to each GPU if self.distributed: total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size()) start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size ) end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size ) self.num_samples = end_index - start_index return iter(mixed_list[start_index:end_index]) else: total_size = round_down(len(mixed_list), self.batch_size) self.num_samples = total_size return iter(mixed_list[:total_size]) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: self.epoch = epoch if __name__ == '__main__': train_dataset = train_dataset_loader(train_list='/mnt/proj3/open-24-5/pengjy_new/WavLM_Adapter/CNCeleb_lst/CNCeleb_trainlist_200spk.txt', augment=False, musan_path='/mnt/proj3/open-24-5/pengjy_new/musan_split/', rir_path='/mnt/proj3/open-24-5/plchot/data_augment/16kHz/simulated_rirs/', max_frames=300, train_path='/mnt/proj3/open-24-5/pengjy_new/Data/CN-Celeb_flac/data', ) train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=1, max_seg_per_spk=500, batch_size=100, distributed=False,seed=120) # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=100, num_workers=10, sampler=train_sampler, pin_memory=True, drop_last=True, ) for data, data_label in train_loader: print(data.shape) data = data.transpose(1,0) print(data.shape) quit()