MassivelyMultilingualTTS / Utility /silence_removal.py
Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
4.32 kB
import librosa
import soundfile as sf
from tqdm import tqdm
from Preprocessing.TextFrontend import get_feature_to_index_lookup
from Utility.path_to_transcript_dicts import *
def make_silence_cleaned_versions(train_sets):
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
device = "cuda" if torch.cuda.is_available() else "cpu"
silero_model = silero_model.to(device)
for train_set in train_sets:
for index in tqdm(range(len(train_set))):
filepath = train_set.datapoints[index][8]
phonemes = train_set.datapoints[index][0]
speech_length = train_set.datapoints[index][3]
durations = train_set.datapoints[index][4]
cumsum = 0
legal_silences = list()
for phoneme_index, phone in enumerate(phonemes):
if phone[get_feature_to_index_lookup()["silence"]] == 1 or phone[get_feature_to_index_lookup()["end of sentence"]] == 1 or phone[get_feature_to_index_lookup()["questionmark"]] == 1 or phone[get_feature_to_index_lookup()["exclamationmark"]] == 1 or phone[get_feature_to_index_lookup()["fullstop"]] == 1:
legal_silences.append([cumsum, cumsum + durations[phoneme_index]])
cumsum = cumsum + durations[phoneme_index]
wave, sr = sf.read(filepath)
resampled_wave = librosa.resample(wave, orig_sr=sr, target_sr=16000)
with torch.inference_mode():
speech_timestamps = get_speech_timestamps(torch.Tensor(resampled_wave).to(device), silero_model, sampling_rate=16000)
silences = list()
prev_end = 0
for speech_segment in speech_timestamps:
if prev_end != 0:
silences.append([prev_end, speech_segment["start"]])
prev_end = speech_segment["end"]
# at this point we know all the silences and we know the legal silences.
# We have to transform them both into ratios, so we can compare them.
# If a silence overlaps with a legal silence, it can stay.
illegal_silences = list()
for silence in silences:
illegal = True
start = silence[0] / len(resampled_wave)
end = silence[1] / len(resampled_wave)
for legal_silence in legal_silences:
legal_start = legal_silence[0] / speech_length
legal_end = legal_silence[1] / speech_length
if legal_start < start < legal_end or legal_start < end < legal_end:
illegal = False
break
if illegal:
# If it is an illegal silence, it is marked for removal in the original wave according to ration with real samplingrate.
illegal_silences.append([start, end])
# print(f"{len(illegal_silences)} illegal silences detected. ({len(silences) - len(illegal_silences)} legal silences left)")
wave = list(wave)
orig_wave_length = len(wave)
for illegal_silence in reversed(illegal_silences):
wave = wave[:int(illegal_silence[0] * orig_wave_length)] + wave[int(illegal_silence[1] * orig_wave_length):]
# Audio with illegal silences removed will be saved into a new directory.
new_filepath_list = filepath.split("/")
new_filepath_list[-2] = new_filepath_list[-2] + "_silence_removed"
os.makedirs("/".join(new_filepath_list[:-1]), exist_ok=True)
sf.write("/".join(new_filepath_list), wave, sr)