|
from dataclasses import asdict, dataclass, field |
|
from typing import Dict, List |
|
|
|
from coqpit import MISSING |
|
|
|
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig |
|
|
|
|
|
@dataclass |
|
class BaseEncoderConfig(BaseTrainingConfig): |
|
"""Defines parameters for a Generic Encoder model.""" |
|
|
|
model: str = None |
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) |
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) |
|
|
|
model_params: Dict = field( |
|
default_factory=lambda: { |
|
"model_name": "lstm", |
|
"input_dim": 80, |
|
"proj_dim": 256, |
|
"lstm_dim": 768, |
|
"num_lstm_layers": 3, |
|
"use_lstm_with_projection": True, |
|
} |
|
) |
|
|
|
audio_augmentation: Dict = field(default_factory=lambda: {}) |
|
|
|
|
|
epochs: int = 10000 |
|
loss: str = "angleproto" |
|
grad_clip: float = 3.0 |
|
lr: float = 0.0001 |
|
optimizer: str = "radam" |
|
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0}) |
|
lr_decay: bool = False |
|
warmup_steps: int = 4000 |
|
|
|
|
|
tb_model_param_stats: bool = False |
|
steps_plot_stats: int = 10 |
|
save_step: int = 1000 |
|
print_step: int = 20 |
|
run_eval: bool = False |
|
|
|
|
|
num_classes_in_batch: int = MISSING |
|
num_utter_per_class: int = MISSING |
|
eval_num_classes_in_batch: int = None |
|
eval_num_utter_per_class: int = None |
|
|
|
num_loader_workers: int = MISSING |
|
voice_len: float = 1.6 |
|
|
|
def check_values(self): |
|
super().check_values() |
|
c = asdict(self) |
|
assert ( |
|
c["model_params"]["input_dim"] == self.audio.num_mels |
|
), " [!] model input dimendion must be equal to melspectrogram dimension." |
|
|