File size: 9,519 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from dataclasses import dataclass, field
from typing import List, Optional

from coqpit import Coqpit

from TTS.vc.configs.shared_configs import BaseVCConfig


@dataclass
class FreeVCAudioConfig(Coqpit):
    """Audio configuration

    Args:
        max_wav_value (float):
            The maximum value of the waveform.

        input_sample_rate (int):
            The sampling rate of the input waveform.

        output_sample_rate (int):
            The sampling rate of the output waveform.

        filter_length (int):
            The length of the filter.

        hop_length (int):
            The hop length.

        win_length (int):
            The window length.

        n_mel_channels (int):
            The number of mel channels.

        mel_fmin (float):
            The minimum frequency of the mel filterbank.

        mel_fmax (Optional[float]):
            The maximum frequency of the mel filterbank.
    """

    max_wav_value: float = field(default=32768.0)
    input_sample_rate: int = field(default=16000)
    output_sample_rate: int = field(default=24000)
    filter_length: int = field(default=1280)
    hop_length: int = field(default=320)
    win_length: int = field(default=1280)
    n_mel_channels: int = field(default=80)
    mel_fmin: float = field(default=0.0)
    mel_fmax: Optional[float] = field(default=None)


@dataclass
class FreeVCArgs(Coqpit):
    """FreeVC model arguments

    Args:
        spec_channels (int):
            The number of channels in the spectrogram.

        inter_channels (int):
            The number of channels in the intermediate layers.

        hidden_channels (int):
            The number of channels in the hidden layers.

        filter_channels (int):
            The number of channels in the filter layers.

        n_heads (int):
            The number of attention heads.

        n_layers (int):
            The number of layers.

        kernel_size (int):
            The size of the kernel.

        p_dropout (float):
            The dropout probability.

        resblock (str):
            The type of residual block.

        resblock_kernel_sizes (List[int]):
            The kernel sizes for the residual blocks.

        resblock_dilation_sizes (List[List[int]]):
            The dilation sizes for the residual blocks.

        upsample_rates (List[int]):
            The upsample rates.

        upsample_initial_channel (int):
            The number of channels in the initial upsample layer.

        upsample_kernel_sizes (List[int]):
            The kernel sizes for the upsample layers.

        n_layers_q (int):
            The number of layers in the quantization network.

        use_spectral_norm (bool):
            Whether to use spectral normalization.

        gin_channels (int):
            The number of channels in the global conditioning vector.

        ssl_dim (int):
            The dimension of the self-supervised learning embedding.

        use_spk (bool):
            Whether to use external speaker encoder.
    """

    spec_channels: int = field(default=641)
    inter_channels: int = field(default=192)
    hidden_channels: int = field(default=192)
    filter_channels: int = field(default=768)
    n_heads: int = field(default=2)
    n_layers: int = field(default=6)
    kernel_size: int = field(default=3)
    p_dropout: float = field(default=0.1)
    resblock: str = field(default="1")
    resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
    resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
    upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
    upsample_initial_channel: int = field(default=512)
    upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
    n_layers_q: int = field(default=3)
    use_spectral_norm: bool = field(default=False)
    gin_channels: int = field(default=256)
    ssl_dim: int = field(default=1024)
    use_spk: bool = field(default=False)
    num_spks: int = field(default=0)
    segment_size: int = field(default=8960)


@dataclass
class FreeVCConfig(BaseVCConfig):
    """Defines parameters for FreeVC End2End TTS model.

    Args:
        model (str):
            Model name. Do not change unless you know what you are doing.

        model_args (FreeVCArgs):
            Model architecture arguments. Defaults to `FreeVCArgs()`.

        audio (FreeVCAudioConfig):
            Audio processing configuration. Defaults to `FreeVCAudioConfig()`.

        grad_clip (List):
            Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.

        lr_gen (float):
            Initial learning rate for the generator. Defaults to 0.0002.

        lr_disc (float):
            Initial learning rate for the discriminator. Defaults to 0.0002.

        lr_scheduler_gen (str):
            Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
            `ExponentialLR`.

        lr_scheduler_gen_params (dict):
            Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

        lr_scheduler_disc (str):
            Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
            `ExponentialLR`.

        lr_scheduler_disc_params (dict):
            Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

        scheduler_after_epoch (bool):
            If true, step the schedulers after each epoch else after each step. Defaults to `False`.

        optimizer (str):
            Name of the optimizer to use with both the generator and the discriminator networks. One of the
            `torch.optim.*`. Defaults to `AdamW`.

        kl_loss_alpha (float):
            Loss weight for KL loss. Defaults to 1.0.

        disc_loss_alpha (float):
            Loss weight for the discriminator loss. Defaults to 1.0.

        gen_loss_alpha (float):
            Loss weight for the generator loss. Defaults to 1.0.

        feat_loss_alpha (float):
            Loss weight for the feature matching loss. Defaults to 1.0.

        mel_loss_alpha (float):
            Loss weight for the mel loss. Defaults to 45.0.

        return_wav (bool):
            If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.

        compute_linear_spec (bool):
            If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.

        use_weighted_sampler (bool):
            If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.

        weighted_sampler_attrs (dict):
            Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
            by overweighting `root_path` by 2.0. Defaults to `{}`.

        weighted_sampler_multipliers (dict):
            Weight each unique value of a key returned by the formatter for weighted sampling.
            For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
            It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.

        r (int):
            Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.

        add_blank (bool):
            If true, a blank token is added in between every character. Defaults to `True`.

        test_sentences (List[List]):
            List of sentences with speaker and language information to be used for testing.

        language_ids_file (str):
            Path to the language ids file.

        use_language_embedding (bool):
            If true, language embedding is used. Defaults to `False`.

    Note:
        Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.

    Example:

        >>> from TTS.vc.configs.freevc_config import FreeVCConfig
        >>> config = FreeVCConfig()
    """

    model: str = "freevc"
    # model specific params
    model_args: FreeVCArgs = field(default_factory=FreeVCArgs)
    audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig)

    # optimizer
    # TODO with training support

    # loss params
    # TODO with training support

    # data loader params
    return_wav: bool = True
    compute_linear_spec: bool = True

    # sampler params
    use_weighted_sampler: bool = False  # TODO: move it to the base config
    weighted_sampler_attrs: dict = field(default_factory=lambda: {})
    weighted_sampler_multipliers: dict = field(default_factory=lambda: {})

    # overrides
    r: int = 1  # DO NOT CHANGE
    add_blank: bool = True

    # multi-speaker settings
    # use speaker embedding layer
    num_speakers: int = 0
    speakers_file: str = None
    speaker_embedding_channels: int = 256

    # use d-vectors
    use_d_vector_file: bool = False
    d_vector_file: List[str] = None
    d_vector_dim: int = None

    def __post_init__(self):
        for key, val in self.model_args.items():
            if hasattr(self, key):
                self[key] = val