MassivelyMultilingualTTS / Modules /Aligner /CodecAlignerDataset.py
Flux9665's picture
update to the current version
70399da
raw
history blame
15.2 kB
import os
import random
import librosa
import soundfile as sf
import torch
from speechbrain.pretrained import EncoderClassifier
from torch.multiprocessing import Manager
from torch.multiprocessing import Process
from torch.utils.data import Dataset
from torchaudio.transforms import Resample
from tqdm import tqdm
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Utility.storage_config import MODELS_DIR
class CodecAlignerDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
cache_dir,
lang,
loading_processes,
device,
min_len_in_seconds=1,
max_len_in_seconds=15,
rebuild_cache=False,
verbose=False,
phone_input=False,
allow_unknown_symbols=False,
gpu_count=1,
rank=0):
self.gpu_count = gpu_count
self.rank = rank
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
device=device,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
verbose=verbose,
phone_input=phone_input,
allow_unknown_symbols=allow_unknown_symbols,
gpu_count=gpu_count,
rank=rank)
self.lang = lang
self.device = device
self.cache_dir = cache_dir
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device)
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
self.speaker_embeddings = cache[2]
self.filepaths = cache[3]
self.datapoints = cache[0]
if self.gpu_count > 1:
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
while len(self.datapoints) % self.gpu_count != 0:
self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway.
chunksize = int(len(self.datapoints) / self.gpu_count)
self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)]
self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)]
print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.")
def _build_dataset_cache(self,
path_to_transcript_dict,
cache_dir,
lang,
loading_processes,
device,
min_len_in_seconds=1,
max_len_in_seconds=15,
verbose=False,
phone_input=False,
allow_unknown_symbols=False,
gpu_count=1,
rank=0
):
if gpu_count != 1:
import sys
print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.")
sys.exit()
os.makedirs(cache_dir, exist_ok=True)
if type(path_to_transcript_dict) != dict:
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
torch.multiprocessing.set_start_method('spawn', force=True)
torch.multiprocessing.set_sharing_strategy('file_system')
resource_manager = Manager()
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
key_list = list(self.path_to_transcript_dict.keys())
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
files_used_note.write(str(key_list))
fisher_yates_shuffle(key_list)
# build cache
print("... building dataset cache ...")
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
_, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
self.result_pool = resource_manager.list()
# make processes
key_splits = list()
process_list = list()
for i in range(loading_processes):
key_splits.append(
key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
for key_split in key_splits:
process_list.append(
Process(target=self._cache_builder_process,
args=(key_split,
lang,
min_len_in_seconds,
max_len_in_seconds,
verbose,
device,
phone_input,
allow_unknown_symbols),
daemon=True))
process_list[-1].start()
for process in process_list:
process.join()
print("pooling results...")
pooled_datapoints = list()
for chunk in self.result_pool:
for datapoint in chunk:
pooled_datapoints.append(datapoint) # unpack into a joint list
self.result_pool = pooled_datapoints
del pooled_datapoints
print("converting text to tensors...")
text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues)
print("converting speech to tensors...")
speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool]
print("converting waves to tensors...")
norm_waves = [torch.Tensor(x[2]) for x in self.result_pool]
print("unpacking file list...")
filepaths = [x[3] for x in self.result_pool]
del self.result_pool
self.datapoints = list(zip(text_tensors, speech_tensors))
del text_tensors
del speech_tensors
print("done!")
# add speaker embeddings
self.speaker_embeddings = list()
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
with torch.inference_mode():
for wave in tqdm(norm_waves):
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
# save to cache
if len(self.datapoints) == 0:
raise RuntimeError # something went wrong and there are no datapoints
torch.save((self.datapoints, None, self.speaker_embeddings, filepaths),
os.path.join(cache_dir, "aligner_train_cache.pt"))
def _cache_builder_process(self,
path_list,
lang,
min_len,
max_len,
verbose,
device,
phone_input,
allow_unknown_symbols):
process_internal_dataset_chunk = list()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
silero_model = silero_model.to(device)
silence = torch.zeros([16000 // 8]).to(device)
tf = ArticulatoryCombinedTextFrontend(language=lang, device=device)
_, sr = sf.read(path_list[0])
assumed_sr = sr
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
for path in tqdm(path_list):
if self.path_to_transcript_dict[path].strip() == "":
continue
try:
wave, sr = sf.read(path)
except:
print(f"Problem with an audio file: {path}")
continue
if len(wave.shape) > 1: # oh no, we found a stereo audio!
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
wave = wave.transpose() # if yes, we switch the axes.
wave = librosa.to_mono(wave)
if sr != assumed_sr:
assumed_sr = sr
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
print(f"{path} has a different sampling rate --> adapting the codec processor")
try:
norm_wave = resample(torch.tensor(wave).float().to(device))
except ValueError:
continue
dur_in_seconds = len(norm_wave) / 16000
if not (min_len <= dur_in_seconds <= max_len):
if verbose:
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
continue
with torch.inference_mode():
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
try:
silence_timestamps = invert_segments(speech_timestamps, len(norm_wave))
for silence_timestamp in silence_timestamps:
begin = silence_timestamp['start']
end = silence_timestamp['end']
norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]])
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
except IndexError:
print("Audio might be too short to cut silences from front and back.")
continue
norm_wave = torch.cat([silence, result, silence])
# raw audio preprocessing is done
transcript = self.path_to_transcript_dict[path]
try:
try:
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy()
except KeyError:
cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy()
if not allow_unknown_symbols:
continue # we skip sentences with unknown symbols
except ValueError:
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
continue
except KeyError:
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
continue
cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
process_internal_dataset_chunk.append([cached_text,
cached_speech,
norm_wave.cpu().detach().numpy(),
path])
self.result_pool.append(process_internal_dataset_chunk)
def __getitem__(self, index):
text_vector = self.datapoints[index][0]
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector)
tokens = torch.LongTensor(tokens)
token_len = torch.LongTensor([len(tokens)])
codes = self.datapoints[index][1]
if codes.size()[0] != 24: # no clue why this is sometimes the case
codes = codes.transpose(0, 1)
return tokens, \
token_len, \
codes, \
None, \
self.speaker_embeddings[index]
def __len__(self):
return len(self.datapoints)
def remove_samples(self, list_of_samples_to_remove):
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
self.datapoints.pop(remove_id)
self.speaker_embeddings.pop(remove_id)
self.filepaths.pop(remove_id)
torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths),
os.path.join(self.cache_dir, "aligner_train_cache.pt"))
print("Dataset updated!")
def fisher_yates_shuffle(lst):
for i in range(len(lst) - 1, 0, -1):
j = random.randint(0, i)
lst[i], lst[j] = lst[j], lst[i]
def invert_segments(segments, total_duration):
if not segments:
return [{'start': 0, 'end': total_duration}]
inverted_segments = []
previous_end = 0
for segment in segments:
start = segment['start']
if previous_end < start:
inverted_segments.append({'start': previous_end, 'end': start})
previous_end = segment['end']
if previous_end < total_duration:
inverted_segments.append({'start': previous_end, 'end': total_duration})
return inverted_segments