|
import os |
|
import random |
|
from typing import Dict, List, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from coqpit import Coqpit |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.sampler import WeightedRandomSampler |
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper |
|
|
|
from TTS.model import BaseTrainerModel |
|
from TTS.tts.datasets.dataset import TTSDataset |
|
from TTS.tts.utils.data import get_length_balancer_weights |
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights |
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager |
|
from TTS.tts.utils.synthesis import synthesis |
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram |
|
|
|
|
|
|
|
|
|
class BaseTTS(BaseTrainerModel): |
|
"""Base `tts` class. Every new `tts` model must inherit this. |
|
|
|
It defines common `tts` specific functions on top of `Model` implementation. |
|
""" |
|
|
|
MODEL_TYPE = "tts" |
|
|
|
def __init__( |
|
self, |
|
config: Coqpit, |
|
ap: "AudioProcessor", |
|
tokenizer: "TTSTokenizer", |
|
speaker_manager: SpeakerManager = None, |
|
language_manager: LanguageManager = None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.ap = ap |
|
self.tokenizer = tokenizer |
|
self.speaker_manager = speaker_manager |
|
self.language_manager = language_manager |
|
self._set_model_args(config) |
|
|
|
def _set_model_args(self, config: Coqpit): |
|
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`). |
|
|
|
`ModelArgs` has all the fields reuqired to initialize the model architecture. |
|
|
|
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. |
|
|
|
If the config is for training with a name like "*Config", then the model args are embeded in the |
|
config.model_args |
|
|
|
If the config is for the model with a name like "*Args", then we assign the directly. |
|
""" |
|
|
|
if "Config" in config.__class__.__name__: |
|
config_num_chars = ( |
|
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars |
|
) |
|
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars |
|
if "characters" in config: |
|
self.config.num_chars = num_chars |
|
if hasattr(self.config, "model_args"): |
|
config.model_args.num_chars = num_chars |
|
self.args = self.config.model_args |
|
else: |
|
self.config = config |
|
self.args = config.model_args |
|
elif "Args" in config.__class__.__name__: |
|
self.args = config |
|
else: |
|
raise ValueError("config must be either a *Config or *Args") |
|
|
|
def init_multispeaker(self, config: Coqpit, data: List = None): |
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining |
|
`in_channels` size of the connected layers. |
|
|
|
This implementation yields 3 possible outcomes: |
|
|
|
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. |
|
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. |
|
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of |
|
`config.d_vector_dim` or 512. |
|
|
|
You can override this function for new models. |
|
|
|
Args: |
|
config (Coqpit): Model configuration. |
|
""" |
|
|
|
if self.speaker_manager is not None: |
|
self.num_speakers = self.speaker_manager.num_speakers |
|
elif hasattr(config, "num_speakers"): |
|
self.num_speakers = config.num_speakers |
|
|
|
|
|
if config.use_speaker_embedding or config.use_d_vector_file: |
|
self.embedded_speaker_dim = ( |
|
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 |
|
) |
|
|
|
if config.use_speaker_embedding and not config.use_d_vector_file: |
|
print(" > Init speaker_embedding layer.") |
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) |
|
self.speaker_embedding.weight.data.normal_(0, 0.3) |
|
|
|
def get_aux_input(self, **kwargs) -> Dict: |
|
"""Prepare and return `aux_input` used by `forward()`""" |
|
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} |
|
|
|
def get_aux_input_from_test_sentences(self, sentence_info): |
|
if hasattr(self.config, "model_args"): |
|
config = self.config.model_args |
|
else: |
|
config = self.config |
|
|
|
|
|
text, speaker_name, style_wav, language_name = None, None, None, None |
|
|
|
if isinstance(sentence_info, list): |
|
if len(sentence_info) == 1: |
|
text = sentence_info[0] |
|
elif len(sentence_info) == 2: |
|
text, speaker_name = sentence_info |
|
elif len(sentence_info) == 3: |
|
text, speaker_name, style_wav = sentence_info |
|
elif len(sentence_info) == 4: |
|
text, speaker_name, style_wav, language_name = sentence_info |
|
else: |
|
text = sentence_info |
|
|
|
|
|
speaker_id, d_vector, language_id = None, None, None |
|
if self.speaker_manager is not None: |
|
if config.use_d_vector_file: |
|
if speaker_name is None: |
|
d_vector = self.speaker_manager.get_random_embedding() |
|
else: |
|
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) |
|
elif config.use_speaker_embedding: |
|
if speaker_name is None: |
|
speaker_id = self.speaker_manager.get_random_id() |
|
else: |
|
speaker_id = self.speaker_manager.name_to_id[speaker_name] |
|
|
|
|
|
if self.language_manager is not None and config.use_language_embedding and language_name is not None: |
|
language_id = self.language_manager.name_to_id[language_name] |
|
|
|
return { |
|
"text": text, |
|
"speaker_id": speaker_id, |
|
"style_wav": style_wav, |
|
"d_vector": d_vector, |
|
"language_id": language_id, |
|
} |
|
|
|
def format_batch(self, batch: Dict) -> Dict: |
|
"""Generic batch formatting for `TTSDataset`. |
|
|
|
You must override this if you use a custom dataset. |
|
|
|
Args: |
|
batch (Dict): [description] |
|
|
|
Returns: |
|
Dict: [description] |
|
""" |
|
|
|
text_input = batch["token_id"] |
|
text_lengths = batch["token_id_lengths"] |
|
speaker_names = batch["speaker_names"] |
|
linear_input = batch["linear"] |
|
mel_input = batch["mel"] |
|
mel_lengths = batch["mel_lengths"] |
|
stop_targets = batch["stop_targets"] |
|
item_idx = batch["item_idxs"] |
|
d_vectors = batch["d_vectors"] |
|
speaker_ids = batch["speaker_ids"] |
|
attn_mask = batch["attns"] |
|
waveform = batch["waveform"] |
|
pitch = batch["pitch"] |
|
energy = batch["energy"] |
|
language_ids = batch["language_ids"] |
|
max_text_length = torch.max(text_lengths.float()) |
|
max_spec_length = torch.max(mel_lengths.float()) |
|
|
|
|
|
durations = None |
|
if attn_mask is not None: |
|
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) |
|
for idx, am in enumerate(attn_mask): |
|
|
|
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] |
|
|
|
c_idxs, counts = torch.unique(c_idxs, return_counts=True) |
|
dur = torch.ones([text_lengths[idx]]).to(counts.dtype) |
|
dur[c_idxs] = counts |
|
|
|
|
|
extra_frames = dur.sum() - mel_lengths[idx] |
|
largest_idxs = torch.argsort(-dur)[:extra_frames] |
|
dur[largest_idxs] -= 1 |
|
assert ( |
|
dur.sum() == mel_lengths[idx] |
|
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" |
|
durations[idx, : text_lengths[idx]] = dur |
|
|
|
|
|
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) |
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) |
|
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() |
|
|
|
return { |
|
"text_input": text_input, |
|
"text_lengths": text_lengths, |
|
"speaker_names": speaker_names, |
|
"mel_input": mel_input, |
|
"mel_lengths": mel_lengths, |
|
"linear_input": linear_input, |
|
"stop_targets": stop_targets, |
|
"stop_target_lengths": stop_target_lengths, |
|
"attn_mask": attn_mask, |
|
"durations": durations, |
|
"speaker_ids": speaker_ids, |
|
"d_vectors": d_vectors, |
|
"max_text_length": float(max_text_length), |
|
"max_spec_length": float(max_spec_length), |
|
"item_idx": item_idx, |
|
"waveform": waveform, |
|
"pitch": pitch, |
|
"energy": energy, |
|
"language_ids": language_ids, |
|
"audio_unique_names": batch["audio_unique_names"], |
|
} |
|
|
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): |
|
weights = None |
|
data_items = dataset.samples |
|
|
|
if getattr(config, "use_language_weighted_sampler", False): |
|
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) |
|
print(" > Using Language weighted sampler with alpha:", alpha) |
|
weights = get_language_balancer_weights(data_items) * alpha |
|
|
|
if getattr(config, "use_speaker_weighted_sampler", False): |
|
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) |
|
print(" > Using Speaker weighted sampler with alpha:", alpha) |
|
if weights is not None: |
|
weights += get_speaker_balancer_weights(data_items) * alpha |
|
else: |
|
weights = get_speaker_balancer_weights(data_items) * alpha |
|
|
|
if getattr(config, "use_length_weighted_sampler", False): |
|
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) |
|
print(" > Using Length weighted sampler with alpha:", alpha) |
|
if weights is not None: |
|
weights += get_length_balancer_weights(data_items) * alpha |
|
else: |
|
weights = get_length_balancer_weights(data_items) * alpha |
|
|
|
if weights is not None: |
|
sampler = WeightedRandomSampler(weights, len(weights)) |
|
else: |
|
sampler = None |
|
|
|
|
|
if sampler is None: |
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
|
else: |
|
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler |
|
|
|
return sampler |
|
|
|
def get_data_loader( |
|
self, |
|
config: Coqpit, |
|
assets: Dict, |
|
is_eval: bool, |
|
samples: Union[List[Dict], List[List]], |
|
verbose: bool, |
|
num_gpus: int, |
|
rank: int = None, |
|
) -> "DataLoader": |
|
if is_eval and not config.run_eval: |
|
loader = None |
|
else: |
|
|
|
if self.speaker_manager is not None: |
|
if hasattr(config, "model_args"): |
|
speaker_id_mapping = ( |
|
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None |
|
) |
|
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None |
|
config.use_d_vector_file = config.model_args.use_d_vector_file |
|
else: |
|
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None |
|
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None |
|
else: |
|
speaker_id_mapping = None |
|
d_vector_mapping = None |
|
|
|
|
|
if self.language_manager is not None: |
|
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None |
|
else: |
|
language_id_mapping = None |
|
|
|
|
|
dataset = TTSDataset( |
|
outputs_per_step=config.r if "r" in config else 1, |
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, |
|
compute_f0=config.get("compute_f0", False), |
|
f0_cache_path=config.get("f0_cache_path", None), |
|
compute_energy=config.get("compute_energy", False), |
|
energy_cache_path=config.get("energy_cache_path", None), |
|
samples=samples, |
|
ap=self.ap, |
|
return_wav=config.return_wav if "return_wav" in config else False, |
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, |
|
min_text_len=config.min_text_len, |
|
max_text_len=config.max_text_len, |
|
min_audio_len=config.min_audio_len, |
|
max_audio_len=config.max_audio_len, |
|
phoneme_cache_path=config.phoneme_cache_path, |
|
precompute_num_workers=config.precompute_num_workers, |
|
use_noise_augment=False if is_eval else config.use_noise_augment, |
|
verbose=verbose, |
|
speaker_id_mapping=speaker_id_mapping, |
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, |
|
tokenizer=self.tokenizer, |
|
start_by_longest=config.start_by_longest, |
|
language_id_mapping=language_id_mapping, |
|
) |
|
|
|
|
|
if num_gpus > 1: |
|
dist.barrier() |
|
|
|
|
|
dataset.preprocess_samples() |
|
|
|
|
|
sampler = self.get_sampler(config, dataset, num_gpus) |
|
|
|
loader = DataLoader( |
|
dataset, |
|
batch_size=config.eval_batch_size if is_eval else config.batch_size, |
|
shuffle=config.shuffle if sampler is None else False, |
|
collate_fn=dataset.collate_fn, |
|
drop_last=config.drop_last, |
|
sampler=sampler, |
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
|
pin_memory=False, |
|
) |
|
return loader |
|
|
|
def _get_test_aux_input( |
|
self, |
|
) -> Dict: |
|
d_vector = None |
|
if self.config.use_d_vector_file: |
|
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] |
|
d_vector = (random.sample(sorted(d_vector), 1),) |
|
|
|
aux_inputs = { |
|
"speaker_id": None |
|
if not self.config.use_speaker_embedding |
|
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), |
|
"d_vector": d_vector, |
|
"style_wav": None, |
|
} |
|
return aux_inputs |
|
|
|
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: |
|
"""Generic test run for `tts` models used by `Trainer`. |
|
|
|
You can override this for a different behaviour. |
|
|
|
Args: |
|
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. |
|
|
|
Returns: |
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. |
|
""" |
|
print(" | > Synthesizing test sentences.") |
|
test_audios = {} |
|
test_figures = {} |
|
test_sentences = self.config.test_sentences |
|
aux_inputs = self._get_test_aux_input() |
|
for idx, sen in enumerate(test_sentences): |
|
if isinstance(sen, list): |
|
aux_inputs = self.get_aux_input_from_test_sentences(sen) |
|
sen = aux_inputs["text"] |
|
outputs_dict = synthesis( |
|
self, |
|
sen, |
|
self.config, |
|
"cuda" in str(next(self.parameters()).device), |
|
speaker_id=aux_inputs["speaker_id"], |
|
d_vector=aux_inputs["d_vector"], |
|
style_wav=aux_inputs["style_wav"], |
|
use_griffin_lim=True, |
|
do_trim_silence=False, |
|
) |
|
test_audios["{}-audio".format(idx)] = outputs_dict["wav"] |
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram( |
|
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False |
|
) |
|
test_figures["{}-alignment".format(idx)] = plot_alignment( |
|
outputs_dict["outputs"]["alignments"], output_fig=False |
|
) |
|
return test_figures, test_audios |
|
|
|
def on_init_start(self, trainer): |
|
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" |
|
if self.speaker_manager is not None: |
|
output_path = os.path.join(trainer.output_path, "speakers.pth") |
|
self.speaker_manager.save_ids_to_file(output_path) |
|
trainer.config.speakers_file = output_path |
|
|
|
if hasattr(trainer.config, "model_args"): |
|
trainer.config.model_args.speakers_file = output_path |
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) |
|
print(f" > `speakers.pth` is saved to {output_path}.") |
|
print(" > `speakers_file` is updated in the config.json.") |
|
|
|
if self.language_manager is not None: |
|
output_path = os.path.join(trainer.output_path, "language_ids.json") |
|
self.language_manager.save_ids_to_file(output_path) |
|
trainer.config.language_ids_file = output_path |
|
if hasattr(trainer.config, "model_args"): |
|
trainer.config.model_args.language_ids_file = output_path |
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) |
|
print(f" > `language_ids.json` is saved to {output_path}.") |
|
print(" > `language_ids_file` is updated in the config.json.") |
|
|
|
|
|
class BaseTTSE2E(BaseTTS): |
|
def _set_model_args(self, config: Coqpit): |
|
self.config = config |
|
if "Config" in config.__class__.__name__: |
|
num_chars = ( |
|
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars |
|
) |
|
self.config.model_args.num_chars = num_chars |
|
self.config.num_chars = num_chars |
|
self.args = config.model_args |
|
self.args.num_chars = num_chars |
|
elif "Args" in config.__class__.__name__: |
|
self.args = config |
|
self.args.num_chars = self.args.num_chars |
|
else: |
|
raise ValueError("config must be either a *Config or *Args") |
|
|