Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import librosa | |
import soundfile as sf | |
import torch | |
from speechbrain.pretrained import EncoderClassifier | |
from torch.multiprocessing import Manager | |
from torch.multiprocessing import Process | |
from torch.utils.data import Dataset | |
from torchaudio.transforms import Resample | |
from tqdm import tqdm | |
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor | |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend | |
from Utility.storage_config import MODELS_DIR | |
class CodecAlignerDataset(Dataset): | |
def __init__(self, | |
path_to_transcript_dict, | |
cache_dir, | |
lang, | |
loading_processes, | |
device, | |
min_len_in_seconds=1, | |
max_len_in_seconds=15, | |
rebuild_cache=False, | |
verbose=False, | |
phone_input=False, | |
allow_unknown_symbols=False, | |
gpu_count=1, | |
rank=0): | |
self.gpu_count = gpu_count | |
self.rank = rank | |
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: | |
self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict, | |
cache_dir=cache_dir, | |
lang=lang, | |
loading_processes=loading_processes, | |
device=device, | |
min_len_in_seconds=min_len_in_seconds, | |
max_len_in_seconds=max_len_in_seconds, | |
verbose=verbose, | |
phone_input=phone_input, | |
allow_unknown_symbols=allow_unknown_symbols, | |
gpu_count=gpu_count, | |
rank=rank) | |
self.lang = lang | |
self.device = device | |
self.cache_dir = cache_dir | |
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device) | |
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
self.speaker_embeddings = cache[2] | |
self.filepaths = cache[3] | |
self.datapoints = cache[0] | |
if self.gpu_count > 1: | |
# we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank. | |
while len(self.datapoints) % self.gpu_count != 0: | |
self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway. | |
chunksize = int(len(self.datapoints) / self.gpu_count) | |
self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)] | |
self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)] | |
print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.") | |
def _build_dataset_cache(self, | |
path_to_transcript_dict, | |
cache_dir, | |
lang, | |
loading_processes, | |
device, | |
min_len_in_seconds=1, | |
max_len_in_seconds=15, | |
verbose=False, | |
phone_input=False, | |
allow_unknown_symbols=False, | |
gpu_count=1, | |
rank=0 | |
): | |
if gpu_count != 1: | |
import sys | |
print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.") | |
sys.exit() | |
os.makedirs(cache_dir, exist_ok=True) | |
if type(path_to_transcript_dict) != dict: | |
path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary. | |
torch.multiprocessing.set_start_method('spawn', force=True) | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
resource_manager = Manager() | |
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) | |
key_list = list(self.path_to_transcript_dict.keys()) | |
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: | |
files_used_note.write(str(key_list)) | |
fisher_yates_shuffle(key_list) | |
# build cache | |
print("... building dataset cache ...") | |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
# careful: assumes 16kHz or 8kHz audio | |
_, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded | |
model='silero_vad', | |
force_reload=False, | |
onnx=False, | |
verbose=False) | |
self.result_pool = resource_manager.list() | |
# make processes | |
key_splits = list() | |
process_list = list() | |
for i in range(loading_processes): | |
key_splits.append( | |
key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) | |
for key_split in key_splits: | |
process_list.append( | |
Process(target=self._cache_builder_process, | |
args=(key_split, | |
lang, | |
min_len_in_seconds, | |
max_len_in_seconds, | |
verbose, | |
device, | |
phone_input, | |
allow_unknown_symbols), | |
daemon=True)) | |
process_list[-1].start() | |
for process in process_list: | |
process.join() | |
print("pooling results...") | |
pooled_datapoints = list() | |
for chunk in self.result_pool: | |
for datapoint in chunk: | |
pooled_datapoints.append(datapoint) # unpack into a joint list | |
self.result_pool = pooled_datapoints | |
del pooled_datapoints | |
print("converting text to tensors...") | |
text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues) | |
print("converting speech to tensors...") | |
speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool] | |
print("converting waves to tensors...") | |
norm_waves = [torch.Tensor(x[2]) for x in self.result_pool] | |
print("unpacking file list...") | |
filepaths = [x[3] for x in self.result_pool] | |
del self.result_pool | |
self.datapoints = list(zip(text_tensors, speech_tensors)) | |
del text_tensors | |
del speech_tensors | |
print("done!") | |
# add speaker embeddings | |
self.speaker_embeddings = list() | |
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", | |
run_opts={"device": str(device)}, | |
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) | |
with torch.inference_mode(): | |
for wave in tqdm(norm_waves): | |
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) | |
# save to cache | |
if len(self.datapoints) == 0: | |
raise RuntimeError # something went wrong and there are no datapoints | |
torch.save((self.datapoints, None, self.speaker_embeddings, filepaths), | |
os.path.join(cache_dir, "aligner_train_cache.pt")) | |
def _cache_builder_process(self, | |
path_list, | |
lang, | |
min_len, | |
max_len, | |
verbose, | |
device, | |
phone_input, | |
allow_unknown_symbols): | |
process_internal_dataset_chunk = list() | |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
# careful: assumes 16kHz or 8kHz audio | |
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
model='silero_vad', | |
force_reload=False, | |
onnx=False, | |
verbose=False) | |
(get_speech_timestamps, | |
save_audio, | |
read_audio, | |
VADIterator, | |
collect_chunks) = utils | |
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets | |
# this to false globally during model loading rather than using inference mode or no_grad | |
silero_model = silero_model.to(device) | |
silence = torch.zeros([16000 // 8]).to(device) | |
tf = ArticulatoryCombinedTextFrontend(language=lang, device=device) | |
_, sr = sf.read(path_list[0]) | |
assumed_sr = sr | |
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) | |
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) | |
for path in tqdm(path_list): | |
if self.path_to_transcript_dict[path].strip() == "": | |
continue | |
try: | |
wave, sr = sf.read(path) | |
except: | |
print(f"Problem with an audio file: {path}") | |
continue | |
if len(wave.shape) > 1: # oh no, we found a stereo audio! | |
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes | |
wave = wave.transpose() # if yes, we switch the axes. | |
wave = librosa.to_mono(wave) | |
if sr != assumed_sr: | |
assumed_sr = sr | |
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) | |
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) | |
print(f"{path} has a different sampling rate --> adapting the codec processor") | |
try: | |
norm_wave = resample(torch.tensor(wave).float().to(device)) | |
except ValueError: | |
continue | |
dur_in_seconds = len(norm_wave) / 16000 | |
if not (min_len <= dur_in_seconds <= max_len): | |
if verbose: | |
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") | |
continue | |
with torch.inference_mode(): | |
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000) | |
try: | |
silence_timestamps = invert_segments(speech_timestamps, len(norm_wave)) | |
for silence_timestamp in silence_timestamps: | |
begin = silence_timestamp['start'] | |
end = silence_timestamp['end'] | |
norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]]) | |
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] | |
except IndexError: | |
print("Audio might be too short to cut silences from front and back.") | |
continue | |
norm_wave = torch.cat([silence, result, silence]) | |
# raw audio preprocessing is done | |
transcript = self.path_to_transcript_dict[path] | |
try: | |
try: | |
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy() | |
except KeyError: | |
cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy() | |
if not allow_unknown_symbols: | |
continue # we skip sentences with unknown symbols | |
except ValueError: | |
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. | |
continue | |
except KeyError: | |
# this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. | |
continue | |
cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() | |
process_internal_dataset_chunk.append([cached_text, | |
cached_speech, | |
norm_wave.cpu().detach().numpy(), | |
path]) | |
self.result_pool.append(process_internal_dataset_chunk) | |
def __getitem__(self, index): | |
text_vector = self.datapoints[index][0] | |
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector) | |
tokens = torch.LongTensor(tokens) | |
token_len = torch.LongTensor([len(tokens)]) | |
codes = self.datapoints[index][1] | |
if codes.size()[0] != 24: # no clue why this is sometimes the case | |
codes = codes.transpose(0, 1) | |
return tokens, \ | |
token_len, \ | |
codes, \ | |
None, \ | |
self.speaker_embeddings[index] | |
def __len__(self): | |
return len(self.datapoints) | |
def remove_samples(self, list_of_samples_to_remove): | |
for remove_id in sorted(list_of_samples_to_remove, reverse=True): | |
self.datapoints.pop(remove_id) | |
self.speaker_embeddings.pop(remove_id) | |
self.filepaths.pop(remove_id) | |
torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths), | |
os.path.join(self.cache_dir, "aligner_train_cache.pt")) | |
print("Dataset updated!") | |
def fisher_yates_shuffle(lst): | |
for i in range(len(lst) - 1, 0, -1): | |
j = random.randint(0, i) | |
lst[i], lst[j] = lst[j], lst[i] | |
def invert_segments(segments, total_duration): | |
if not segments: | |
return [{'start': 0, 'end': total_duration}] | |
inverted_segments = [] | |
previous_end = 0 | |
for segment in segments: | |
start = segment['start'] | |
if previous_end < start: | |
inverted_segments.append({'start': previous_end, 'end': start}) | |
previous_end = segment['end'] | |
if previous_end < total_duration: | |
inverted_segments.append({'start': previous_end, 'end': total_duration}) | |
return inverted_segments |