Spaces:
Running
on
T4
Running
on
T4
import itertools | |
import os | |
import warnings | |
from typing import cast | |
import matplotlib.pyplot as plt | |
import pyloudnorm | |
import sounddevice | |
import soundfile | |
import torch | |
import spaces | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
from speechbrain.pretrained import EncoderClassifier | |
from torchaudio.transforms import Resample | |
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS | |
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN | |
from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend | |
from Preprocessing.TextFrontend import get_language_id | |
from Utility.storage_config import MODELS_DIR | |
from Utility.utils import cumsum_durations | |
from Utility.utils import float2pcm | |
class ToucanTTSInterface(torch.nn.Module): | |
def __init__(self, | |
device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude. | |
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone | |
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint | |
language="eng", # initial language of the model, can be changed later with the setter methods | |
enhance=None # legacy argument | |
): | |
super().__init__() | |
self.device = device | |
if not tts_model_path.endswith(".pt"): | |
# default to shorthand system | |
tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt") | |
################################ | |
# build text to phone # | |
################################ | |
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True) | |
##################################### | |
# load phone to features model # | |
##################################### | |
checkpoint = torch.load(tts_model_path, map_location='cpu') | |
self.phone2mel = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]) | |
with torch.no_grad(): | |
self.phone2mel.store_inverse_all() # this also removes weight norm | |
self.phone2mel = self.phone2mel.to(torch.device(device)) | |
###################################### | |
# load features to style models # | |
###################################### | |
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", | |
run_opts={"device": str(device)}, | |
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) | |
################################ | |
# load mel to wave model # | |
################################ | |
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu") | |
self.vocoder = HiFiGAN() | |
self.vocoder.load_state_dict(vocoder_checkpoint) | |
self.vocoder = self.vocoder.to(device).eval() | |
self.vocoder.remove_weight_norm() | |
self.meter = pyloudnorm.Meter(24000) | |
################################ | |
# set defaults # | |
################################ | |
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) | |
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device) | |
self.phone2mel.eval() | |
self.vocoder.eval() | |
self.lang_id = get_language_id(language) | |
self.to(torch.device(device)) | |
self.eval() | |
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): | |
if embedding is not None: | |
self.default_utterance_embedding = embedding.squeeze().to(self.device) | |
return | |
if type(path_to_reference_audio) != list: | |
path_to_reference_audio = [path_to_reference_audio] | |
if len(path_to_reference_audio) > 0: | |
for path in path_to_reference_audio: | |
assert os.path.exists(path) | |
speaker_embs = list() | |
for path in path_to_reference_audio: | |
wave, sr = soundfile.read(path) | |
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32)) | |
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze() | |
speaker_embs.append(speaker_embedding) | |
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) | |
def set_language(self, lang_id): | |
""" | |
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs | |
""" | |
self.set_phonemizer_language(lang_id=lang_id) | |
self.set_accent_language(lang_id=lang_id) | |
def set_phonemizer_language(self, lang_id): | |
self.text2phone.change_lang(language=lang_id, add_silence_to_end=True) | |
def set_accent_language(self, lang_id): | |
if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']: | |
if lang_id == 'vi-so' or lang_id == 'vi-ctr': | |
lang_id = 'vie' | |
elif lang_id == 'spa-lat': | |
lang_id = 'spa' | |
elif lang_id == 'pt-br': | |
lang_id = 'por' | |
elif lang_id == 'fr-sw' or lang_id == 'fr-be': | |
lang_id = 'fra' | |
elif lang_id == 'en-sc' or lang_id == 'en-us': | |
lang_id = 'eng' | |
else: | |
# no clue where these others are even coming from, they are not in ISO 639-2 | |
lang_id = 'eng' | |
self.lang_id = get_language_id(lang_id).to(self.device) | |
def forward(self, | |
text, | |
view=False, | |
duration_scaling_factor=1.0, | |
pitch_variance_scale=1.0, | |
energy_variance_scale=1.0, | |
pause_duration_scaling_factor=1.0, | |
durations=None, | |
pitch=None, | |
energy=None, | |
input_is_phones=False, | |
return_plot_as_filepath=False, | |
loudness_in_db=-24.0, | |
glow_sampling_temperature=0.2): | |
""" | |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. | |
1.0 means no scaling happens, higher values increase durations for the whole | |
utterance, lower values decrease durations for the whole utterance. | |
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the pitch curve, | |
lower values decrease variance of the pitch curve. | |
energy_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the energy curve, | |
lower values decrease variance of the energy curve. | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.device = device | |
self.to(device) | |
with torch.inference_mode(): | |
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device)) | |
mel, durations, pitch, energy = self.phone2mel(phones, | |
return_duration_pitch_energy=True, | |
utterance_embedding=self.default_utterance_embedding.to(device), | |
durations=durations, | |
pitch=pitch, | |
energy=energy, | |
lang_id=self.lang_id.to(device), | |
duration_scaling_factor=duration_scaling_factor, | |
pitch_variance_scale=pitch_variance_scale, | |
energy_variance_scale=energy_variance_scale, | |
pause_duration_scaling_factor=pause_duration_scaling_factor, | |
glow_sampling_temperature=glow_sampling_temperature) | |
wave, _, _ = self.vocoder(mel.unsqueeze(0)) | |
wave = wave.squeeze().cpu() | |
wave = wave.numpy() | |
sr = 24000 | |
try: | |
loudness = self.meter.integrated_loudness(wave) | |
wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db) | |
except ValueError: | |
# if the audio is too short, a value error will arise | |
pass | |
if view or return_plot_as_filepath: | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) | |
ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu') | |
ax.yaxis.set_visible(False) | |
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) | |
ax.xaxis.grid(True, which='minor') | |
ax.set_xticks(label_positions, minor=False) | |
if input_is_phones: | |
phones = text.replace(" ", "|") | |
else: | |
phones = self.text2phone.get_phone_string(text, for_plot_labels=True) | |
ax.set_xticklabels(phones) | |
word_boundaries = list() | |
for label_index, phone in enumerate(phones): | |
if phone == "|": | |
word_boundaries.append(label_positions[label_index]) | |
try: | |
prev_word_boundary = 0 | |
word_label_positions = list() | |
for word_boundary in word_boundaries: | |
word_label_positions.append((word_boundary + prev_word_boundary) / 2) | |
prev_word_boundary = word_boundary | |
word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) | |
secondary_ax = ax.secondary_xaxis('bottom') | |
secondary_ax.tick_params(axis="x", direction="out", pad=24) | |
secondary_ax.set_xticks(word_label_positions, minor=False) | |
secondary_ax.set_xticklabels(text.split()) | |
secondary_ax.tick_params(axis='x', colors='orange') | |
secondary_ax.xaxis.label.set_color('orange') | |
except ValueError: | |
ax.set_title(text) | |
except IndexError: | |
ax.set_title(text) | |
ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5) | |
ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0) | |
plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0) | |
ax.set_aspect("auto") | |
if return_plot_as_filepath: | |
plt.savefig("tmp.png") | |
return wave, sr, "tmp.png" | |
self.to("cpu") | |
self.device="cpu" | |
return wave, sr | |
def read_to_file(self, | |
text_list, | |
file_location, | |
duration_scaling_factor=1.0, | |
pitch_variance_scale=1.0, | |
energy_variance_scale=1.0, | |
pause_duration_scaling_factor=1.0, | |
silent=False, | |
dur_list=None, | |
pitch_list=None, | |
energy_list=None, | |
glow_sampling_temperature=0.2): | |
""" | |
Args: | |
silent: Whether to be verbose about the process | |
text_list: A list of strings to be read | |
file_location: The path and name of the file it should be saved to | |
energy_list: list of energy tensors to be used for the texts | |
pitch_list: list of pitch tensors to be used for the texts | |
dur_list: list of duration tensors to be used for the texts | |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. | |
1.0 means no scaling happens, higher values increase durations for the whole | |
utterance, lower values decrease durations for the whole utterance. | |
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the pitch curve, | |
lower values decrease variance of the pitch curve. | |
energy_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the energy curve, | |
lower values decrease variance of the energy curve. | |
""" | |
if not dur_list: | |
dur_list = [] | |
if not pitch_list: | |
pitch_list = [] | |
if not energy_list: | |
energy_list = [] | |
silence = torch.zeros([14300]) | |
wav = silence.clone() | |
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list): | |
if text.strip() != "": | |
if not silent: | |
print("Now synthesizing: {}".format(text)) | |
spoken_sentence, sr = self(text, | |
durations=durations.to(self.device) if durations is not None else None, | |
pitch=pitch.to(self.device) if pitch is not None else None, | |
energy=energy.to(self.device) if energy is not None else None, | |
duration_scaling_factor=duration_scaling_factor, | |
pitch_variance_scale=pitch_variance_scale, | |
energy_variance_scale=energy_variance_scale, | |
pause_duration_scaling_factor=pause_duration_scaling_factor, | |
glow_sampling_temperature=glow_sampling_temperature) | |
spoken_sentence = torch.tensor(spoken_sentence).cpu() | |
wav = torch.cat((wav, spoken_sentence, silence), 0) | |
soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16") | |
def read_aloud(self, | |
text, | |
view=False, | |
duration_scaling_factor=1.0, | |
pitch_variance_scale=1.0, | |
energy_variance_scale=1.0, | |
blocking=False, | |
glow_sampling_temperature=0.2): | |
if text.strip() == "": | |
return | |
wav, sr = self(text, | |
view, | |
duration_scaling_factor=duration_scaling_factor, | |
pitch_variance_scale=pitch_variance_scale, | |
energy_variance_scale=energy_variance_scale, | |
glow_sampling_temperature=glow_sampling_temperature) | |
silence = torch.zeros([sr // 2]) | |
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() | |
sounddevice.play(float2pcm(wav), samplerate=sr) | |
if view: | |
plt.show() | |
if blocking: | |
sounddevice.wait() | |