import itertools import os import warnings from typing import cast import matplotlib.pyplot as plt import pyloudnorm import sounddevice import soundfile import torch import spaces with warnings.catch_warnings(): warnings.simplefilter("ignore") from speechbrain.pretrained import EncoderClassifier from torchaudio.transforms import Resample from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.TextFrontend import get_language_id from Utility.storage_config import MODELS_DIR from Utility.utils import cumsum_durations from Utility.utils import float2pcm class ToucanTTSInterface(torch.nn.Module): def __init__(self, device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude. tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint language="eng", # initial language of the model, can be changed later with the setter methods enhance=None # legacy argument ): super().__init__() self.device = device if not tts_model_path.endswith(".pt"): # default to shorthand system tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt") ################################ # build text to phone # ################################ self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True) ##################################### # load phone to features model # ##################################### checkpoint = torch.load(tts_model_path, map_location='cpu') self.phone2mel = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]) with torch.no_grad(): self.phone2mel.store_inverse_all() # this also removes weight norm self.phone2mel = self.phone2mel.to(torch.device(device)) ###################################### # load features to style models # ###################################### self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(device)}, savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) ################################ # load mel to wave model # ################################ vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu") self.vocoder = HiFiGAN() self.vocoder.load_state_dict(vocoder_checkpoint) self.vocoder = self.vocoder.to(device).eval() self.vocoder.remove_weight_norm() self.meter = pyloudnorm.Meter(24000) ################################ # set defaults # ################################ self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device) self.phone2mel.eval() self.vocoder.eval() self.lang_id = get_language_id(language) self.to(torch.device(device)) self.eval() def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): if embedding is not None: self.default_utterance_embedding = embedding.squeeze().to(self.device) return if type(path_to_reference_audio) != list: path_to_reference_audio = [path_to_reference_audio] if len(path_to_reference_audio) > 0: for path in path_to_reference_audio: assert os.path.exists(path) speaker_embs = list() for path in path_to_reference_audio: wave, sr = soundfile.read(path) wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32)) speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze() speaker_embs.append(speaker_embedding) self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) def set_language(self, lang_id): """ The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs """ self.set_phonemizer_language(lang_id=lang_id) self.set_accent_language(lang_id=lang_id) def set_phonemizer_language(self, lang_id): self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True) def set_accent_language(self, lang_id): if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']: if lang_id == 'vi-so' or lang_id == 'vi-ctr': lang_id = 'vie' elif lang_id == 'spa-lat': lang_id = 'spa' elif lang_id == 'pt-br': lang_id = 'por' elif lang_id == 'fr-sw' or lang_id == 'fr-be': lang_id = 'fra' elif lang_id == 'en-sc' or lang_id == 'en-us': lang_id = 'eng' else: # no clue where these others are even coming from, they are not in ISO 639-2 lang_id = 'eng' self.lang_id = get_language_id(lang_id).to(self.device) @spaces.GPU def forward(self, text, view=False, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, durations=None, pitch=None, energy=None, input_is_phones=False, return_plot_as_filepath=False, loudness_in_db=-24.0, glow_sampling_temperature=0.2): """ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. """ device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.to(device) with torch.inference_mode(): phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device)) mel, durations, pitch, energy = self.phone2mel(phones, return_duration_pitch_energy=True, utterance_embedding=self.default_utterance_embedding, durations=durations, pitch=pitch, energy=energy, lang_id=self.lang_id, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, glow_sampling_temperature=glow_sampling_temperature) wave, _, _ = self.vocoder(mel.unsqueeze(0)) wave = wave.squeeze().cpu() wave = wave.numpy() sr = 24000 try: loudness = self.meter.integrated_loudness(wave) wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db) except ValueError: # if the audio is too short, a value error will arise pass if view or return_plot_as_filepath: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu') ax.yaxis.set_visible(False) duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) ax.xaxis.grid(True, which='minor') ax.set_xticks(label_positions, minor=False) if input_is_phones: phones = text.replace(" ", "|") else: phones = self.text2phone.get_phone_string(text, for_plot_labels=True) ax.set_xticklabels(phones) word_boundaries = list() for label_index, phone in enumerate(phones): if phone == "|": word_boundaries.append(label_positions[label_index]) try: prev_word_boundary = 0 word_label_positions = list() for word_boundary in word_boundaries: word_label_positions.append((word_boundary + prev_word_boundary) / 2) prev_word_boundary = word_boundary word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) secondary_ax = ax.secondary_xaxis('bottom') secondary_ax.tick_params(axis="x", direction="out", pad=24) secondary_ax.set_xticks(word_label_positions, minor=False) secondary_ax.set_xticklabels(text.split()) secondary_ax.tick_params(axis='x', colors='orange') secondary_ax.xaxis.label.set_color('orange') except ValueError: ax.set_title(text) except IndexError: ax.set_title(text) ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5) ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0) plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0) ax.set_aspect("auto") if return_plot_as_filepath: plt.savefig("tmp.png") return wave, sr, "tmp.png" self.to("cpu") self.device="cpu" return wave, sr def read_to_file(self, text_list, file_location, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, silent=False, dur_list=None, pitch_list=None, energy_list=None, glow_sampling_temperature=0.2): """ Args: silent: Whether to be verbose about the process text_list: A list of strings to be read file_location: The path and name of the file it should be saved to energy_list: list of energy tensors to be used for the texts pitch_list: list of pitch tensors to be used for the texts dur_list: list of duration tensors to be used for the texts duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. """ if not dur_list: dur_list = [] if not pitch_list: pitch_list = [] if not energy_list: energy_list = [] silence = torch.zeros([14300]) wav = silence.clone() for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list): if text.strip() != "": if not silent: print("Now synthesizing: {}".format(text)) spoken_sentence, sr = self(text, durations=durations.to(self.device) if durations is not None else None, pitch=pitch.to(self.device) if pitch is not None else None, energy=energy.to(self.device) if energy is not None else None, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, glow_sampling_temperature=glow_sampling_temperature) spoken_sentence = torch.tensor(spoken_sentence).cpu() wav = torch.cat((wav, spoken_sentence, silence), 0) soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16") def read_aloud(self, text, view=False, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, blocking=False, glow_sampling_temperature=0.2): if text.strip() == "": return wav, sr = self(text, view, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, glow_sampling_temperature=glow_sampling_temperature) silence = torch.zeros([sr // 2]) wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() sounddevice.play(float2pcm(wav), samplerate=sr) if view: plt.show() if blocking: sounddevice.wait()