Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,341 Bytes
6faeba1 6a79837 6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
This module is meant to find potentially problematic samples
in the data you are using. There are two types: The alignment
scorer and the TTS scorer. The alignment scorer can help you
find mispronunciations or errors in the labels. The TTS scorer
can help you find outliers in the audio part of text-audio pairs.
"""
import math
import statistics
import torch
import torch.multiprocessing
from tqdm import tqdm
from Modules.ToucanTTS.ToucanTTS import ToucanTTS
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
from Utility.corpus_preparation import prepare_tts_corpus
class TTSScorer:
def __init__(self,
path_to_model,
device,
):
self.device = device
self.path_to_score = dict()
self.path_to_id = dict()
self.nans = list()
self.nan_indexes = list()
self.tts = ToucanTTS()
checkpoint = torch.load(path_to_model, map_location='cpu')
weights = checkpoint["model"]
self.tts.load_state_dict(weights)
self.tts.to(self.device)
self.nans_removed = False
self.current_dset = None
self.ap = CodecAudioPreprocessor(input_sr=-1, device=device)
self.spec_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device)
def score(self, path_to_toucantts_dataset, lang_id):
"""
call this to update the path_to_score dict with scores for this dataset
"""
dataset = prepare_tts_corpus(dict(), path_to_toucantts_dataset, lang_id)
self.current_dset = dataset
self.nans = list()
self.nan_indexes = list()
self.path_to_score = dict()
self.path_to_id = dict()
_ = dataset[0]
for index in tqdm(range(len(dataset.datapoints))):
datapoint = dataset.datapoints[index]
text_tensors = datapoint[0].to(self.device).unsqueeze(0).float()
text_lengths = datapoint[1].squeeze().to(self.device).unsqueeze(0)
speech_indexes = datapoint[2]
speech_lengths = datapoint[3].squeeze().to(self.device).unsqueeze(0)
gold_durations = datapoint[4].to(self.device).unsqueeze(0)
gold_pitch = datapoint[6].to(self.device).unsqueeze(0) # mind the switched order
gold_energy = datapoint[5].to(self.device).unsqueeze(0) # mind the switched order
lang_ids = dataset.language_id.to(self.device)
filepath = datapoint[8]
with torch.inference_mode():
wave = self.ap.indexes_to_audio(speech_indexes.int().to(self.device)).detach()
mel = self.spec_extractor.audio_to_mel_spec_tensor(wave, explicit_sampling_rate=16000).transpose(0, 1).detach().cpu()
gold_speech_sample = mel.clone().to(self.device).unsqueeze(0)
utterance_embedding = datapoint[7].unsqueeze(0).to(self.device)
try:
regression_loss, _, duration_loss, pitch_loss, energy_loss = self.tts(text_tensors=text_tensors,
text_lengths=text_lengths,
gold_speech=gold_speech_sample,
speech_lengths=speech_lengths,
gold_durations=gold_durations,
gold_pitch=gold_pitch,
gold_energy=gold_energy,
utterance_embedding=utterance_embedding,
lang_ids=lang_ids,
return_feats=False,
run_stochastic=False)
loss = regression_loss # + duration_loss + pitch_loss + energy_loss # we omit the stochastic loss
except TypeError:
loss = torch.tensor(torch.nan)
if torch.isnan(loss):
self.nans.append(filepath)
self.nan_indexes.append(index)
self.path_to_score[filepath] = loss.cpu().item()
self.path_to_id[filepath] = index
if len(self.nans) > 0:
print("NaNs detected during scoring!")
for path in self.nans:
print(path)
print("\n\n")
self.nans_removed = False
def show_samples_with_highest_loss(self, n=-1):
"""
NaN samples will always be shown.
To see all samples, pass -1, otherwise n samples will be shown.
"""
if len(self.nans) > 0:
print("The following filepaths had an infinite loss:")
for path in self.nans:
print(path)
print("\n\n")
for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)):
if index < n or n == -1:
print(f"Loss: {round(self.path_to_score[path], 3)} - Path: {path}")
print("\n\n")
def remove_samples_with_highest_loss(self, n=10):
if self.current_dset is None:
print("Please run the scoring first.")
else:
if self.nans_removed:
print("Indexes are no longer accurate. Please re-run the scoring. \n\n"
"This function also removes NaNs, so if you want to remove the NaN samples and the n samples "
"with the highest loss, only call this function.")
else:
remove_ids = list()
remove_ids.extend(self.nan_indexes)
for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)):
if index < n:
remove_ids.append(self.path_to_id[path])
self.current_dset.remove_samples(remove_ids)
self.nans_removed = True
def remove_samples_with_loss_three_std_devs_higher_than_avg(self):
if self.current_dset is None:
print("Please run the scoring first.")
else:
if self.nans_removed:
print("Indexes are no longer accurate. Please re-run the scoring. \n\n"
"This function also removes NaNs, so if you want to remove the NaN samples and the outliers, only call this one here.")
else:
remove_ids = list()
remove_ids.extend(self.nan_indexes)
scores_without_nans = [value for value in list(self.path_to_score.values()) if not math.isnan(value)]
avg = statistics.mean(scores_without_nans)
std = statistics.stdev(scores_without_nans)
thresh = avg + (3 * std)
for path in self.path_to_score:
if not math.isnan(self.path_to_score[path]):
if self.path_to_score[path] > thresh: # we found an outlier!
remove_ids.append(self.path_to_id[path])
print(f"removing {len(remove_ids)} outliers!")
self.current_dset.remove_samples(remove_ids)
self.nans_removed = True
def remove_nans(self):
if self.nans_removed:
print("NaNs have already been removed!")
else:
if self.current_dset is None:
print("Please run the scoring first to find NaNs.")
else:
if len(self.nans) > 0:
print("The following filepaths had an infinite loss and are being removed from the dataset cache:")
for path in self.nans:
print(path)
self.current_dset.remove_samples(self.nan_indexes)
self.nans_removed = True
else:
print("No NaNs detected in this dataset.")
|