Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,213 Bytes
6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import os
import random
import librosa
import soundfile as sf
import torch
from speechbrain.pretrained import EncoderClassifier
from torch.multiprocessing import Manager
from torch.multiprocessing import Process
from torch.utils.data import Dataset
from torchaudio.transforms import Resample
from tqdm import tqdm
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Utility.storage_config import MODELS_DIR
class CodecAlignerDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
cache_dir,
lang,
loading_processes,
device,
min_len_in_seconds=1,
max_len_in_seconds=15,
rebuild_cache=False,
verbose=False,
phone_input=False,
allow_unknown_symbols=False,
gpu_count=1,
rank=0):
self.gpu_count = gpu_count
self.rank = rank
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
device=device,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
verbose=verbose,
phone_input=phone_input,
allow_unknown_symbols=allow_unknown_symbols,
gpu_count=gpu_count,
rank=rank)
self.lang = lang
self.device = device
self.cache_dir = cache_dir
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device)
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
self.speaker_embeddings = cache[2]
self.filepaths = cache[3]
self.datapoints = cache[0]
if self.gpu_count > 1:
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
while len(self.datapoints) % self.gpu_count != 0:
self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway.
chunksize = int(len(self.datapoints) / self.gpu_count)
self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)]
self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)]
print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.")
def _build_dataset_cache(self,
path_to_transcript_dict,
cache_dir,
lang,
loading_processes,
device,
min_len_in_seconds=1,
max_len_in_seconds=15,
verbose=False,
phone_input=False,
allow_unknown_symbols=False,
gpu_count=1,
rank=0
):
if gpu_count != 1:
import sys
print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.")
sys.exit()
os.makedirs(cache_dir, exist_ok=True)
if type(path_to_transcript_dict) != dict:
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
torch.multiprocessing.set_start_method('spawn', force=True)
torch.multiprocessing.set_sharing_strategy('file_system')
resource_manager = Manager()
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
key_list = list(self.path_to_transcript_dict.keys())
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
files_used_note.write(str(key_list))
fisher_yates_shuffle(key_list)
# build cache
print("... building dataset cache ...")
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
_, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
self.result_pool = resource_manager.list()
# make processes
key_splits = list()
process_list = list()
for i in range(loading_processes):
key_splits.append(
key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
for key_split in key_splits:
process_list.append(
Process(target=self._cache_builder_process,
args=(key_split,
lang,
min_len_in_seconds,
max_len_in_seconds,
verbose,
device,
phone_input,
allow_unknown_symbols),
daemon=True))
process_list[-1].start()
for process in process_list:
process.join()
print("pooling results...")
pooled_datapoints = list()
for chunk in self.result_pool:
for datapoint in chunk:
pooled_datapoints.append(datapoint) # unpack into a joint list
self.result_pool = pooled_datapoints
del pooled_datapoints
print("converting text to tensors...")
text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues)
print("converting speech to tensors...")
speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool]
print("converting waves to tensors...")
norm_waves = [torch.Tensor(x[2]) for x in self.result_pool]
print("unpacking file list...")
filepaths = [x[3] for x in self.result_pool]
del self.result_pool
self.datapoints = list(zip(text_tensors, speech_tensors))
del text_tensors
del speech_tensors
print("done!")
# add speaker embeddings
self.speaker_embeddings = list()
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
with torch.inference_mode():
for wave in tqdm(norm_waves):
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
# save to cache
if len(self.datapoints) == 0:
raise RuntimeError # something went wrong and there are no datapoints
torch.save((self.datapoints, None, self.speaker_embeddings, filepaths),
os.path.join(cache_dir, "aligner_train_cache.pt"))
def _cache_builder_process(self,
path_list,
lang,
min_len,
max_len,
verbose,
device,
phone_input,
allow_unknown_symbols):
process_internal_dataset_chunk = list()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
silero_model = silero_model.to(device)
silence = torch.zeros([16000 // 8]).to(device)
tf = ArticulatoryCombinedTextFrontend(language=lang, device=device)
_, sr = sf.read(path_list[0])
assumed_sr = sr
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
for path in tqdm(path_list):
if self.path_to_transcript_dict[path].strip() == "":
continue
try:
wave, sr = sf.read(path)
except:
print(f"Problem with an audio file: {path}")
continue
if len(wave.shape) > 1: # oh no, we found a stereo audio!
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
wave = wave.transpose() # if yes, we switch the axes.
wave = librosa.to_mono(wave)
if sr != assumed_sr:
assumed_sr = sr
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
print(f"{path} has a different sampling rate --> adapting the codec processor")
try:
norm_wave = resample(torch.tensor(wave).float().to(device))
except ValueError:
continue
dur_in_seconds = len(norm_wave) / 16000
if not (min_len <= dur_in_seconds <= max_len):
if verbose:
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
continue
with torch.inference_mode():
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
try:
silence_timestamps = invert_segments(speech_timestamps, len(norm_wave))
for silence_timestamp in silence_timestamps:
begin = silence_timestamp['start']
end = silence_timestamp['end']
norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]])
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
except IndexError:
print("Audio might be too short to cut silences from front and back.")
continue
norm_wave = torch.cat([silence, result, silence])
# raw audio preprocessing is done
transcript = self.path_to_transcript_dict[path]
try:
try:
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy()
except KeyError:
cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy()
if not allow_unknown_symbols:
continue # we skip sentences with unknown symbols
except ValueError:
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
continue
except KeyError:
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
continue
cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
process_internal_dataset_chunk.append([cached_text,
cached_speech,
norm_wave.cpu().detach().numpy(),
path])
self.result_pool.append(process_internal_dataset_chunk)
def __getitem__(self, index):
text_vector = self.datapoints[index][0]
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector)
tokens = torch.LongTensor(tokens)
token_len = torch.LongTensor([len(tokens)])
codes = self.datapoints[index][1]
if codes.size()[0] != 24: # no clue why this is sometimes the case
codes = codes.transpose(0, 1)
return tokens, \
token_len, \
codes, \
None, \
self.speaker_embeddings[index]
def __len__(self):
return len(self.datapoints)
def remove_samples(self, list_of_samples_to_remove):
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
self.datapoints.pop(remove_id)
self.speaker_embeddings.pop(remove_id)
self.filepaths.pop(remove_id)
torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths),
os.path.join(self.cache_dir, "aligner_train_cache.pt"))
print("Dataset updated!")
def fisher_yates_shuffle(lst):
for i in range(len(lst) - 1, 0, -1):
j = random.randint(0, i)
lst[i], lst[j] = lst[j], lst[i]
def invert_segments(segments, total_duration):
if not segments:
return [{'start': 0, 'end': total_duration}]
inverted_segments = []
previous_end = 0
for segment in segments:
start = segment['start']
if previous_end < start:
inverted_segments.append({'start': previous_end, 'end': start})
previous_end = segment['end']
if previous_end < total_duration:
inverted_segments.append({'start': previous_end, 'end': total_duration})
return inverted_segments |