Spaces:
Running
Running
Init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +3 -0
- .gitignore +9 -0
- BIGVGAN-LICENSE +21 -0
- RADTTS-LICENSE +19 -0
- README.md +24 -7
- activations.py +126 -0
- alias_free_activation/cuda/__init__.py +0 -0
- alias_free_activation/cuda/activation1d.py +77 -0
- alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- alias_free_activation/cuda/compat.h +29 -0
- alias_free_activation/cuda/load.py +86 -0
- alias_free_activation/cuda/type_shim.h +92 -0
- alias_free_activation/torch/__init__.py +6 -0
- alias_free_activation/torch/act.py +30 -0
- alias_free_activation/torch/filter.py +101 -0
- alias_free_activation/torch/resample.py +58 -0
- alignment.py +54 -0
- app.py +360 -0
- attribute_prediction_model.py +402 -0
- audio_processing.py +328 -0
- autoregressive_flow.py +259 -0
- bigvgan.py +528 -0
- common.py +1083 -0
- configs/bigvgan_config.json +63 -0
- configs/radtts-pp-dap-model.json +218 -0
- data.py +606 -0
- distributed.py +161 -0
- filelists/3speakers_ukrainian_train_filelist.txt +0 -0
- filelists/3speakers_ukrainian_train_filelist_dc.txt +0 -0
- filelists/3speakers_ukrainian_val_filelist.txt +85 -0
- filelists/3speakers_ukrainian_val_filelist_dc.txt +85 -0
- loss.py +228 -0
- partialconv1d.py +77 -0
- radam.py +114 -0
- radtts.py +936 -0
- requirements-dev.txt +1 -0
- requirements.txt +13 -0
- splines.py +326 -0
- transformer.py +219 -0
- tts_text_processing/LICENSE +19 -0
- tts_text_processing/abbreviations.py +57 -0
- tts_text_processing/acronyms.py +69 -0
- tts_text_processing/cleaners.py +123 -0
- tts_text_processing/cmudict.py +140 -0
- tts_text_processing/datestime.py +24 -0
- tts_text_processing/grapheme_dictionary.py +37 -0
- tts_text_processing/heteronyms +413 -0
- tts_text_processing/letters_and_numbers.py +96 -0
- tts_text_processing/numerical.py +175 -0
.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.ruff_cache/
|
2 |
+
.venv/
|
3 |
+
models/
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.venv/
|
3 |
+
.ruff_cache/
|
4 |
+
__pycache__/
|
5 |
+
|
6 |
+
flagged/
|
7 |
+
models/
|
8 |
+
|
9 |
+
audio.wav
|
BIGVGAN-LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 NVIDIA CORPORATION.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
RADTTS-LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation
|
6 |
+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
Software is furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,29 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
|
|
|
|
|
|
7 |
sdk_version: 5.19.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: apache-2.0
|
3 |
+
title: RAD-TTS++ Ukrainian (BigVGAN v2)
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
emoji: 🎧
|
6 |
+
colorFrom: blue
|
7 |
+
colorTo: gray
|
8 |
+
short_description: Use RAD-TTS++ model to synthesize text in Ukrainian
|
9 |
sdk_version: 5.19.0
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
## Install
|
13 |
+
|
14 |
+
```shell
|
15 |
+
uv venv --python 3.10
|
16 |
+
|
17 |
+
source .venv/bin/activate
|
18 |
+
|
19 |
+
uv pip install -r requirements.txt
|
20 |
+
|
21 |
+
# in development mode
|
22 |
+
uv pip install -r requirements-dev.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
## Run
|
26 |
+
|
27 |
+
```shell
|
28 |
+
python app.py
|
29 |
+
```
|
activations.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
"""
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Initialization.
|
31 |
+
INPUT:
|
32 |
+
- in_features: shape of the input
|
33 |
+
- alpha: trainable parameter
|
34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
35 |
+
alpha will be trained along with the rest of your model.
|
36 |
+
"""
|
37 |
+
super(Snake, self).__init__()
|
38 |
+
self.in_features = in_features
|
39 |
+
|
40 |
+
# initialize alpha
|
41 |
+
self.alpha_logscale = alpha_logscale
|
42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
44 |
+
else: # linear scale alphas initialized to ones
|
45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
46 |
+
|
47 |
+
self.alpha.requires_grad = alpha_trainable
|
48 |
+
|
49 |
+
self.no_div_by_zero = 0.000000001
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
"""
|
53 |
+
Forward pass of the function.
|
54 |
+
Applies the function to the input elementwise.
|
55 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
56 |
+
"""
|
57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
58 |
+
if self.alpha_logscale:
|
59 |
+
alpha = torch.exp(alpha)
|
60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class SnakeBeta(nn.Module):
|
66 |
+
"""
|
67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
68 |
+
Shape:
|
69 |
+
- Input: (B, C, T)
|
70 |
+
- Output: (B, C, T), same shape as the input
|
71 |
+
Parameters:
|
72 |
+
- alpha - trainable parameter that controls frequency
|
73 |
+
- beta - trainable parameter that controls magnitude
|
74 |
+
References:
|
75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
76 |
+
https://arxiv.org/abs/2006.08195
|
77 |
+
Examples:
|
78 |
+
>>> a1 = snakebeta(256)
|
79 |
+
>>> x = torch.randn(256)
|
80 |
+
>>> x = a1(x)
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
Initialization.
|
88 |
+
INPUT:
|
89 |
+
- in_features: shape of the input
|
90 |
+
- alpha - trainable parameter that controls frequency
|
91 |
+
- beta - trainable parameter that controls magnitude
|
92 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
93 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
94 |
+
alpha will be trained along with the rest of your model.
|
95 |
+
"""
|
96 |
+
super(SnakeBeta, self).__init__()
|
97 |
+
self.in_features = in_features
|
98 |
+
|
99 |
+
# initialize alpha
|
100 |
+
self.alpha_logscale = alpha_logscale
|
101 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
102 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
103 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
104 |
+
else: # linear scale alphas initialized to ones
|
105 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
106 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
107 |
+
|
108 |
+
self.alpha.requires_grad = alpha_trainable
|
109 |
+
self.beta.requires_grad = alpha_trainable
|
110 |
+
|
111 |
+
self.no_div_by_zero = 0.000000001
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""
|
115 |
+
Forward pass of the function.
|
116 |
+
Applies the function to the input elementwise.
|
117 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
118 |
+
"""
|
119 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
120 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
121 |
+
if self.alpha_logscale:
|
122 |
+
alpha = torch.exp(alpha)
|
123 |
+
beta = torch.exp(beta)
|
124 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
125 |
+
|
126 |
+
return x
|
alias_free_activation/cuda/__init__.py
ADDED
File without changes
|
alias_free_activation/cuda/activation1d.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from alias_free_activation.cuda import load
|
10 |
+
|
11 |
+
anti_alias_activation_cuda = load.load()
|
12 |
+
|
13 |
+
|
14 |
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
15 |
+
"""
|
16 |
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
17 |
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
18 |
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
23 |
+
activation_results = anti_alias_activation_cuda.forward(
|
24 |
+
inputs, up_ftr, down_ftr, alpha, beta
|
25 |
+
)
|
26 |
+
|
27 |
+
return activation_results
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, output_grads):
|
31 |
+
raise NotImplementedError
|
32 |
+
return output_grads, None, None
|
33 |
+
|
34 |
+
|
35 |
+
class Activation1d(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
activation,
|
39 |
+
up_ratio: int = 2,
|
40 |
+
down_ratio: int = 2,
|
41 |
+
up_kernel_size: int = 12,
|
42 |
+
down_kernel_size: int = 12,
|
43 |
+
fused: bool = True,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.up_ratio = up_ratio
|
47 |
+
self.down_ratio = down_ratio
|
48 |
+
self.act = activation
|
49 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
50 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
51 |
+
|
52 |
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
if not self.fused:
|
56 |
+
x = self.upsample(x)
|
57 |
+
x = self.act(x)
|
58 |
+
x = self.downsample(x)
|
59 |
+
return x
|
60 |
+
else:
|
61 |
+
if self.act.__class__.__name__ == "Snake":
|
62 |
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
63 |
+
else:
|
64 |
+
beta = (
|
65 |
+
self.act.beta.data
|
66 |
+
) # Snakebeta uses different params for alpha and beta
|
67 |
+
alpha = self.act.alpha.data
|
68 |
+
if (
|
69 |
+
not self.act.alpha_logscale
|
70 |
+
): # Exp baked into cuda kernel, cancel it out with a log
|
71 |
+
alpha = torch.log(alpha)
|
72 |
+
beta = torch.log(beta)
|
73 |
+
|
74 |
+
x = FusedAntiAliasActivation.apply(
|
75 |
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
76 |
+
)
|
77 |
+
return x
|
alias_free_activation/cuda/anti_alias_activation.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
23 |
+
}
|
alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include <cuda.h>
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_profiler_api.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
#include <torch/extension.h>
|
24 |
+
#include "type_shim.h"
|
25 |
+
#include <assert.h>
|
26 |
+
#include <cfloat>
|
27 |
+
#include <limits>
|
28 |
+
#include <stdint.h>
|
29 |
+
#include <c10/macros/Macros.h>
|
30 |
+
|
31 |
+
namespace
|
32 |
+
{
|
33 |
+
// Hard-coded hyperparameters
|
34 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
35 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
36 |
+
constexpr int BUFFER_SIZE = 32;
|
37 |
+
constexpr int FILTER_SIZE = 12;
|
38 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
39 |
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
40 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
41 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
42 |
+
|
43 |
+
template <typename input_t, typename output_t, typename acc_t>
|
44 |
+
__global__ void anti_alias_activation_forward(
|
45 |
+
output_t *dst,
|
46 |
+
const input_t *src,
|
47 |
+
const input_t *up_ftr,
|
48 |
+
const input_t *down_ftr,
|
49 |
+
const input_t *alpha,
|
50 |
+
const input_t *beta,
|
51 |
+
int batch_size,
|
52 |
+
int channels,
|
53 |
+
int seq_len)
|
54 |
+
{
|
55 |
+
// Up and downsample filters
|
56 |
+
input_t up_filter[FILTER_SIZE];
|
57 |
+
input_t down_filter[FILTER_SIZE];
|
58 |
+
|
59 |
+
// Load data from global memory including extra indices reserved for replication paddings
|
60 |
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
61 |
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
62 |
+
|
63 |
+
// Output stores downsampled output before writing to dst
|
64 |
+
output_t output[BUFFER_SIZE];
|
65 |
+
|
66 |
+
// blockDim/threadIdx = (128, 1, 1)
|
67 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
68 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
69 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
70 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
71 |
+
|
72 |
+
// intermediate have double the seq_len
|
73 |
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
74 |
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
75 |
+
|
76 |
+
// Get values needed for replication padding before moving pointer
|
77 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
78 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
79 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
80 |
+
|
81 |
+
// Move src and dst pointers
|
82 |
+
src += block_offset + local_offset;
|
83 |
+
dst += block_offset + local_offset;
|
84 |
+
|
85 |
+
// Alpha and beta values for snake activatons. Applies exp by default
|
86 |
+
alpha = alpha + blockIdx.y;
|
87 |
+
input_t alpha_val = expf(alpha[0]);
|
88 |
+
beta = beta + blockIdx.y;
|
89 |
+
input_t beta_val = expf(beta[0]);
|
90 |
+
|
91 |
+
#pragma unroll
|
92 |
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
93 |
+
{
|
94 |
+
up_filter[it] = up_ftr[it];
|
95 |
+
down_filter[it] = down_ftr[it];
|
96 |
+
}
|
97 |
+
|
98 |
+
// Apply replication padding for upsampling, matching torch impl
|
99 |
+
#pragma unroll
|
100 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
101 |
+
{
|
102 |
+
int element_index = seq_offset + it; // index for element
|
103 |
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
104 |
+
{
|
105 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
106 |
+
}
|
107 |
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
108 |
+
{
|
109 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
110 |
+
}
|
111 |
+
if ((element_index >= 0) && (element_index < seq_len))
|
112 |
+
{
|
113 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
118 |
+
#pragma unroll
|
119 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
120 |
+
{
|
121 |
+
input_t acc = 0.0;
|
122 |
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
123 |
+
#pragma unroll
|
124 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
125 |
+
{
|
126 |
+
if ((element_index + f_idx) >= 0)
|
127 |
+
{
|
128 |
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
129 |
+
}
|
130 |
+
}
|
131 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
132 |
+
}
|
133 |
+
|
134 |
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
135 |
+
double no_div_by_zero = 0.000000001;
|
136 |
+
#pragma unroll
|
137 |
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
138 |
+
{
|
139 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
140 |
+
}
|
141 |
+
|
142 |
+
// Apply replication padding before downsampling conv from intermediates
|
143 |
+
#pragma unroll
|
144 |
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
145 |
+
{
|
146 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
147 |
+
}
|
148 |
+
#pragma unroll
|
149 |
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
150 |
+
{
|
151 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
155 |
+
#pragma unroll
|
156 |
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
157 |
+
{
|
158 |
+
input_t acc = 0.0;
|
159 |
+
#pragma unroll
|
160 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
161 |
+
{
|
162 |
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
163 |
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
164 |
+
}
|
165 |
+
output[it] = acc;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Write output to dst
|
169 |
+
#pragma unroll
|
170 |
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
171 |
+
{
|
172 |
+
int element_index = seq_offset + it;
|
173 |
+
if (element_index < seq_len)
|
174 |
+
{
|
175 |
+
dst[it] = output[it];
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
}
|
180 |
+
|
181 |
+
template <typename input_t, typename output_t, typename acc_t>
|
182 |
+
void dispatch_anti_alias_activation_forward(
|
183 |
+
output_t *dst,
|
184 |
+
const input_t *src,
|
185 |
+
const input_t *up_ftr,
|
186 |
+
const input_t *down_ftr,
|
187 |
+
const input_t *alpha,
|
188 |
+
const input_t *beta,
|
189 |
+
int batch_size,
|
190 |
+
int channels,
|
191 |
+
int seq_len)
|
192 |
+
{
|
193 |
+
if (seq_len == 0)
|
194 |
+
{
|
195 |
+
return;
|
196 |
+
}
|
197 |
+
else
|
198 |
+
{
|
199 |
+
// Use 128 threads per block to maximimize gpu utilization
|
200 |
+
constexpr int threads_per_block = 128;
|
201 |
+
constexpr int seq_len_per_block = 4096;
|
202 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
203 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
204 |
+
dim3 threads(threads_per_block, 1, 1);
|
205 |
+
|
206 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
207 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
213 |
+
{
|
214 |
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
215 |
+
const int batches = input.size(0);
|
216 |
+
const int channels = input.size(1);
|
217 |
+
const int seq_len = input.size(2);
|
218 |
+
|
219 |
+
// Output
|
220 |
+
auto act_options = input.options().requires_grad(false);
|
221 |
+
|
222 |
+
torch::Tensor anti_alias_activation_results =
|
223 |
+
torch::empty({batches, channels, seq_len}, act_options);
|
224 |
+
|
225 |
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
226 |
+
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
227 |
+
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
228 |
+
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
229 |
+
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
230 |
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
231 |
+
|
232 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
233 |
+
input.scalar_type(),
|
234 |
+
"dispatch anti alias activation_forward",
|
235 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
236 |
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
237 |
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
238 |
+
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
239 |
+
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
240 |
+
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
241 |
+
reinterpret_cast<const scalar_t *>(beta_ptr),
|
242 |
+
batches,
|
243 |
+
channels,
|
244 |
+
seq_len););
|
245 |
+
return anti_alias_activation_results;
|
246 |
+
}
|
alias_free_activation/cuda/compat.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
#ifndef TORCH_CHECK
|
22 |
+
#define TORCH_CHECK AT_CHECK
|
23 |
+
#endif
|
24 |
+
|
25 |
+
#ifdef VERSION_GE_1_3
|
26 |
+
#define DATA_PTR data_ptr
|
27 |
+
#else
|
28 |
+
#define DATA_PTR data
|
29 |
+
#endif
|
alias_free_activation/cuda/load.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
from torch.utils import cpp_extension
|
9 |
+
|
10 |
+
"""
|
11 |
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
12 |
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
13 |
+
"""
|
14 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
15 |
+
|
16 |
+
|
17 |
+
def load():
|
18 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
19 |
+
cc_flag = []
|
20 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
21 |
+
if int(bare_metal_major) >= 11:
|
22 |
+
cc_flag.append("-gencode")
|
23 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
24 |
+
|
25 |
+
# Build path
|
26 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
27 |
+
buildpath = srcpath / "build"
|
28 |
+
_create_build_dir(buildpath)
|
29 |
+
|
30 |
+
# Helper function to build the kernels.
|
31 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
32 |
+
return cpp_extension.load(
|
33 |
+
name=name,
|
34 |
+
sources=sources,
|
35 |
+
build_directory=buildpath,
|
36 |
+
extra_cflags=[
|
37 |
+
"-O3",
|
38 |
+
],
|
39 |
+
extra_cuda_cflags=[
|
40 |
+
"-O3",
|
41 |
+
"-gencode",
|
42 |
+
"arch=compute_70,code=sm_70",
|
43 |
+
"--use_fast_math",
|
44 |
+
]
|
45 |
+
+ extra_cuda_flags
|
46 |
+
+ cc_flag,
|
47 |
+
verbose=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
extra_cuda_flags = [
|
51 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
52 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
53 |
+
"--expt-relaxed-constexpr",
|
54 |
+
"--expt-extended-lambda",
|
55 |
+
]
|
56 |
+
|
57 |
+
sources = [
|
58 |
+
srcpath / "anti_alias_activation.cpp",
|
59 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
60 |
+
]
|
61 |
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
62 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
63 |
+
)
|
64 |
+
|
65 |
+
return anti_alias_activation_cuda
|
66 |
+
|
67 |
+
|
68 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
69 |
+
raw_output = subprocess.check_output(
|
70 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
71 |
+
)
|
72 |
+
output = raw_output.split()
|
73 |
+
release_idx = output.index("release") + 1
|
74 |
+
release = output[release_idx].split(".")
|
75 |
+
bare_metal_major = release[0]
|
76 |
+
bare_metal_minor = release[1][0]
|
77 |
+
|
78 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
79 |
+
|
80 |
+
|
81 |
+
def _create_build_dir(buildpath):
|
82 |
+
try:
|
83 |
+
os.mkdir(buildpath)
|
84 |
+
except OSError:
|
85 |
+
if not os.path.isdir(buildpath):
|
86 |
+
print(f"Creation of the build directory {buildpath} failed")
|
alias_free_activation/cuda/type_shim.h
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include "compat.h"
|
19 |
+
|
20 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
21 |
+
switch (TYPE) \
|
22 |
+
{ \
|
23 |
+
case at::ScalarType::Float: \
|
24 |
+
{ \
|
25 |
+
using scalar_t = float; \
|
26 |
+
__VA_ARGS__; \
|
27 |
+
break; \
|
28 |
+
} \
|
29 |
+
case at::ScalarType::Half: \
|
30 |
+
{ \
|
31 |
+
using scalar_t = at::Half; \
|
32 |
+
__VA_ARGS__; \
|
33 |
+
break; \
|
34 |
+
} \
|
35 |
+
case at::ScalarType::BFloat16: \
|
36 |
+
{ \
|
37 |
+
using scalar_t = at::BFloat16; \
|
38 |
+
__VA_ARGS__; \
|
39 |
+
break; \
|
40 |
+
} \
|
41 |
+
default: \
|
42 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
43 |
+
}
|
44 |
+
|
45 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
46 |
+
switch (TYPEIN) \
|
47 |
+
{ \
|
48 |
+
case at::ScalarType::Float: \
|
49 |
+
{ \
|
50 |
+
using scalar_t_in = float; \
|
51 |
+
switch (TYPEOUT) \
|
52 |
+
{ \
|
53 |
+
case at::ScalarType::Float: \
|
54 |
+
{ \
|
55 |
+
using scalar_t_out = float; \
|
56 |
+
__VA_ARGS__; \
|
57 |
+
break; \
|
58 |
+
} \
|
59 |
+
case at::ScalarType::Half: \
|
60 |
+
{ \
|
61 |
+
using scalar_t_out = at::Half; \
|
62 |
+
__VA_ARGS__; \
|
63 |
+
break; \
|
64 |
+
} \
|
65 |
+
case at::ScalarType::BFloat16: \
|
66 |
+
{ \
|
67 |
+
using scalar_t_out = at::BFloat16; \
|
68 |
+
__VA_ARGS__; \
|
69 |
+
break; \
|
70 |
+
} \
|
71 |
+
default: \
|
72 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
73 |
+
} \
|
74 |
+
break; \
|
75 |
+
} \
|
76 |
+
case at::ScalarType::Half: \
|
77 |
+
{ \
|
78 |
+
using scalar_t_in = at::Half; \
|
79 |
+
using scalar_t_out = at::Half; \
|
80 |
+
__VA_ARGS__; \
|
81 |
+
break; \
|
82 |
+
} \
|
83 |
+
case at::ScalarType::BFloat16: \
|
84 |
+
{ \
|
85 |
+
using scalar_t_in = at::BFloat16; \
|
86 |
+
using scalar_t_out = at::BFloat16; \
|
87 |
+
__VA_ARGS__; \
|
88 |
+
break; \
|
89 |
+
} \
|
90 |
+
default: \
|
91 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
92 |
+
}
|
alias_free_activation/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
alias_free_activation/torch/act.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.up_ratio = up_ratio
|
19 |
+
self.down_ratio = down_ratio
|
20 |
+
self.act = activation
|
21 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
22 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
23 |
+
|
24 |
+
# x: [B,C,T]
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.upsample(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.downsample(x)
|
29 |
+
|
30 |
+
return x
|
alias_free_activation/torch/filter.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if "sinc" in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(
|
21 |
+
x == 0,
|
22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
+
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(
|
31 |
+
cutoff, half_width, kernel_size
|
32 |
+
): # return filter [1,1,kernel_size]
|
33 |
+
even = kernel_size % 2 == 0
|
34 |
+
half_size = kernel_size // 2
|
35 |
+
|
36 |
+
# For kaiser window
|
37 |
+
delta_f = 4 * half_width
|
38 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
39 |
+
if A > 50.0:
|
40 |
+
beta = 0.1102 * (A - 8.7)
|
41 |
+
elif A >= 21.0:
|
42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
43 |
+
else:
|
44 |
+
beta = 0.0
|
45 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
46 |
+
|
47 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
48 |
+
if even:
|
49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
50 |
+
else:
|
51 |
+
time = torch.arange(kernel_size) - half_size
|
52 |
+
if cutoff == 0:
|
53 |
+
filter_ = torch.zeros_like(time)
|
54 |
+
else:
|
55 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
56 |
+
"""
|
57 |
+
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
58 |
+
"""
|
59 |
+
filter_ /= filter_.sum()
|
60 |
+
filter = filter_.view(1, 1, kernel_size)
|
61 |
+
|
62 |
+
return filter
|
63 |
+
|
64 |
+
|
65 |
+
class LowPassFilter1d(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
cutoff=0.5,
|
69 |
+
half_width=0.6,
|
70 |
+
stride: int = 1,
|
71 |
+
padding: bool = True,
|
72 |
+
padding_mode: str = "replicate",
|
73 |
+
kernel_size: int = 12,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
if cutoff < -0.0:
|
80 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
81 |
+
if cutoff > 0.5:
|
82 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.even = kernel_size % 2 == 0
|
85 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
86 |
+
self.pad_right = kernel_size // 2
|
87 |
+
self.stride = stride
|
88 |
+
self.padding = padding
|
89 |
+
self.padding_mode = padding_mode
|
90 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
91 |
+
self.register_buffer("filter", filter)
|
92 |
+
|
93 |
+
# Input [B, C, T]
|
94 |
+
def forward(self, x):
|
95 |
+
_, C, _ = x.shape
|
96 |
+
|
97 |
+
if self.padding:
|
98 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
99 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
100 |
+
|
101 |
+
return out
|
alias_free_activation/torch/resample.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from alias_free_activation.torch.filter import LowPassFilter1d
|
7 |
+
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = (
|
15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
)
|
17 |
+
self.stride = ratio
|
18 |
+
self.pad = self.kernel_size // ratio - 1
|
19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
20 |
+
self.pad_right = (
|
21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
22 |
+
)
|
23 |
+
filter = kaiser_sinc_filter1d(
|
24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
25 |
+
)
|
26 |
+
self.register_buffer("filter", filter)
|
27 |
+
|
28 |
+
# x: [B, C, T]
|
29 |
+
def forward(self, x):
|
30 |
+
_, C, _ = x.shape
|
31 |
+
|
32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
33 |
+
x = self.ratio * F.conv_transpose1d(
|
34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
35 |
+
)
|
36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class DownSample1d(nn.Module):
|
42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self.ratio = ratio
|
45 |
+
self.kernel_size = (
|
46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
47 |
+
)
|
48 |
+
self.lowpass = LowPassFilter1d(
|
49 |
+
cutoff=0.5 / ratio,
|
50 |
+
half_width=0.6 / ratio,
|
51 |
+
stride=ratio,
|
52 |
+
kernel_size=self.kernel_size,
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
xx = self.lowpass(x)
|
57 |
+
|
58 |
+
return xx
|
alignment.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
from numba import jit
|
24 |
+
|
25 |
+
|
26 |
+
@jit(nopython=True)
|
27 |
+
def mas_width1(attn_map):
|
28 |
+
"""mas with hardcoded width=1"""
|
29 |
+
# assumes mel x text
|
30 |
+
opt = np.zeros_like(attn_map)
|
31 |
+
attn_map = np.log(attn_map)
|
32 |
+
attn_map[0, 1:] = -np.inf
|
33 |
+
log_p = np.zeros_like(attn_map)
|
34 |
+
log_p[0, :] = attn_map[0, :]
|
35 |
+
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
36 |
+
for i in range(1, attn_map.shape[0]):
|
37 |
+
for j in range(attn_map.shape[1]): # for each text dim
|
38 |
+
prev_log = log_p[i - 1, j]
|
39 |
+
prev_j = j
|
40 |
+
|
41 |
+
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
42 |
+
prev_log = log_p[i - 1, j - 1]
|
43 |
+
prev_j = j - 1
|
44 |
+
|
45 |
+
log_p[i, j] = attn_map[i, j] + prev_log
|
46 |
+
prev_ind[i, j] = prev_j
|
47 |
+
|
48 |
+
# now backtrack
|
49 |
+
curr_text_idx = attn_map.shape[1] - 1
|
50 |
+
for i in range(attn_map.shape[0] - 1, -1, -1):
|
51 |
+
opt[i, curr_text_idx] = 1
|
52 |
+
curr_text_idx = prev_ind[i, curr_text_idx]
|
53 |
+
opt[0, curr_text_idx] = 1
|
54 |
+
return opt
|
app.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
|
6 |
+
from importlib.metadata import version
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
use_zerogpu = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import spaces # it's for ZeroGPU
|
15 |
+
use_zerogpu = True
|
16 |
+
print("ZeroGPU is available, changing inference call.")
|
17 |
+
except ImportError:
|
18 |
+
print("ZeroGPU is not available, skipping...")
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torchaudio
|
24 |
+
|
25 |
+
# BigVGAN
|
26 |
+
import bigvgan
|
27 |
+
|
28 |
+
# RAD-TTS code
|
29 |
+
from radtts import RADTTS
|
30 |
+
from data import Data
|
31 |
+
from common import update_params
|
32 |
+
|
33 |
+
use_cuda = torch.cuda.is_available()
|
34 |
+
|
35 |
+
if use_cuda:
|
36 |
+
print("CUDA is available, setting correct inference_device variable.")
|
37 |
+
device = "cuda"
|
38 |
+
else:
|
39 |
+
device = "cpu"
|
40 |
+
|
41 |
+
|
42 |
+
def download_file_from_repo(
|
43 |
+
repo_id: str,
|
44 |
+
filename: str,
|
45 |
+
local_dir: str = ".",
|
46 |
+
repo_type: str = "model",
|
47 |
+
) -> str:
|
48 |
+
try:
|
49 |
+
os.makedirs(local_dir, exist_ok=True)
|
50 |
+
|
51 |
+
file_path = hf_hub_download(
|
52 |
+
repo_id=repo_id,
|
53 |
+
filename=filename,
|
54 |
+
local_dir=local_dir,
|
55 |
+
cache_dir=None,
|
56 |
+
force_download=False,
|
57 |
+
repo_type=repo_type,
|
58 |
+
)
|
59 |
+
|
60 |
+
return file_path
|
61 |
+
except Exception as e:
|
62 |
+
raise Exception(f"An error occurred during download: {e}") from e
|
63 |
+
|
64 |
+
|
65 |
+
download_file_from_repo(
|
66 |
+
"Yehor/radtts-uk",
|
67 |
+
"radtts-pp-dap-model/model_dap_84000.pt",
|
68 |
+
"./models/",
|
69 |
+
)
|
70 |
+
|
71 |
+
# Init the model
|
72 |
+
seed = 1234
|
73 |
+
|
74 |
+
config = "configs/radtts-pp-dap-model.json"
|
75 |
+
radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
|
76 |
+
|
77 |
+
params = []
|
78 |
+
|
79 |
+
# Load the config
|
80 |
+
with open(config) as f:
|
81 |
+
data = f.read()
|
82 |
+
|
83 |
+
config = json.loads(data)
|
84 |
+
update_params(config, params)
|
85 |
+
|
86 |
+
data_config = config["data_config"]
|
87 |
+
model_config = config["model_config"]
|
88 |
+
|
89 |
+
# Seed
|
90 |
+
torch.manual_seed(seed)
|
91 |
+
torch.cuda.manual_seed(seed)
|
92 |
+
|
93 |
+
# Load vocoder
|
94 |
+
vocoder_model = bigvgan.BigVGAN.from_pretrained(
|
95 |
+
"nvidia/bigvgan_v2_22khz_80band_fmax8k_256x", use_cuda_kernel=False
|
96 |
+
)
|
97 |
+
vocoder_model.remove_weight_norm()
|
98 |
+
vocoder_model = vocoder_model.eval().to(device)
|
99 |
+
|
100 |
+
# Load RAD-TTS
|
101 |
+
if use_cuda:
|
102 |
+
radtts = RADTTS(**model_config).cuda()
|
103 |
+
else:
|
104 |
+
radtts = RADTTS(**model_config)
|
105 |
+
|
106 |
+
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
|
107 |
+
|
108 |
+
checkpoint_dict = torch.load(radtts_path, map_location="cpu") # todo: CPU?
|
109 |
+
radtts.load_state_dict(checkpoint_dict["state_dict"], strict=False)
|
110 |
+
radtts.eval()
|
111 |
+
|
112 |
+
print(f"Loaded checkpoint '{radtts_path}')")
|
113 |
+
|
114 |
+
ignore_keys = ["training_files", "validation_files"]
|
115 |
+
trainset = Data(
|
116 |
+
data_config["training_files"],
|
117 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
118 |
+
)
|
119 |
+
|
120 |
+
# Config
|
121 |
+
concurrency_limit = 5
|
122 |
+
|
123 |
+
title = "RAD-TTS++ Ukrainian"
|
124 |
+
|
125 |
+
# https://www.tablesgenerator.com/markdown_tables
|
126 |
+
authors_table = """
|
127 |
+
## Authors
|
128 |
+
|
129 |
+
Follow them on social networks and **contact** if you need any help or have any questions:
|
130 |
+
|
131 |
+
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|
132 |
+
|-------------------------------------------------------------------------------------------------|
|
133 |
+
| https://t.me/smlkw in Telegram |
|
134 |
+
| https://x.com/yehor_smoliakov at X |
|
135 |
+
| https://github.com/egorsmkv at GitHub |
|
136 |
+
| https://huggingface.co/Yehor at Hugging Face |
|
137 |
+
| or use [email protected] |
|
138 |
+
""".strip()
|
139 |
+
|
140 |
+
description_head = f"""
|
141 |
+
# {title}
|
142 |
+
|
143 |
+
## Overview
|
144 |
+
|
145 |
+
Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and [BigVGAN v2](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) with 22050 Hz.
|
146 |
+
""".strip()
|
147 |
+
|
148 |
+
description_foot = f"""
|
149 |
+
{authors_table}
|
150 |
+
""".strip()
|
151 |
+
|
152 |
+
tech_env = f"""
|
153 |
+
#### Environment
|
154 |
+
|
155 |
+
- Python: {sys.version}
|
156 |
+
""".strip()
|
157 |
+
|
158 |
+
tech_libraries = f"""
|
159 |
+
#### Libraries
|
160 |
+
|
161 |
+
- gradio: {version("gradio")}
|
162 |
+
- torch: {version("torch")}
|
163 |
+
- scipy: {version("scipy")}
|
164 |
+
- numba: {version("numba")}
|
165 |
+
- librosa: {version("librosa")}
|
166 |
+
- unidecode: {version("unidecode")}
|
167 |
+
- inflect: {version("inflect")}
|
168 |
+
""".strip()
|
169 |
+
|
170 |
+
|
171 |
+
class VoiceOption(Enum):
|
172 |
+
Tetiana = "Tetiana (female) 👩"
|
173 |
+
Mykyta = "Mykyta (male) 👨"
|
174 |
+
Lada = "Lada (female) 👩"
|
175 |
+
|
176 |
+
|
177 |
+
voice_mapping = {
|
178 |
+
VoiceOption.Tetiana.value: "tetiana",
|
179 |
+
VoiceOption.Mykyta.value: "mykyta",
|
180 |
+
VoiceOption.Lada.value: "lada",
|
181 |
+
}
|
182 |
+
|
183 |
+
|
184 |
+
examples = [
|
185 |
+
[
|
186 |
+
"Прокинувся ґазда вранці. Пішов, вичистив з-під коня, вичистив з-під бика, вичистив з-під овечок, вибрав молодняк, відніс його набік.",
|
187 |
+
VoiceOption.Mykyta.value,
|
188 |
+
],
|
189 |
+
[
|
190 |
+
"Пішов взяв сіна, дав корові. Пішов взяв сіна, дав бикові. Ячміню коняці насипав. Зайшов почистив корову, зайшов почистив бика, зайшов почистив коня, за яйця його мацнув.",
|
191 |
+
VoiceOption.Lada.value,
|
192 |
+
],
|
193 |
+
[
|
194 |
+
"Кінь ногою здригнув, на хазяїна ласкавим оком подивився. Тоді дядько пішов відкрив курей, гусей, качок, повиносив їм зерна, огірків нарізаних, нагодував. Коли чує – з хати дружина кличе. Зайшов. Дітки повмивані, сидять за столом, всі чекають тата. Взяв він ложку, перехрестив дітей, перехрестив лоба, почали снідати. Поснідали, він дістав пряників, роздав дітям. Діти зібралися, пішли в школу. Дядько вийшов, сів на призьбі, взяв сапку, почав мантачити. Мантачив-мантачив, коли – жінка виходить. Він їй ту сапку дає, ласкаво за сраку вщипнув, жінка до нього лагідно всміхнулася, пішла на город – сапати. Коли – йде пастух і товар кличе в череду. Повідмикав дядько овечок, коровку, бика, коня, все відпустив. Сів попри хати, дістав табАку, відірвав шмат газети, насипав, наслинив собі гарну таку цигарку. Благодать божа – і сонечко вже здійнялося над деревами. Дядько встромив цигарку в рота, дістав сірники, тільки чиркати – коли раптом з хати: Доброе утро! Московское время – шесть часов утра! Витяг дядько цигарку с рота, сплюнув набік, і сам собі каже: Ана маєш. Прокинулись, бляді!",
|
195 |
+
VoiceOption.Tetiana.value,
|
196 |
+
],
|
197 |
+
]
|
198 |
+
|
199 |
+
|
200 |
+
def inference(text, voice):
|
201 |
+
if not text:
|
202 |
+
raise gr.Error("Please paste your text.")
|
203 |
+
|
204 |
+
gr.Info("Starting...", duration=0.5)
|
205 |
+
|
206 |
+
speaker = voice_mapping[voice]
|
207 |
+
speaker = speaker_text = speaker_attributes = speaker
|
208 |
+
|
209 |
+
n_takes = 1
|
210 |
+
|
211 |
+
sigma = 0.8 # sampling sigma for decoder
|
212 |
+
sigma_tkndur = 0.666 # sampling sigma for duration
|
213 |
+
sigma_f0 = 1.0 # sampling sigma for f0
|
214 |
+
sigma_energy = 1.0 # sampling sigma for energy avg
|
215 |
+
|
216 |
+
token_dur_scaling = 1.0
|
217 |
+
|
218 |
+
f0_mean = 0
|
219 |
+
f0_std = 0
|
220 |
+
energy_mean = 0
|
221 |
+
energy_std = 0
|
222 |
+
|
223 |
+
if use_cuda:
|
224 |
+
speaker_id = trainset.get_speaker_id(speaker).cuda()
|
225 |
+
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
226 |
+
|
227 |
+
if speaker_text is not None:
|
228 |
+
speaker_id_text = trainset.get_speaker_id(speaker_text).cuda()
|
229 |
+
|
230 |
+
if speaker_attributes is not None:
|
231 |
+
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).cuda()
|
232 |
+
|
233 |
+
tensor_text = trainset.get_text(text).cuda()[None]
|
234 |
+
else:
|
235 |
+
speaker_id = trainset.get_speaker_id(speaker)
|
236 |
+
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
237 |
+
|
238 |
+
if speaker_text is not None:
|
239 |
+
speaker_id_text = trainset.get_speaker_id(speaker_text)
|
240 |
+
|
241 |
+
if speaker_attributes is not None:
|
242 |
+
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes)
|
243 |
+
|
244 |
+
tensor_text = trainset.get_text(text)[None]
|
245 |
+
|
246 |
+
inference_start = time.time()
|
247 |
+
|
248 |
+
for take in range(n_takes):
|
249 |
+
with torch.autocast(device, enabled=False):
|
250 |
+
with torch.inference_mode():
|
251 |
+
outputs = radtts.infer(
|
252 |
+
speaker_id,
|
253 |
+
tensor_text,
|
254 |
+
sigma,
|
255 |
+
sigma_tkndur,
|
256 |
+
sigma_f0,
|
257 |
+
sigma_energy,
|
258 |
+
token_dur_scaling,
|
259 |
+
token_duration_max=100,
|
260 |
+
speaker_id_text=speaker_id_text,
|
261 |
+
speaker_id_attributes=speaker_id_attributes,
|
262 |
+
f0_mean=f0_mean,
|
263 |
+
f0_std=f0_std,
|
264 |
+
energy_mean=energy_mean,
|
265 |
+
energy_std=energy_std,
|
266 |
+
use_cuda=use_cuda,
|
267 |
+
)
|
268 |
+
|
269 |
+
mel = outputs["mel"]
|
270 |
+
|
271 |
+
gr.Info(
|
272 |
+
"Synthesized MEL spectrogram, converting to WAVE.", duration=0.5
|
273 |
+
)
|
274 |
+
|
275 |
+
wav_gen = vocoder_model(mel)
|
276 |
+
wav_gen_float = wav_gen.squeeze(0).cpu()
|
277 |
+
|
278 |
+
torchaudio.save("audio.wav", wav_gen_float, 22_050, encoding="PCM_S")
|
279 |
+
|
280 |
+
duration = len(wav_gen_float[0]) / 22_050
|
281 |
+
|
282 |
+
elapsed_time = time.time() - inference_start
|
283 |
+
rtf = elapsed_time / duration
|
284 |
+
|
285 |
+
speed_ratio = duration / elapsed_time
|
286 |
+
speech_rate = len(text.split(" ")) / duration
|
287 |
+
|
288 |
+
rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second."
|
289 |
+
|
290 |
+
gr.Success("Finished!", duration=0.5)
|
291 |
+
|
292 |
+
return [gr.Audio("audio.wav"), rtf_value]
|
293 |
+
|
294 |
+
|
295 |
+
try:
|
296 |
+
@spaces.GPU
|
297 |
+
def inference_zerogpu(text, voice):
|
298 |
+
return inference(text, voice)
|
299 |
+
except NameError:
|
300 |
+
print("ZeroGPU is not available, skipping...")
|
301 |
+
|
302 |
+
|
303 |
+
def inference_cpu(text, voice):
|
304 |
+
return inference(text, voice)
|
305 |
+
|
306 |
+
|
307 |
+
demo = gr.Blocks(
|
308 |
+
title=title,
|
309 |
+
analytics_enabled=False,
|
310 |
+
theme=gr.themes.Base(),
|
311 |
+
)
|
312 |
+
|
313 |
+
with demo:
|
314 |
+
gr.Markdown(description_head)
|
315 |
+
|
316 |
+
gr.Markdown("## Usage")
|
317 |
+
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column():
|
320 |
+
audio = gr.Audio(label="Synthesized audio")
|
321 |
+
rtf = gr.Markdown(
|
322 |
+
label="Real-Time Factor",
|
323 |
+
value="Here you will see how fast the model and the speaker is.",
|
324 |
+
)
|
325 |
+
|
326 |
+
with gr.Row():
|
327 |
+
with gr.Column():
|
328 |
+
text = gr.Text(
|
329 |
+
label="Text",
|
330 |
+
value="Сл+ава Укра+їні! — українське вітання, національне гасло.",
|
331 |
+
)
|
332 |
+
voice = gr.Radio(
|
333 |
+
label="Voice",
|
334 |
+
choices=[option.value for option in VoiceOption],
|
335 |
+
value=VoiceOption.Tetiana.value,
|
336 |
+
)
|
337 |
+
|
338 |
+
gr.Button("Run").click(
|
339 |
+
inference_zerogpu if use_zerogpu else inference_cpu,
|
340 |
+
concurrency_limit=concurrency_limit,
|
341 |
+
inputs=[text, voice],
|
342 |
+
outputs=[audio, rtf],
|
343 |
+
)
|
344 |
+
|
345 |
+
with gr.Row():
|
346 |
+
gr.Examples(
|
347 |
+
label="Choose an example",
|
348 |
+
inputs=[text, voice],
|
349 |
+
examples=examples,
|
350 |
+
)
|
351 |
+
|
352 |
+
gr.Markdown(description_foot)
|
353 |
+
|
354 |
+
gr.Markdown("### Gradio app uses:")
|
355 |
+
gr.Markdown(tech_env)
|
356 |
+
gr.Markdown(tech_libraries)
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
demo.queue()
|
360 |
+
demo.launch()
|
attribute_prediction_model.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
from common import ConvNorm, Invertible1x1Conv
|
24 |
+
from common import AffineTransformationLayer, SplineTransformationLayer
|
25 |
+
from common import ConvLSTMLinear
|
26 |
+
from transformer import FFTransformer
|
27 |
+
from autoregressive_flow import AR_Step, AR_Back_Step
|
28 |
+
|
29 |
+
|
30 |
+
def get_attribute_prediction_model(config):
|
31 |
+
name = config["name"]
|
32 |
+
hparams = config["hparams"]
|
33 |
+
if name == "dap":
|
34 |
+
model = DAP(**hparams)
|
35 |
+
elif name == "bgap":
|
36 |
+
model = BGAP(**hparams)
|
37 |
+
elif name == "agap":
|
38 |
+
model = AGAP(**hparams)
|
39 |
+
else:
|
40 |
+
raise Exception("{} model is not supported".format(name))
|
41 |
+
|
42 |
+
return model
|
43 |
+
|
44 |
+
|
45 |
+
class AttributeProcessing:
|
46 |
+
def __init__(self, take_log_of_input=False):
|
47 |
+
super(AttributeProcessing).__init__()
|
48 |
+
self.take_log_of_input = take_log_of_input
|
49 |
+
|
50 |
+
def normalize(self, x):
|
51 |
+
if self.take_log_of_input:
|
52 |
+
x = torch.log(x + 1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
def denormalize(self, x):
|
56 |
+
if self.take_log_of_input:
|
57 |
+
x = torch.exp(x) - 1
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class BottleneckLayerLayer(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
in_dim,
|
65 |
+
reduction_factor,
|
66 |
+
norm="weightnorm",
|
67 |
+
non_linearity="relu",
|
68 |
+
kernel_size=3,
|
69 |
+
use_partial_padding=False,
|
70 |
+
):
|
71 |
+
super(BottleneckLayerLayer, self).__init__()
|
72 |
+
|
73 |
+
self.reduction_factor = reduction_factor
|
74 |
+
reduced_dim = int(in_dim / reduction_factor)
|
75 |
+
self.out_dim = reduced_dim
|
76 |
+
if self.reduction_factor > 1:
|
77 |
+
fn = ConvNorm(
|
78 |
+
in_dim,
|
79 |
+
reduced_dim,
|
80 |
+
kernel_size=kernel_size,
|
81 |
+
use_weight_norm=(norm == "weightnorm"),
|
82 |
+
)
|
83 |
+
if norm == "instancenorm":
|
84 |
+
fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True))
|
85 |
+
|
86 |
+
self.projection_fn = fn
|
87 |
+
self.non_linearity = nn.ReLU()
|
88 |
+
if non_linearity == "leakyrelu":
|
89 |
+
self.non_linearity = nn.LeakyReLU()
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if self.reduction_factor > 1:
|
93 |
+
x = self.projection_fn(x)
|
94 |
+
x = self.non_linearity(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class DAP(nn.Module):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
n_speaker_dim,
|
102 |
+
bottleneck_hparams,
|
103 |
+
take_log_of_input,
|
104 |
+
arch_hparams,
|
105 |
+
use_transformer=False,
|
106 |
+
):
|
107 |
+
super(DAP, self).__init__()
|
108 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
109 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
110 |
+
|
111 |
+
arch_hparams["in_dim"] = self.bottleneck_layer.out_dim + n_speaker_dim
|
112 |
+
if use_transformer:
|
113 |
+
self.feat_pred_fn = FFTransformer(**arch_hparams)
|
114 |
+
else:
|
115 |
+
self.feat_pred_fn = ConvLSTMLinear(**arch_hparams)
|
116 |
+
|
117 |
+
def forward(self, txt_enc, spk_emb, x, lens):
|
118 |
+
if x is not None:
|
119 |
+
x = self.attribute_processing.normalize(x)
|
120 |
+
|
121 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
122 |
+
spk_emb_expanded = spk_emb[..., None].expand(-1, -1, txt_enc.shape[2])
|
123 |
+
context = torch.cat((txt_enc, spk_emb_expanded), 1)
|
124 |
+
|
125 |
+
x_hat = self.feat_pred_fn(context, lens)
|
126 |
+
|
127 |
+
outputs = {"x_hat": x_hat, "x": x}
|
128 |
+
return outputs
|
129 |
+
|
130 |
+
def infer(self, z, txt_enc, spk_emb, lens=None):
|
131 |
+
x_hat = self.forward(txt_enc, spk_emb, x=None, lens=lens)["x_hat"]
|
132 |
+
x_hat = self.attribute_processing.denormalize(x_hat)
|
133 |
+
return x_hat
|
134 |
+
|
135 |
+
|
136 |
+
class BGAP(torch.nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
n_in_dim,
|
140 |
+
n_speaker_dim,
|
141 |
+
bottleneck_hparams,
|
142 |
+
n_flows,
|
143 |
+
n_group_size,
|
144 |
+
n_layers,
|
145 |
+
with_dilation,
|
146 |
+
kernel_size,
|
147 |
+
scaling_fn,
|
148 |
+
take_log_of_input=False,
|
149 |
+
n_channels=1024,
|
150 |
+
use_quadratic=False,
|
151 |
+
n_bins=8,
|
152 |
+
n_spline_steps=2,
|
153 |
+
):
|
154 |
+
super(BGAP, self).__init__()
|
155 |
+
# assert(n_group_size % 2 == 0)
|
156 |
+
self.n_flows = n_flows
|
157 |
+
self.n_group_size = n_group_size
|
158 |
+
self.transforms = torch.nn.ModuleList()
|
159 |
+
self.convinv = torch.nn.ModuleList()
|
160 |
+
self.n_speaker_dim = n_speaker_dim
|
161 |
+
self.scaling_fn = scaling_fn
|
162 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
163 |
+
self.n_spline_steps = n_spline_steps
|
164 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
165 |
+
n_txt_reduced_dim = self.bottleneck_layer.out_dim
|
166 |
+
context_dim = n_txt_reduced_dim * n_group_size + n_speaker_dim
|
167 |
+
|
168 |
+
if self.n_group_size > 1:
|
169 |
+
self.unfold_params = {
|
170 |
+
"kernel_size": (n_group_size, 1),
|
171 |
+
"stride": n_group_size,
|
172 |
+
"padding": 0,
|
173 |
+
"dilation": 1,
|
174 |
+
}
|
175 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
176 |
+
|
177 |
+
for k in range(n_flows):
|
178 |
+
self.convinv.append(Invertible1x1Conv(n_in_dim * n_group_size))
|
179 |
+
if k >= n_flows - self.n_spline_steps:
|
180 |
+
left = -3
|
181 |
+
right = 3
|
182 |
+
top = 3
|
183 |
+
bottom = -3
|
184 |
+
self.transforms.append(
|
185 |
+
SplineTransformationLayer(
|
186 |
+
n_in_dim * n_group_size,
|
187 |
+
context_dim,
|
188 |
+
n_layers,
|
189 |
+
with_dilation=with_dilation,
|
190 |
+
kernel_size=kernel_size,
|
191 |
+
scaling_fn=scaling_fn,
|
192 |
+
n_channels=n_channels,
|
193 |
+
top=top,
|
194 |
+
bottom=bottom,
|
195 |
+
left=left,
|
196 |
+
right=right,
|
197 |
+
use_quadratic=use_quadratic,
|
198 |
+
n_bins=n_bins,
|
199 |
+
)
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
self.transforms.append(
|
203 |
+
AffineTransformationLayer(
|
204 |
+
n_in_dim * n_group_size,
|
205 |
+
context_dim,
|
206 |
+
n_layers,
|
207 |
+
with_dilation=with_dilation,
|
208 |
+
kernel_size=kernel_size,
|
209 |
+
scaling_fn=scaling_fn,
|
210 |
+
affine_model="simple_conv",
|
211 |
+
n_channels=n_channels,
|
212 |
+
)
|
213 |
+
)
|
214 |
+
|
215 |
+
def fold(self, data):
|
216 |
+
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
|
217 |
+
the grouping or "squeeze" operation on input
|
218 |
+
|
219 |
+
Args:
|
220 |
+
data: B x C x T tensor of temporal data
|
221 |
+
"""
|
222 |
+
output_size = (data.shape[2] * self.n_group_size, 1)
|
223 |
+
data = nn.functional.fold(
|
224 |
+
data, output_size=output_size, **self.unfold_params
|
225 |
+
).squeeze(-1)
|
226 |
+
return data
|
227 |
+
|
228 |
+
def preprocess_context(self, txt_emb, speaker_vecs, std_scale=None):
|
229 |
+
if self.n_group_size > 1:
|
230 |
+
txt_emb = self.unfold(txt_emb[..., None])
|
231 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
|
232 |
+
context = torch.cat((txt_emb, speaker_vecs), 1)
|
233 |
+
return context
|
234 |
+
|
235 |
+
def forward(self, txt_enc, spk_emb, x, lens):
|
236 |
+
"""x<tensor>: duration or pitch or energy average"""
|
237 |
+
assert txt_enc.size(2) >= x.size(1)
|
238 |
+
if len(x.shape) == 2:
|
239 |
+
# add channel dimension
|
240 |
+
x = x[:, None]
|
241 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
242 |
+
|
243 |
+
# lens including padded values
|
244 |
+
lens_grouped = (lens // self.n_group_size).long()
|
245 |
+
context = self.preprocess_context(txt_enc, spk_emb)
|
246 |
+
x = self.unfold(x[..., None])
|
247 |
+
log_s_list, log_det_W_list = [], []
|
248 |
+
for k in range(self.n_flows):
|
249 |
+
x, log_s = self.transforms[k](x, context, seq_lens=lens_grouped)
|
250 |
+
x, log_det_W = self.convinv[k](x)
|
251 |
+
log_det_W_list.append(log_det_W)
|
252 |
+
log_s_list.append(log_s)
|
253 |
+
# prepare outputs
|
254 |
+
outputs = {"z": x, "log_det_W_list": log_det_W_list, "log_s_list": log_s_list}
|
255 |
+
|
256 |
+
return outputs
|
257 |
+
|
258 |
+
def infer(self, z, txt_enc, spk_emb, seq_lens):
|
259 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
260 |
+
context = self.preprocess_context(txt_enc, spk_emb)
|
261 |
+
lens_grouped = (seq_lens // self.n_group_size).long()
|
262 |
+
z = self.unfold(z[..., None])
|
263 |
+
for k in reversed(range(self.n_flows)):
|
264 |
+
z = self.convinv[k](z, inverse=True)
|
265 |
+
z = self.transforms[k].forward(
|
266 |
+
z, context, inverse=True, seq_lens=lens_grouped
|
267 |
+
)
|
268 |
+
# z mapped to input domain
|
269 |
+
x_hat = self.fold(z)
|
270 |
+
# pad on the way out
|
271 |
+
return x_hat
|
272 |
+
|
273 |
+
|
274 |
+
class AGAP(torch.nn.Module):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
n_in_dim,
|
278 |
+
n_speaker_dim,
|
279 |
+
n_flows,
|
280 |
+
n_hidden,
|
281 |
+
n_lstm_layers,
|
282 |
+
bottleneck_hparams,
|
283 |
+
scaling_fn="exp",
|
284 |
+
take_log_of_input=False,
|
285 |
+
p_dropout=0.0,
|
286 |
+
setup="",
|
287 |
+
spline_flow_params=None,
|
288 |
+
n_group_size=1,
|
289 |
+
):
|
290 |
+
super(AGAP, self).__init__()
|
291 |
+
self.flows = torch.nn.ModuleList()
|
292 |
+
self.n_group_size = n_group_size
|
293 |
+
self.n_speaker_dim = n_speaker_dim
|
294 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
295 |
+
self.n_in_dim = n_in_dim
|
296 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
297 |
+
n_txt_reduced_dim = self.bottleneck_layer.out_dim
|
298 |
+
|
299 |
+
if self.n_group_size > 1:
|
300 |
+
self.unfold_params = {
|
301 |
+
"kernel_size": (n_group_size, 1),
|
302 |
+
"stride": n_group_size,
|
303 |
+
"padding": 0,
|
304 |
+
"dilation": 1,
|
305 |
+
}
|
306 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
307 |
+
|
308 |
+
if spline_flow_params is not None:
|
309 |
+
spline_flow_params["n_in_channels"] *= self.n_group_size
|
310 |
+
|
311 |
+
for i in range(n_flows):
|
312 |
+
if i % 2 == 0:
|
313 |
+
self.flows.append(
|
314 |
+
AR_Step(
|
315 |
+
n_in_dim * n_group_size,
|
316 |
+
n_speaker_dim,
|
317 |
+
n_txt_reduced_dim * n_group_size,
|
318 |
+
n_hidden,
|
319 |
+
n_lstm_layers,
|
320 |
+
scaling_fn,
|
321 |
+
spline_flow_params,
|
322 |
+
)
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
self.flows.append(
|
326 |
+
AR_Back_Step(
|
327 |
+
n_in_dim * n_group_size,
|
328 |
+
n_speaker_dim,
|
329 |
+
n_txt_reduced_dim * n_group_size,
|
330 |
+
n_hidden,
|
331 |
+
n_lstm_layers,
|
332 |
+
scaling_fn,
|
333 |
+
spline_flow_params,
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
def fold(self, data):
|
338 |
+
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
|
339 |
+
the grouping or "squeeze" operation on input
|
340 |
+
|
341 |
+
Args:
|
342 |
+
data: B x C x T tensor of temporal data
|
343 |
+
"""
|
344 |
+
output_size = (data.shape[2] * self.n_group_size, 1)
|
345 |
+
data = nn.functional.fold(
|
346 |
+
data, output_size=output_size, **self.unfold_params
|
347 |
+
).squeeze(-1)
|
348 |
+
return data
|
349 |
+
|
350 |
+
def preprocess_context(self, txt_emb, speaker_vecs):
|
351 |
+
if self.n_group_size > 1:
|
352 |
+
txt_emb = self.unfold(txt_emb[..., None])
|
353 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
|
354 |
+
context = torch.cat((txt_emb, speaker_vecs), 1)
|
355 |
+
return context
|
356 |
+
|
357 |
+
def forward(self, txt_emb, spk_emb, x, lens):
|
358 |
+
"""x<tensor>: duration or pitch or energy average"""
|
359 |
+
|
360 |
+
x = x[:, None] if len(x.shape) == 2 else x # add channel dimension
|
361 |
+
if self.n_group_size > 1:
|
362 |
+
x = self.unfold(x[..., None])
|
363 |
+
x = x.permute(2, 0, 1) # permute to time, batch, dims
|
364 |
+
x = self.attribute_processing.normalize(x)
|
365 |
+
|
366 |
+
txt_emb = self.bottleneck_layer(txt_emb)
|
367 |
+
context = self.preprocess_context(txt_emb, spk_emb)
|
368 |
+
context = context.permute(2, 0, 1) # permute to time, batch, dims
|
369 |
+
|
370 |
+
lens_groupped = (lens / self.n_group_size).long()
|
371 |
+
log_s_list = []
|
372 |
+
for i, flow in enumerate(self.flows):
|
373 |
+
x, log_s = flow(x, context, lens_groupped)
|
374 |
+
log_s_list.append(log_s)
|
375 |
+
|
376 |
+
x = x.permute(1, 2, 0) # x mapped to z
|
377 |
+
log_s_list = [log_s_elt.permute(1, 2, 0) for log_s_elt in log_s_list]
|
378 |
+
outputs = {"z": x, "log_s_list": log_s_list, "log_det_W_list": []}
|
379 |
+
return outputs
|
380 |
+
|
381 |
+
def infer(self, z, txt_emb, spk_emb, seq_lens=None):
|
382 |
+
if self.n_group_size > 1:
|
383 |
+
n_frames = z.shape[2]
|
384 |
+
z = self.unfold(z[..., None])
|
385 |
+
z = z.permute(2, 0, 1) # permute to time, batch, dims
|
386 |
+
|
387 |
+
txt_emb = self.bottleneck_layer(txt_emb)
|
388 |
+
context = self.preprocess_context(txt_emb, spk_emb)
|
389 |
+
context = context.permute(2, 0, 1) # permute to time, batch, dims
|
390 |
+
|
391 |
+
for i, flow in enumerate(reversed(self.flows)):
|
392 |
+
z = flow.infer(z, context)
|
393 |
+
|
394 |
+
x_hat = z.permute(1, 2, 0)
|
395 |
+
if self.n_group_size > 1:
|
396 |
+
x_hat = self.fold(x_hat)
|
397 |
+
if n_frames > x_hat.shape[2]:
|
398 |
+
m = nn.ReflectionPad1d((0, n_frames - x_hat.shape[2]))
|
399 |
+
x_hat = m(x_hat)
|
400 |
+
|
401 |
+
x_hat = self.attribute_processing.denormalize(x_hat)
|
402 |
+
return x_hat
|
audio_processing.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
from scipy.signal import get_window
|
24 |
+
from librosa.filters import mel as librosa_mel_fn
|
25 |
+
import librosa.util as librosa_util
|
26 |
+
|
27 |
+
|
28 |
+
def window_sumsquare(
|
29 |
+
window,
|
30 |
+
n_frames,
|
31 |
+
hop_length=200,
|
32 |
+
win_length=800,
|
33 |
+
n_fft=800,
|
34 |
+
dtype=np.float32,
|
35 |
+
norm=None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
# from librosa 0.6
|
39 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
40 |
+
|
41 |
+
This is used to estimate modulation effects induced by windowing
|
42 |
+
observations in short-time fourier transforms.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
window : string, tuple, number, callable, or list-like
|
47 |
+
Window specification, as in `get_window`
|
48 |
+
|
49 |
+
n_frames : int > 0
|
50 |
+
The number of analysis frames
|
51 |
+
|
52 |
+
hop_length : int > 0
|
53 |
+
The number of samples to advance between frames
|
54 |
+
|
55 |
+
win_length : [optional]
|
56 |
+
The length of the window function. By default, this matches `n_fft`.
|
57 |
+
|
58 |
+
n_fft : int > 0
|
59 |
+
The length of each analysis frame.
|
60 |
+
|
61 |
+
dtype : np.dtype
|
62 |
+
The data type of the output
|
63 |
+
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
67 |
+
The sum-squared envelope of the window function
|
68 |
+
"""
|
69 |
+
if win_length is None:
|
70 |
+
win_length = n_fft
|
71 |
+
|
72 |
+
n = n_fft + hop_length * (n_frames - 1)
|
73 |
+
x = np.zeros(n, dtype=dtype)
|
74 |
+
|
75 |
+
# Compute the squared window at the desired length
|
76 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
77 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
78 |
+
win_sq = librosa_util.pad_center(win_sq, size=n_fft)
|
79 |
+
|
80 |
+
# Fill the envelope
|
81 |
+
for i in range(n_frames):
|
82 |
+
sample = i * hop_length
|
83 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
88 |
+
"""
|
89 |
+
PARAMS
|
90 |
+
------
|
91 |
+
magnitudes: spectrogram magnitudes
|
92 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
93 |
+
"""
|
94 |
+
|
95 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
96 |
+
angles = angles.astype(np.float32)
|
97 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
98 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
99 |
+
|
100 |
+
for i in range(n_iters):
|
101 |
+
_, angles = stft_fn.transform(signal)
|
102 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
103 |
+
return signal
|
104 |
+
|
105 |
+
|
106 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
107 |
+
"""
|
108 |
+
PARAMS
|
109 |
+
------
|
110 |
+
C: compression factor
|
111 |
+
"""
|
112 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
113 |
+
|
114 |
+
|
115 |
+
def dynamic_range_decompression(x, C=1):
|
116 |
+
"""
|
117 |
+
PARAMS
|
118 |
+
------
|
119 |
+
C: compression factor used to compress
|
120 |
+
"""
|
121 |
+
return torch.exp(x) / C
|
122 |
+
|
123 |
+
|
124 |
+
class TacotronSTFT(torch.nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
filter_length=1024,
|
128 |
+
hop_length=256,
|
129 |
+
win_length=1024,
|
130 |
+
n_mel_channels=80,
|
131 |
+
sampling_rate=22050,
|
132 |
+
mel_fmin=0.0,
|
133 |
+
mel_fmax=None,
|
134 |
+
):
|
135 |
+
super(TacotronSTFT, self).__init__()
|
136 |
+
self.n_mel_channels = n_mel_channels
|
137 |
+
self.sampling_rate = sampling_rate
|
138 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
139 |
+
mel_basis = librosa_mel_fn(
|
140 |
+
sr=sampling_rate,
|
141 |
+
n_fft=filter_length,
|
142 |
+
n_mels=n_mel_channels,
|
143 |
+
fmin=mel_fmin,
|
144 |
+
fmax=mel_fmax,
|
145 |
+
)
|
146 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
147 |
+
self.register_buffer("mel_basis", mel_basis)
|
148 |
+
|
149 |
+
def spectral_normalize(self, magnitudes):
|
150 |
+
output = dynamic_range_compression(magnitudes)
|
151 |
+
return output
|
152 |
+
|
153 |
+
def spectral_de_normalize(self, magnitudes):
|
154 |
+
output = dynamic_range_decompression(magnitudes)
|
155 |
+
return output
|
156 |
+
|
157 |
+
def mel_spectrogram(self, y):
|
158 |
+
"""Computes mel-spectrograms from a batch of waves
|
159 |
+
PARAMS
|
160 |
+
------
|
161 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
162 |
+
|
163 |
+
RETURNS
|
164 |
+
-------
|
165 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
166 |
+
"""
|
167 |
+
assert torch.min(y.data) >= -1
|
168 |
+
assert torch.max(y.data) <= 1
|
169 |
+
|
170 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
171 |
+
magnitudes = magnitudes.data
|
172 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
173 |
+
mel_output = self.spectral_normalize(mel_output)
|
174 |
+
return mel_output
|
175 |
+
|
176 |
+
|
177 |
+
"""
|
178 |
+
BSD 3-Clause License
|
179 |
+
|
180 |
+
Copyright (c) 2017, Prem Seetharaman
|
181 |
+
All rights reserved.
|
182 |
+
|
183 |
+
* Redistribution and use in source and binary forms, with or without
|
184 |
+
modification, are permitted provided that the following conditions are met:
|
185 |
+
|
186 |
+
* Redistributions of source code must retain the above copyright notice,
|
187 |
+
this list of conditions and the following disclaimer.
|
188 |
+
|
189 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
190 |
+
list of conditions and the following disclaimer in the
|
191 |
+
documentation and/or other materials provided with the distribution.
|
192 |
+
|
193 |
+
* Neither the name of the copyright holder nor the names of its
|
194 |
+
contributors may be used to endorse or promote products derived from this
|
195 |
+
software without specific prior written permission.
|
196 |
+
|
197 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
198 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
199 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
200 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
201 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
202 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
203 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
204 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
205 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
206 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
207 |
+
"""
|
208 |
+
import torch.nn.functional as F
|
209 |
+
from torch.autograd import Variable
|
210 |
+
from scipy.signal import get_window
|
211 |
+
from librosa.util import pad_center, tiny
|
212 |
+
|
213 |
+
|
214 |
+
class STFT(torch.nn.Module):
|
215 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
219 |
+
):
|
220 |
+
super(STFT, self).__init__()
|
221 |
+
self.filter_length = filter_length
|
222 |
+
self.hop_length = hop_length
|
223 |
+
self.win_length = win_length
|
224 |
+
self.window = window
|
225 |
+
self.forward_transform = None
|
226 |
+
scale = self.filter_length / self.hop_length
|
227 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
228 |
+
|
229 |
+
cutoff = int((self.filter_length / 2 + 1))
|
230 |
+
fourier_basis = np.vstack(
|
231 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
232 |
+
)
|
233 |
+
|
234 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
235 |
+
inverse_basis = torch.FloatTensor(
|
236 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
237 |
+
)
|
238 |
+
|
239 |
+
if window is not None:
|
240 |
+
assert win_length >= filter_length
|
241 |
+
# get window and zero center pad it to filter_length
|
242 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
243 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
244 |
+
fft_window = torch.from_numpy(fft_window).float()
|
245 |
+
|
246 |
+
# window the bases
|
247 |
+
forward_basis *= fft_window
|
248 |
+
inverse_basis *= fft_window
|
249 |
+
|
250 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
251 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
252 |
+
|
253 |
+
def transform(self, input_data):
|
254 |
+
num_batches = input_data.size(0)
|
255 |
+
num_samples = input_data.size(1)
|
256 |
+
|
257 |
+
self.num_samples = num_samples
|
258 |
+
|
259 |
+
# similar to librosa, reflect-pad the input
|
260 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
261 |
+
input_data = F.pad(
|
262 |
+
input_data.unsqueeze(1),
|
263 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
264 |
+
mode="reflect",
|
265 |
+
)
|
266 |
+
input_data = input_data.squeeze(1)
|
267 |
+
|
268 |
+
forward_transform = F.conv1d(
|
269 |
+
input_data,
|
270 |
+
Variable(self.forward_basis, requires_grad=False),
|
271 |
+
stride=self.hop_length,
|
272 |
+
padding=0,
|
273 |
+
)
|
274 |
+
|
275 |
+
cutoff = int((self.filter_length / 2) + 1)
|
276 |
+
real_part = forward_transform[:, :cutoff, :]
|
277 |
+
imag_part = forward_transform[:, cutoff:, :]
|
278 |
+
|
279 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
280 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
281 |
+
|
282 |
+
return magnitude, phase
|
283 |
+
|
284 |
+
def inverse(self, magnitude, phase):
|
285 |
+
recombine_magnitude_phase = torch.cat(
|
286 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
287 |
+
)
|
288 |
+
|
289 |
+
inverse_transform = F.conv_transpose1d(
|
290 |
+
recombine_magnitude_phase,
|
291 |
+
Variable(self.inverse_basis, requires_grad=False),
|
292 |
+
stride=self.hop_length,
|
293 |
+
padding=0,
|
294 |
+
)
|
295 |
+
|
296 |
+
if self.window is not None:
|
297 |
+
window_sum = window_sumsquare(
|
298 |
+
self.window,
|
299 |
+
magnitude.size(-1),
|
300 |
+
hop_length=self.hop_length,
|
301 |
+
win_length=self.win_length,
|
302 |
+
n_fft=self.filter_length,
|
303 |
+
dtype=np.float32,
|
304 |
+
)
|
305 |
+
# remove modulation effects
|
306 |
+
approx_nonzero_indices = torch.from_numpy(
|
307 |
+
np.where(window_sum > tiny(window_sum))[0]
|
308 |
+
)
|
309 |
+
window_sum = torch.autograd.Variable(
|
310 |
+
torch.from_numpy(window_sum), requires_grad=False
|
311 |
+
)
|
312 |
+
window_sum = window_sum.to(magnitude.device)
|
313 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
314 |
+
approx_nonzero_indices
|
315 |
+
]
|
316 |
+
|
317 |
+
# scale by hop ratio
|
318 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
319 |
+
|
320 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
321 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
322 |
+
|
323 |
+
return inverse_transform
|
324 |
+
|
325 |
+
def forward(self, input_data):
|
326 |
+
self.magnitude, self.phase = self.transform(input_data)
|
327 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
328 |
+
return reconstruction
|
autoregressive_flow.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# AR_Back_Step and AR_Step based on implementation from
|
23 |
+
# https://github.com/NVIDIA/flowtron/blob/master/flowtron.py
|
24 |
+
# Original license text:
|
25 |
+
###############################################################################
|
26 |
+
#
|
27 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
28 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
29 |
+
# you may not use this file except in compliance with the License.
|
30 |
+
# You may obtain a copy of the License at
|
31 |
+
#
|
32 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
33 |
+
#
|
34 |
+
# Unless required by applicable law or agreed to in writing, software
|
35 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
36 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
37 |
+
# See the License for the specific language governing permissions and
|
38 |
+
# limitations under the License.
|
39 |
+
#
|
40 |
+
###############################################################################
|
41 |
+
# Original Author and Contact: Rafael Valle
|
42 |
+
# Modification by Rafael Valle
|
43 |
+
|
44 |
+
import torch
|
45 |
+
from torch import nn
|
46 |
+
from common import DenseLayer, SplineTransformationLayerAR
|
47 |
+
|
48 |
+
|
49 |
+
class AR_Back_Step(torch.nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
n_attr_channels,
|
53 |
+
n_speaker_dim,
|
54 |
+
n_text_dim,
|
55 |
+
n_hidden,
|
56 |
+
n_lstm_layers,
|
57 |
+
scaling_fn,
|
58 |
+
spline_flow_params=None,
|
59 |
+
):
|
60 |
+
super(AR_Back_Step, self).__init__()
|
61 |
+
self.ar_step = AR_Step(
|
62 |
+
n_attr_channels,
|
63 |
+
n_speaker_dim,
|
64 |
+
n_text_dim,
|
65 |
+
n_hidden,
|
66 |
+
n_lstm_layers,
|
67 |
+
scaling_fn,
|
68 |
+
spline_flow_params,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, mel, context, lens):
|
72 |
+
mel = torch.flip(mel, (0,))
|
73 |
+
context = torch.flip(context, (0,))
|
74 |
+
# backwards flow, send padded zeros back to end
|
75 |
+
for k in range(mel.size(1)):
|
76 |
+
mel[:, k] = mel[:, k].roll(lens[k].item(), dims=0)
|
77 |
+
context[:, k] = context[:, k].roll(lens[k].item(), dims=0)
|
78 |
+
|
79 |
+
mel, log_s = self.ar_step(mel, context, lens)
|
80 |
+
|
81 |
+
# move padded zeros back to beginning
|
82 |
+
for k in range(mel.size(1)):
|
83 |
+
mel[:, k] = mel[:, k].roll(-lens[k].item(), dims=0)
|
84 |
+
|
85 |
+
return torch.flip(mel, (0,)), log_s
|
86 |
+
|
87 |
+
def infer(self, residual, context):
|
88 |
+
residual = self.ar_step.infer(
|
89 |
+
torch.flip(residual, (0,)), torch.flip(context, (0,))
|
90 |
+
)
|
91 |
+
residual = torch.flip(residual, (0,))
|
92 |
+
return residual
|
93 |
+
|
94 |
+
|
95 |
+
class AR_Step(torch.nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
n_attr_channels,
|
99 |
+
n_speaker_dim,
|
100 |
+
n_text_channels,
|
101 |
+
n_hidden,
|
102 |
+
n_lstm_layers,
|
103 |
+
scaling_fn,
|
104 |
+
spline_flow_params=None,
|
105 |
+
):
|
106 |
+
super(AR_Step, self).__init__()
|
107 |
+
if spline_flow_params is not None:
|
108 |
+
self.spline_flow = SplineTransformationLayerAR(**spline_flow_params)
|
109 |
+
else:
|
110 |
+
self.n_out_dims = n_attr_channels
|
111 |
+
self.conv = torch.nn.Conv1d(n_hidden, 2 * n_attr_channels, 1)
|
112 |
+
self.conv.weight.data = 0.0 * self.conv.weight.data
|
113 |
+
self.conv.bias.data = 0.0 * self.conv.bias.data
|
114 |
+
|
115 |
+
self.attr_lstm = torch.nn.LSTM(n_attr_channels, n_hidden)
|
116 |
+
self.lstm = torch.nn.LSTM(
|
117 |
+
n_hidden + n_text_channels + n_speaker_dim, n_hidden, n_lstm_layers
|
118 |
+
)
|
119 |
+
|
120 |
+
if spline_flow_params is None:
|
121 |
+
self.dense_layer = DenseLayer(in_dim=n_hidden, sizes=[n_hidden, n_hidden])
|
122 |
+
self.scaling_fn = scaling_fn
|
123 |
+
|
124 |
+
def run_padded_sequence(
|
125 |
+
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
|
126 |
+
):
|
127 |
+
"""Sorts input data by previded ordering (and un-ordering) and runs the
|
128 |
+
packed data through the recurrent model
|
129 |
+
|
130 |
+
Args:
|
131 |
+
sorted_idx (torch.tensor): 1D sorting index
|
132 |
+
unsort_idx (torch.tensor): 1D unsorting index (inverse sorted_idx)
|
133 |
+
lens: lengths of input data (sorted in descending order)
|
134 |
+
padded_data (torch.tensor): input sequences (padded)
|
135 |
+
recurrent_model (nn.Module): recurrent model to run data through
|
136 |
+
Returns:
|
137 |
+
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
|
138 |
+
unsorted, ordering
|
139 |
+
"""
|
140 |
+
|
141 |
+
# sort the data by decreasing length using provided index
|
142 |
+
# we assume batch index is in dim=1
|
143 |
+
padded_data = padded_data[:, sorted_idx]
|
144 |
+
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens.cpu())
|
145 |
+
hidden_vectors = recurrent_model(padded_data)[0]
|
146 |
+
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
|
147 |
+
# unsort the results at dim=1 and return
|
148 |
+
hidden_vectors = hidden_vectors[:, unsort_idx]
|
149 |
+
return hidden_vectors
|
150 |
+
|
151 |
+
def get_scaling_and_logs(self, scale_unconstrained):
|
152 |
+
if self.scaling_fn == "translate":
|
153 |
+
s = torch.exp(scale_unconstrained * 0)
|
154 |
+
log_s = scale_unconstrained * 0
|
155 |
+
elif self.scaling_fn == "exp":
|
156 |
+
s = torch.exp(scale_unconstrained)
|
157 |
+
log_s = scale_unconstrained # log(exp
|
158 |
+
elif self.scaling_fn == "tanh":
|
159 |
+
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
|
160 |
+
log_s = torch.log(s)
|
161 |
+
elif self.scaling_fn == "sigmoid":
|
162 |
+
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
|
163 |
+
log_s = torch.log(s)
|
164 |
+
else:
|
165 |
+
raise Exception("Scaling fn {} not supp.".format(self.scaling_fn))
|
166 |
+
|
167 |
+
return s, log_s
|
168 |
+
|
169 |
+
def forward(self, mel, context, lens):
|
170 |
+
dummy = torch.FloatTensor(1, mel.size(1), mel.size(2)).zero_()
|
171 |
+
dummy = dummy.type(mel.type())
|
172 |
+
# seq_len x batch x dim
|
173 |
+
mel0 = torch.cat([dummy, mel[:-1]], 0)
|
174 |
+
|
175 |
+
self.lstm.flatten_parameters()
|
176 |
+
self.attr_lstm.flatten_parameters()
|
177 |
+
if lens is not None:
|
178 |
+
# collect decreasing length indices
|
179 |
+
lens, ids = torch.sort(lens, descending=True)
|
180 |
+
original_ids = [0] * lens.size(0)
|
181 |
+
for i, ids_i in enumerate(ids):
|
182 |
+
original_ids[ids_i] = i
|
183 |
+
# mel_seq_len x batch x hidden_dim
|
184 |
+
mel_hidden = self.run_padded_sequence(
|
185 |
+
ids, original_ids, lens, mel0, self.attr_lstm
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
mel_hidden = self.attr_lstm(mel0)[0]
|
189 |
+
|
190 |
+
decoder_input = torch.cat((mel_hidden, context), -1)
|
191 |
+
|
192 |
+
if lens is not None:
|
193 |
+
# reorder, run padded sequence and undo reordering
|
194 |
+
lstm_hidden = self.run_padded_sequence(
|
195 |
+
ids, original_ids, lens, decoder_input, self.lstm
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
lstm_hidden = self.lstm(decoder_input)[0]
|
199 |
+
|
200 |
+
if hasattr(self, "spline_flow"):
|
201 |
+
# spline flow fn expects inputs to be batch, channel, time
|
202 |
+
lstm_hidden = lstm_hidden.permute(1, 2, 0)
|
203 |
+
mel = mel.permute(1, 2, 0)
|
204 |
+
mel, log_s = self.spline_flow(mel, lstm_hidden, inverse=False)
|
205 |
+
mel = mel.permute(2, 0, 1)
|
206 |
+
log_s = log_s.permute(2, 0, 1)
|
207 |
+
else:
|
208 |
+
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
|
209 |
+
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
|
210 |
+
|
211 |
+
scale, log_s = self.get_scaling_and_logs(
|
212 |
+
decoder_output[:, :, : self.n_out_dims]
|
213 |
+
)
|
214 |
+
bias = decoder_output[:, :, self.n_out_dims :]
|
215 |
+
|
216 |
+
mel = scale * mel + bias
|
217 |
+
|
218 |
+
return mel, log_s
|
219 |
+
|
220 |
+
def infer(self, residual, context):
|
221 |
+
total_output = [] # seems 10FPS faster than pre-allocation
|
222 |
+
|
223 |
+
output = None
|
224 |
+
dummy = torch.cuda.FloatTensor(1, residual.size(1), residual.size(2)).zero_()
|
225 |
+
self.attr_lstm.flatten_parameters()
|
226 |
+
|
227 |
+
for i in range(0, residual.size(0)):
|
228 |
+
if i == 0:
|
229 |
+
output = dummy
|
230 |
+
mel_hidden, (h, c) = self.attr_lstm(output)
|
231 |
+
else:
|
232 |
+
mel_hidden, (h, c) = self.attr_lstm(output, (h, c))
|
233 |
+
|
234 |
+
decoder_input = torch.cat((mel_hidden, context[i][None]), -1)
|
235 |
+
|
236 |
+
if i == 0:
|
237 |
+
lstm_hidden, (h1, c1) = self.lstm(decoder_input)
|
238 |
+
else:
|
239 |
+
lstm_hidden, (h1, c1) = self.lstm(decoder_input, (h1, c1))
|
240 |
+
|
241 |
+
if hasattr(self, "spline_flow"):
|
242 |
+
# expects inputs to be batch, channel, time
|
243 |
+
lstm_hidden = lstm_hidden.permute(1, 2, 0)
|
244 |
+
output = residual[i : i + 1].permute(1, 2, 0)
|
245 |
+
output = self.spline_flow(output, lstm_hidden, inverse=True)
|
246 |
+
output = output.permute(2, 0, 1)
|
247 |
+
else:
|
248 |
+
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
|
249 |
+
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
|
250 |
+
|
251 |
+
s, log_s = self.get_scaling_and_logs(
|
252 |
+
decoder_output[:, :, : decoder_output.size(2) // 2]
|
253 |
+
)
|
254 |
+
b = decoder_output[:, :, decoder_output.size(2) // 2 :]
|
255 |
+
output = (residual[i : i + 1] - b) / s
|
256 |
+
total_output.append(output)
|
257 |
+
|
258 |
+
total_output = torch.cat(total_output, 0)
|
259 |
+
return total_output
|
bigvgan.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Optional, Union, Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
16 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
17 |
+
|
18 |
+
import activations
|
19 |
+
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
20 |
+
|
21 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
22 |
+
|
23 |
+
|
24 |
+
class AttrDict(dict):
|
25 |
+
def __init__(self, *args, **kwargs):
|
26 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
27 |
+
self.__dict__ = self
|
28 |
+
|
29 |
+
|
30 |
+
def build_env(config, config_name, path):
|
31 |
+
t_path = os.path.join(path, config_name)
|
32 |
+
if config != t_path:
|
33 |
+
os.makedirs(path, exist_ok=True)
|
34 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
35 |
+
|
36 |
+
|
37 |
+
def init_weights(m, mean=0.0, std=0.01):
|
38 |
+
classname = m.__class__.__name__
|
39 |
+
if classname.find("Conv") != -1:
|
40 |
+
m.weight.data.normal_(mean, std)
|
41 |
+
|
42 |
+
|
43 |
+
def apply_weight_norm(m):
|
44 |
+
classname = m.__class__.__name__
|
45 |
+
if classname.find("Conv") != -1:
|
46 |
+
weight_norm(m)
|
47 |
+
|
48 |
+
|
49 |
+
def get_padding(kernel_size, dilation=1):
|
50 |
+
return int((kernel_size * dilation - dilation) / 2)
|
51 |
+
|
52 |
+
|
53 |
+
def load_checkpoint(filepath, device):
|
54 |
+
assert os.path.isfile(filepath)
|
55 |
+
print(f"Loading '{filepath}'")
|
56 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
57 |
+
print("Complete.")
|
58 |
+
return checkpoint_dict
|
59 |
+
|
60 |
+
|
61 |
+
def load_hparams_from_json(path) -> AttrDict:
|
62 |
+
with open(path) as f:
|
63 |
+
data = f.read()
|
64 |
+
return AttrDict(json.loads(data))
|
65 |
+
|
66 |
+
|
67 |
+
class AMPBlock1(torch.nn.Module):
|
68 |
+
"""
|
69 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
70 |
+
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
71 |
+
|
72 |
+
Args:
|
73 |
+
h (AttrDict): Hyperparameters.
|
74 |
+
channels (int): Number of convolution channels.
|
75 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
76 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
77 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
h: AttrDict,
|
83 |
+
channels: int,
|
84 |
+
kernel_size: int = 3,
|
85 |
+
dilation: tuple = (1, 3, 5),
|
86 |
+
activation: str = None,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.h = h
|
91 |
+
|
92 |
+
self.convs1 = nn.ModuleList(
|
93 |
+
[
|
94 |
+
weight_norm(
|
95 |
+
Conv1d(
|
96 |
+
channels,
|
97 |
+
channels,
|
98 |
+
kernel_size,
|
99 |
+
stride=1,
|
100 |
+
dilation=d,
|
101 |
+
padding=get_padding(kernel_size, d),
|
102 |
+
)
|
103 |
+
)
|
104 |
+
for d in dilation
|
105 |
+
]
|
106 |
+
)
|
107 |
+
self.convs1.apply(init_weights)
|
108 |
+
|
109 |
+
self.convs2 = nn.ModuleList(
|
110 |
+
[
|
111 |
+
weight_norm(
|
112 |
+
Conv1d(
|
113 |
+
channels,
|
114 |
+
channels,
|
115 |
+
kernel_size,
|
116 |
+
stride=1,
|
117 |
+
dilation=1,
|
118 |
+
padding=get_padding(kernel_size, 1),
|
119 |
+
)
|
120 |
+
)
|
121 |
+
for _ in range(len(dilation))
|
122 |
+
]
|
123 |
+
)
|
124 |
+
self.convs2.apply(init_weights)
|
125 |
+
|
126 |
+
self.num_layers = len(self.convs1) + len(
|
127 |
+
self.convs2
|
128 |
+
) # Total number of conv layers
|
129 |
+
|
130 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
131 |
+
if self.h.get("use_cuda_kernel", False):
|
132 |
+
from alias_free_activation.cuda.activation1d import (
|
133 |
+
Activation1d as CudaActivation1d,
|
134 |
+
)
|
135 |
+
|
136 |
+
Activation1d = CudaActivation1d
|
137 |
+
else:
|
138 |
+
Activation1d = TorchActivation1d
|
139 |
+
|
140 |
+
# Activation functions
|
141 |
+
if activation == "snake":
|
142 |
+
self.activations = nn.ModuleList(
|
143 |
+
[
|
144 |
+
Activation1d(
|
145 |
+
activation=activations.Snake(
|
146 |
+
channels, alpha_logscale=h.snake_logscale
|
147 |
+
)
|
148 |
+
)
|
149 |
+
for _ in range(self.num_layers)
|
150 |
+
]
|
151 |
+
)
|
152 |
+
elif activation == "snakebeta":
|
153 |
+
self.activations = nn.ModuleList(
|
154 |
+
[
|
155 |
+
Activation1d(
|
156 |
+
activation=activations.SnakeBeta(
|
157 |
+
channels, alpha_logscale=h.snake_logscale
|
158 |
+
)
|
159 |
+
)
|
160 |
+
for _ in range(self.num_layers)
|
161 |
+
]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
raise NotImplementedError(
|
165 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
170 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
171 |
+
xt = a1(x)
|
172 |
+
xt = c1(xt)
|
173 |
+
xt = a2(xt)
|
174 |
+
xt = c2(xt)
|
175 |
+
x = xt + x
|
176 |
+
|
177 |
+
return x
|
178 |
+
|
179 |
+
def remove_weight_norm(self):
|
180 |
+
for l in self.convs1:
|
181 |
+
remove_weight_norm(l)
|
182 |
+
for l in self.convs2:
|
183 |
+
remove_weight_norm(l)
|
184 |
+
|
185 |
+
|
186 |
+
class AMPBlock2(torch.nn.Module):
|
187 |
+
"""
|
188 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
189 |
+
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
190 |
+
|
191 |
+
Args:
|
192 |
+
h (AttrDict): Hyperparameters.
|
193 |
+
channels (int): Number of convolution channels.
|
194 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
195 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
196 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
h: AttrDict,
|
202 |
+
channels: int,
|
203 |
+
kernel_size: int = 3,
|
204 |
+
dilation: tuple = (1, 3, 5),
|
205 |
+
activation: str = None,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
|
209 |
+
self.h = h
|
210 |
+
|
211 |
+
self.convs = nn.ModuleList(
|
212 |
+
[
|
213 |
+
weight_norm(
|
214 |
+
Conv1d(
|
215 |
+
channels,
|
216 |
+
channels,
|
217 |
+
kernel_size,
|
218 |
+
stride=1,
|
219 |
+
dilation=d,
|
220 |
+
padding=get_padding(kernel_size, d),
|
221 |
+
)
|
222 |
+
)
|
223 |
+
for d in dilation
|
224 |
+
]
|
225 |
+
)
|
226 |
+
self.convs.apply(init_weights)
|
227 |
+
|
228 |
+
self.num_layers = len(self.convs) # Total number of conv layers
|
229 |
+
|
230 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
231 |
+
if self.h.get("use_cuda_kernel", False):
|
232 |
+
from alias_free_activation.cuda.activation1d import (
|
233 |
+
Activation1d as CudaActivation1d,
|
234 |
+
)
|
235 |
+
|
236 |
+
Activation1d = CudaActivation1d
|
237 |
+
else:
|
238 |
+
Activation1d = TorchActivation1d
|
239 |
+
|
240 |
+
# Activation functions
|
241 |
+
if activation == "snake":
|
242 |
+
self.activations = nn.ModuleList(
|
243 |
+
[
|
244 |
+
Activation1d(
|
245 |
+
activation=activations.Snake(
|
246 |
+
channels, alpha_logscale=h.snake_logscale
|
247 |
+
)
|
248 |
+
)
|
249 |
+
for _ in range(self.num_layers)
|
250 |
+
]
|
251 |
+
)
|
252 |
+
elif activation == "snakebeta":
|
253 |
+
self.activations = nn.ModuleList(
|
254 |
+
[
|
255 |
+
Activation1d(
|
256 |
+
activation=activations.SnakeBeta(
|
257 |
+
channels, alpha_logscale=h.snake_logscale
|
258 |
+
)
|
259 |
+
)
|
260 |
+
for _ in range(self.num_layers)
|
261 |
+
]
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
raise NotImplementedError(
|
265 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
266 |
+
)
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
for c, a in zip(self.convs, self.activations):
|
270 |
+
xt = a(x)
|
271 |
+
xt = c(xt)
|
272 |
+
x = xt + x
|
273 |
+
|
274 |
+
def remove_weight_norm(self):
|
275 |
+
for l in self.convs:
|
276 |
+
remove_weight_norm(l)
|
277 |
+
|
278 |
+
|
279 |
+
class BigVGAN(
|
280 |
+
torch.nn.Module,
|
281 |
+
PyTorchModelHubMixin,
|
282 |
+
library_name="bigvgan",
|
283 |
+
repo_url="https://github.com/NVIDIA/BigVGAN",
|
284 |
+
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
285 |
+
pipeline_tag="audio-to-audio",
|
286 |
+
license="mit",
|
287 |
+
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
288 |
+
):
|
289 |
+
"""
|
290 |
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
291 |
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
h (AttrDict): Hyperparameters.
|
295 |
+
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
296 |
+
|
297 |
+
Note:
|
298 |
+
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
299 |
+
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
303 |
+
super().__init__()
|
304 |
+
self.h = h
|
305 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
306 |
+
|
307 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
308 |
+
if self.h.get("use_cuda_kernel", False):
|
309 |
+
from alias_free_activation.cuda.activation1d import (
|
310 |
+
Activation1d as CudaActivation1d,
|
311 |
+
)
|
312 |
+
|
313 |
+
Activation1d = CudaActivation1d
|
314 |
+
else:
|
315 |
+
Activation1d = TorchActivation1d
|
316 |
+
|
317 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
318 |
+
self.num_upsamples = len(h.upsample_rates)
|
319 |
+
|
320 |
+
# Pre-conv
|
321 |
+
self.conv_pre = weight_norm(
|
322 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
323 |
+
)
|
324 |
+
|
325 |
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
326 |
+
if h.resblock == "1":
|
327 |
+
resblock_class = AMPBlock1
|
328 |
+
elif h.resblock == "2":
|
329 |
+
resblock_class = AMPBlock2
|
330 |
+
else:
|
331 |
+
raise ValueError(
|
332 |
+
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
333 |
+
)
|
334 |
+
|
335 |
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
336 |
+
self.ups = nn.ModuleList()
|
337 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
338 |
+
self.ups.append(
|
339 |
+
nn.ModuleList(
|
340 |
+
[
|
341 |
+
weight_norm(
|
342 |
+
ConvTranspose1d(
|
343 |
+
h.upsample_initial_channel // (2**i),
|
344 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
345 |
+
k,
|
346 |
+
u,
|
347 |
+
padding=(k - u) // 2,
|
348 |
+
)
|
349 |
+
)
|
350 |
+
]
|
351 |
+
)
|
352 |
+
)
|
353 |
+
|
354 |
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
355 |
+
self.resblocks = nn.ModuleList()
|
356 |
+
for i in range(len(self.ups)):
|
357 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
358 |
+
for j, (k, d) in enumerate(
|
359 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
360 |
+
):
|
361 |
+
self.resblocks.append(
|
362 |
+
resblock_class(h, ch, k, d, activation=h.activation)
|
363 |
+
)
|
364 |
+
|
365 |
+
# Post-conv
|
366 |
+
activation_post = (
|
367 |
+
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
368 |
+
if h.activation == "snake"
|
369 |
+
else (
|
370 |
+
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
371 |
+
if h.activation == "snakebeta"
|
372 |
+
else None
|
373 |
+
)
|
374 |
+
)
|
375 |
+
if activation_post is None:
|
376 |
+
raise NotImplementedError(
|
377 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
378 |
+
)
|
379 |
+
|
380 |
+
self.activation_post = Activation1d(activation=activation_post)
|
381 |
+
|
382 |
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
383 |
+
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
384 |
+
self.conv_post = weight_norm(
|
385 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
386 |
+
)
|
387 |
+
|
388 |
+
# Weight initialization
|
389 |
+
for i in range(len(self.ups)):
|
390 |
+
self.ups[i].apply(init_weights)
|
391 |
+
self.conv_post.apply(init_weights)
|
392 |
+
|
393 |
+
# Final tanh activation. Defaults to True for backward compatibility
|
394 |
+
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
395 |
+
|
396 |
+
def forward(self, x):
|
397 |
+
# Pre-conv
|
398 |
+
x = self.conv_pre(x)
|
399 |
+
|
400 |
+
for i in range(self.num_upsamples):
|
401 |
+
# Upsampling
|
402 |
+
for i_up in range(len(self.ups[i])):
|
403 |
+
x = self.ups[i][i_up](x)
|
404 |
+
# AMP blocks
|
405 |
+
xs = None
|
406 |
+
for j in range(self.num_kernels):
|
407 |
+
if xs is None:
|
408 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
409 |
+
else:
|
410 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
411 |
+
x = xs / self.num_kernels
|
412 |
+
|
413 |
+
# Post-conv
|
414 |
+
x = self.activation_post(x)
|
415 |
+
x = self.conv_post(x)
|
416 |
+
# Final tanh activation
|
417 |
+
if self.use_tanh_at_final:
|
418 |
+
x = torch.tanh(x)
|
419 |
+
else:
|
420 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
421 |
+
|
422 |
+
return x
|
423 |
+
|
424 |
+
def remove_weight_norm(self):
|
425 |
+
try:
|
426 |
+
print("Removing weight norm...")
|
427 |
+
for l in self.ups:
|
428 |
+
for l_i in l:
|
429 |
+
remove_weight_norm(l_i)
|
430 |
+
for l in self.resblocks:
|
431 |
+
l.remove_weight_norm()
|
432 |
+
remove_weight_norm(self.conv_pre)
|
433 |
+
remove_weight_norm(self.conv_post)
|
434 |
+
except ValueError:
|
435 |
+
print("[INFO] Model already removed weight norm. Skipping!")
|
436 |
+
pass
|
437 |
+
|
438 |
+
# Additional methods for huggingface_hub support
|
439 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
440 |
+
"""Save weights and config.json from a Pytorch model to a local directory."""
|
441 |
+
|
442 |
+
model_path = save_directory / "bigvgan_generator.pt"
|
443 |
+
torch.save({"generator": self.state_dict()}, model_path)
|
444 |
+
|
445 |
+
config_path = save_directory / "config.json"
|
446 |
+
with open(config_path, "w") as config_file:
|
447 |
+
json.dump(self.h, config_file, indent=4)
|
448 |
+
|
449 |
+
@classmethod
|
450 |
+
def _from_pretrained(
|
451 |
+
cls,
|
452 |
+
*,
|
453 |
+
model_id: str,
|
454 |
+
revision: str,
|
455 |
+
cache_dir: str,
|
456 |
+
force_download: bool,
|
457 |
+
proxies: Optional[Dict],
|
458 |
+
resume_download: bool,
|
459 |
+
local_files_only: bool,
|
460 |
+
token: Union[str, bool, None],
|
461 |
+
map_location: str = "cpu", # Additional argument
|
462 |
+
strict: bool = False, # Additional argument
|
463 |
+
use_cuda_kernel: bool = False,
|
464 |
+
**model_kwargs,
|
465 |
+
):
|
466 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
467 |
+
|
468 |
+
# Download and load hyperparameters (h) used by BigVGAN
|
469 |
+
if os.path.isdir(model_id):
|
470 |
+
print("Loading config.json from local directory")
|
471 |
+
config_file = os.path.join(model_id, "config.json")
|
472 |
+
else:
|
473 |
+
config_file = hf_hub_download(
|
474 |
+
repo_id=model_id,
|
475 |
+
filename="config.json",
|
476 |
+
revision=revision,
|
477 |
+
cache_dir=cache_dir,
|
478 |
+
force_download=force_download,
|
479 |
+
proxies=proxies,
|
480 |
+
resume_download=resume_download,
|
481 |
+
token=token,
|
482 |
+
local_files_only=local_files_only,
|
483 |
+
)
|
484 |
+
h = load_hparams_from_json(config_file)
|
485 |
+
|
486 |
+
# instantiate BigVGAN using h
|
487 |
+
if use_cuda_kernel:
|
488 |
+
print(
|
489 |
+
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
490 |
+
)
|
491 |
+
print(
|
492 |
+
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
493 |
+
)
|
494 |
+
print(
|
495 |
+
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
496 |
+
)
|
497 |
+
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
498 |
+
|
499 |
+
# Download and load pretrained generator weight
|
500 |
+
if os.path.isdir(model_id):
|
501 |
+
print("Loading weights from local directory")
|
502 |
+
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
503 |
+
else:
|
504 |
+
print(f"Loading weights from {model_id}")
|
505 |
+
model_file = hf_hub_download(
|
506 |
+
repo_id=model_id,
|
507 |
+
filename="bigvgan_generator.pt",
|
508 |
+
revision=revision,
|
509 |
+
cache_dir=cache_dir,
|
510 |
+
force_download=force_download,
|
511 |
+
proxies=proxies,
|
512 |
+
resume_download=resume_download,
|
513 |
+
token=token,
|
514 |
+
local_files_only=local_files_only,
|
515 |
+
)
|
516 |
+
|
517 |
+
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
518 |
+
|
519 |
+
try:
|
520 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
521 |
+
except RuntimeError:
|
522 |
+
print(
|
523 |
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
524 |
+
)
|
525 |
+
model.remove_weight_norm()
|
526 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
527 |
+
|
528 |
+
return model
|
common.py
ADDED
@@ -0,0 +1,1083 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# 1x1InvertibleConv and WN based on implementation from WaveGlow https://github.com/NVIDIA/waveglow/blob/master/glow.py
|
23 |
+
# Original license:
|
24 |
+
# *****************************************************************************
|
25 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
26 |
+
#
|
27 |
+
# Redistribution and use in source and binary forms, with or without
|
28 |
+
# modification, are permitted provided that the following conditions are met:
|
29 |
+
# * Redistributions of source code must retain the above copyright
|
30 |
+
# notice, this list of conditions and the following disclaimer.
|
31 |
+
# * Redistributions in binary form must reproduce the above copyright
|
32 |
+
# notice, this list of conditions and the following disclaimer in the
|
33 |
+
# documentation and/or other materials provided with the distribution.
|
34 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
35 |
+
# names of its contributors may be used to endorse or promote products
|
36 |
+
# derived from this software without specific prior written permission.
|
37 |
+
#
|
38 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
39 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
40 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
41 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
42 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
43 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
44 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
45 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
46 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
47 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
48 |
+
#
|
49 |
+
# *****************************************************************************
|
50 |
+
|
51 |
+
import torch
|
52 |
+
from torch import nn
|
53 |
+
from torch.nn import functional as F
|
54 |
+
|
55 |
+
import numpy as np
|
56 |
+
import ast
|
57 |
+
|
58 |
+
from splines import (
|
59 |
+
piecewise_linear_transform,
|
60 |
+
piecewise_linear_inverse_transform,
|
61 |
+
unbounded_piecewise_quadratic_transform,
|
62 |
+
)
|
63 |
+
from partialconv1d import PartialConv1d as pconv1d
|
64 |
+
from typing import Tuple
|
65 |
+
|
66 |
+
use_cuda = torch.cuda.is_available()
|
67 |
+
|
68 |
+
if use_cuda:
|
69 |
+
device = "cuda"
|
70 |
+
else:
|
71 |
+
device = "cpu"
|
72 |
+
|
73 |
+
|
74 |
+
def update_params(config, params):
|
75 |
+
for param in params:
|
76 |
+
print(param)
|
77 |
+
k, v = param.split("=")
|
78 |
+
try:
|
79 |
+
v = ast.literal_eval(v)
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
|
83 |
+
k_split = k.split(".")
|
84 |
+
if len(k_split) > 1:
|
85 |
+
parent_k = k_split[0]
|
86 |
+
cur_param = [".".join(k_split[1:]) + "=" + str(v)]
|
87 |
+
update_params(config[parent_k], cur_param)
|
88 |
+
elif k in config and len(k_split) == 1:
|
89 |
+
print(f"overriding {k} with {v}")
|
90 |
+
config[k] = v
|
91 |
+
else:
|
92 |
+
print("{}, {} params not updated".format(k, v))
|
93 |
+
|
94 |
+
|
95 |
+
def get_mask_from_lengths(lengths):
|
96 |
+
"""Constructs binary mask from a 1D torch tensor of input lengths
|
97 |
+
|
98 |
+
Args:
|
99 |
+
lengths (torch.tensor): 1D tensor
|
100 |
+
Returns:
|
101 |
+
mask (torch.tensor): num_sequences x max_length x 1 binary tensor
|
102 |
+
"""
|
103 |
+
max_len = torch.max(lengths).item()
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
106 |
+
else:
|
107 |
+
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len))
|
108 |
+
mask = (ids < lengths.unsqueeze(1)).bool()
|
109 |
+
return mask
|
110 |
+
|
111 |
+
|
112 |
+
class ExponentialClass(torch.nn.Module):
|
113 |
+
def __init__(self):
|
114 |
+
super(ExponentialClass, self).__init__()
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return torch.exp(x)
|
118 |
+
|
119 |
+
|
120 |
+
class LinearNorm(torch.nn.Module):
|
121 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
122 |
+
super(LinearNorm, self).__init__()
|
123 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
124 |
+
|
125 |
+
torch.nn.init.xavier_uniform_(
|
126 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
return self.linear_layer(x)
|
131 |
+
|
132 |
+
|
133 |
+
class ConvNorm(torch.nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
in_channels,
|
137 |
+
out_channels,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=None,
|
141 |
+
dilation=1,
|
142 |
+
bias=True,
|
143 |
+
w_init_gain="linear",
|
144 |
+
use_partial_padding=False,
|
145 |
+
use_weight_norm=False,
|
146 |
+
):
|
147 |
+
super(ConvNorm, self).__init__()
|
148 |
+
if padding is None:
|
149 |
+
assert kernel_size % 2 == 1
|
150 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
151 |
+
self.kernel_size = kernel_size
|
152 |
+
self.dilation = dilation
|
153 |
+
self.use_partial_padding = use_partial_padding
|
154 |
+
self.use_weight_norm = use_weight_norm
|
155 |
+
conv_fn = torch.nn.Conv1d
|
156 |
+
if self.use_partial_padding:
|
157 |
+
conv_fn = pconv1d
|
158 |
+
self.conv = conv_fn(
|
159 |
+
in_channels,
|
160 |
+
out_channels,
|
161 |
+
kernel_size=kernel_size,
|
162 |
+
stride=stride,
|
163 |
+
padding=padding,
|
164 |
+
dilation=dilation,
|
165 |
+
bias=bias,
|
166 |
+
)
|
167 |
+
torch.nn.init.xavier_uniform_(
|
168 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
169 |
+
)
|
170 |
+
if self.use_weight_norm:
|
171 |
+
self.conv = nn.utils.weight_norm(self.conv)
|
172 |
+
|
173 |
+
def forward(self, signal, mask=None):
|
174 |
+
if self.use_partial_padding:
|
175 |
+
conv_signal = self.conv(signal, mask)
|
176 |
+
else:
|
177 |
+
conv_signal = self.conv(signal)
|
178 |
+
if mask is not None:
|
179 |
+
# always re-zero output if mask is
|
180 |
+
# available to match zero-padding
|
181 |
+
conv_signal = conv_signal * mask
|
182 |
+
return conv_signal
|
183 |
+
|
184 |
+
|
185 |
+
class DenseLayer(nn.Module):
|
186 |
+
def __init__(self, in_dim=1024, sizes=[1024, 1024]):
|
187 |
+
super(DenseLayer, self).__init__()
|
188 |
+
in_sizes = [in_dim] + sizes[:-1]
|
189 |
+
self.layers = nn.ModuleList(
|
190 |
+
[
|
191 |
+
LinearNorm(in_size, out_size, bias=True)
|
192 |
+
for (in_size, out_size) in zip(in_sizes, sizes)
|
193 |
+
]
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for linear in self.layers:
|
198 |
+
x = torch.tanh(linear(x))
|
199 |
+
return x
|
200 |
+
|
201 |
+
|
202 |
+
class LengthRegulator(nn.Module):
|
203 |
+
def __init__(self):
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
def forward(self, x, dur):
|
207 |
+
output = []
|
208 |
+
for x_i, dur_i in zip(x, dur):
|
209 |
+
expanded = self.expand(x_i, dur_i)
|
210 |
+
output.append(expanded)
|
211 |
+
output = self.pad(output)
|
212 |
+
return output
|
213 |
+
|
214 |
+
def expand(self, x, dur):
|
215 |
+
output = []
|
216 |
+
for i, frame in enumerate(x):
|
217 |
+
expanded_len = int(dur[i] + 0.5)
|
218 |
+
expanded = frame.expand(expanded_len, -1)
|
219 |
+
output.append(expanded)
|
220 |
+
output = torch.cat(output, 0)
|
221 |
+
return output
|
222 |
+
|
223 |
+
def pad(self, x):
|
224 |
+
output = []
|
225 |
+
max_len = max([x[i].size(0) for i in range(len(x))])
|
226 |
+
for i, seq in enumerate(x):
|
227 |
+
padded = F.pad(seq, [0, 0, 0, max_len - seq.size(0)], "constant", 0.0)
|
228 |
+
output.append(padded)
|
229 |
+
output = torch.stack(output)
|
230 |
+
return output
|
231 |
+
|
232 |
+
|
233 |
+
class ConvLSTMLinear(nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
in_dim,
|
237 |
+
out_dim,
|
238 |
+
n_layers=2,
|
239 |
+
n_channels=256,
|
240 |
+
kernel_size=3,
|
241 |
+
p_dropout=0.1,
|
242 |
+
lstm_type="bilstm",
|
243 |
+
use_linear=True,
|
244 |
+
):
|
245 |
+
super(ConvLSTMLinear, self).__init__()
|
246 |
+
self.out_dim = out_dim
|
247 |
+
self.lstm_type = lstm_type
|
248 |
+
self.use_linear = use_linear
|
249 |
+
self.dropout = nn.Dropout(p=p_dropout)
|
250 |
+
|
251 |
+
convolutions = []
|
252 |
+
for i in range(n_layers):
|
253 |
+
conv_layer = ConvNorm(
|
254 |
+
in_dim if i == 0 else n_channels,
|
255 |
+
n_channels,
|
256 |
+
kernel_size=kernel_size,
|
257 |
+
stride=1,
|
258 |
+
padding=int((kernel_size - 1) / 2),
|
259 |
+
dilation=1,
|
260 |
+
w_init_gain="relu",
|
261 |
+
)
|
262 |
+
conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight")
|
263 |
+
convolutions.append(conv_layer)
|
264 |
+
|
265 |
+
self.convolutions = nn.ModuleList(convolutions)
|
266 |
+
|
267 |
+
if not self.use_linear:
|
268 |
+
n_channels = out_dim
|
269 |
+
|
270 |
+
if self.lstm_type != "":
|
271 |
+
use_bilstm = False
|
272 |
+
lstm_channels = n_channels
|
273 |
+
if self.lstm_type == "bilstm":
|
274 |
+
use_bilstm = True
|
275 |
+
lstm_channels = int(n_channels // 2)
|
276 |
+
|
277 |
+
self.bilstm = nn.LSTM(
|
278 |
+
n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
|
279 |
+
)
|
280 |
+
lstm_norm_fn_pntr = nn.utils.spectral_norm
|
281 |
+
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
|
282 |
+
if self.lstm_type == "bilstm":
|
283 |
+
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
|
284 |
+
|
285 |
+
if self.use_linear:
|
286 |
+
self.dense = nn.Linear(n_channels, out_dim)
|
287 |
+
|
288 |
+
def run_padded_sequence(self, context, lens):
|
289 |
+
context_embedded = []
|
290 |
+
for b_ind in range(context.size()[0]): # TODO: speed up
|
291 |
+
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
|
292 |
+
for conv in self.convolutions:
|
293 |
+
curr_context = self.dropout(F.relu(conv(curr_context)))
|
294 |
+
context_embedded.append(curr_context[0].transpose(0, 1))
|
295 |
+
context = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
|
296 |
+
return context
|
297 |
+
|
298 |
+
def run_unsorted_inputs(self, fn, context, lens):
|
299 |
+
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
|
300 |
+
unsort_ids = [0] * lens.size(0)
|
301 |
+
for i in range(len(ids_sorted)):
|
302 |
+
unsort_ids[ids_sorted[i]] = i
|
303 |
+
lens_sorted = lens_sorted.long().cpu()
|
304 |
+
|
305 |
+
context = context[ids_sorted]
|
306 |
+
context = nn.utils.rnn.pack_padded_sequence(
|
307 |
+
context, lens_sorted, batch_first=True
|
308 |
+
)
|
309 |
+
context = fn(context)[0]
|
310 |
+
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
|
311 |
+
|
312 |
+
# map back to original indices
|
313 |
+
context = context[unsort_ids]
|
314 |
+
return context
|
315 |
+
|
316 |
+
def forward(self, context, lens):
|
317 |
+
if context.size()[0] > 1:
|
318 |
+
context = self.run_padded_sequence(context, lens)
|
319 |
+
# to B, D, T
|
320 |
+
context = context.transpose(1, 2)
|
321 |
+
else:
|
322 |
+
for conv in self.convolutions:
|
323 |
+
context = self.dropout(F.relu(conv(context)))
|
324 |
+
|
325 |
+
if self.lstm_type != "":
|
326 |
+
context = context.transpose(1, 2)
|
327 |
+
self.bilstm.flatten_parameters()
|
328 |
+
if lens is not None:
|
329 |
+
context = self.run_unsorted_inputs(self.bilstm, context, lens)
|
330 |
+
else:
|
331 |
+
context = self.bilstm(context)[0]
|
332 |
+
context = context.transpose(1, 2)
|
333 |
+
|
334 |
+
x_hat = context
|
335 |
+
if self.use_linear:
|
336 |
+
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
|
337 |
+
|
338 |
+
return x_hat
|
339 |
+
|
340 |
+
def infer(self, z, txt_enc, spk_emb):
|
341 |
+
x_hat = self.forward(txt_enc, spk_emb)["x_hat"]
|
342 |
+
x_hat = self.feature_processing.denormalize(x_hat)
|
343 |
+
return x_hat
|
344 |
+
|
345 |
+
|
346 |
+
class Encoder(nn.Module):
|
347 |
+
"""Encoder module:
|
348 |
+
- Three 1-d convolution banks
|
349 |
+
- Bidirectional LSTM
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(
|
353 |
+
self,
|
354 |
+
encoder_n_convolutions=3,
|
355 |
+
encoder_embedding_dim=512,
|
356 |
+
encoder_kernel_size=5,
|
357 |
+
norm_fn=nn.BatchNorm1d,
|
358 |
+
lstm_norm_fn=None,
|
359 |
+
):
|
360 |
+
super(Encoder, self).__init__()
|
361 |
+
|
362 |
+
convolutions = []
|
363 |
+
for _ in range(encoder_n_convolutions):
|
364 |
+
conv_layer = nn.Sequential(
|
365 |
+
ConvNorm(
|
366 |
+
encoder_embedding_dim,
|
367 |
+
encoder_embedding_dim,
|
368 |
+
kernel_size=encoder_kernel_size,
|
369 |
+
stride=1,
|
370 |
+
padding=int((encoder_kernel_size - 1) / 2),
|
371 |
+
dilation=1,
|
372 |
+
w_init_gain="relu",
|
373 |
+
use_partial_padding=True,
|
374 |
+
),
|
375 |
+
norm_fn(encoder_embedding_dim, affine=True),
|
376 |
+
)
|
377 |
+
convolutions.append(conv_layer)
|
378 |
+
self.convolutions = nn.ModuleList(convolutions)
|
379 |
+
|
380 |
+
self.lstm = nn.LSTM(
|
381 |
+
encoder_embedding_dim,
|
382 |
+
int(encoder_embedding_dim / 2),
|
383 |
+
1,
|
384 |
+
batch_first=True,
|
385 |
+
bidirectional=True,
|
386 |
+
)
|
387 |
+
if lstm_norm_fn is not None:
|
388 |
+
if "spectral" in lstm_norm_fn:
|
389 |
+
print("Applying spectral norm to text encoder LSTM")
|
390 |
+
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
391 |
+
elif "weight" in lstm_norm_fn:
|
392 |
+
print("Applying weight norm to text encoder LSTM")
|
393 |
+
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
394 |
+
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
|
395 |
+
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
|
396 |
+
|
397 |
+
@torch.autocast(device, enabled=False)
|
398 |
+
def forward(self, x, in_lens):
|
399 |
+
"""
|
400 |
+
Args:
|
401 |
+
x (torch.tensor): N x C x L padded input of text embeddings
|
402 |
+
in_lens (torch.tensor): 1D tensor of sequence lengths
|
403 |
+
"""
|
404 |
+
if x.size()[0] > 1:
|
405 |
+
x_embedded = []
|
406 |
+
for b_ind in range(x.size()[0]): # TODO: improve speed
|
407 |
+
curr_x = x[b_ind : b_ind + 1, :, : in_lens[b_ind]].clone()
|
408 |
+
for conv in self.convolutions:
|
409 |
+
curr_x = F.dropout(F.relu(conv(curr_x)), 0.5, self.training)
|
410 |
+
x_embedded.append(curr_x[0].transpose(0, 1))
|
411 |
+
x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
|
412 |
+
else:
|
413 |
+
for conv in self.convolutions:
|
414 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
415 |
+
x = x.transpose(1, 2)
|
416 |
+
|
417 |
+
# recent amp change -- change in_lens to int
|
418 |
+
in_lens = in_lens.int().cpu()
|
419 |
+
|
420 |
+
x = nn.utils.rnn.pack_padded_sequence(x, in_lens, batch_first=True)
|
421 |
+
|
422 |
+
self.lstm.flatten_parameters()
|
423 |
+
outputs, _ = self.lstm(x)
|
424 |
+
|
425 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
426 |
+
|
427 |
+
return outputs
|
428 |
+
|
429 |
+
@torch.autocast(device, enabled=False)
|
430 |
+
def infer(self, x):
|
431 |
+
for conv in self.convolutions:
|
432 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
433 |
+
|
434 |
+
x = x.transpose(1, 2)
|
435 |
+
self.lstm.flatten_parameters()
|
436 |
+
outputs, _ = self.lstm(x)
|
437 |
+
|
438 |
+
return outputs
|
439 |
+
|
440 |
+
|
441 |
+
class Invertible1x1ConvLUS(torch.nn.Module):
|
442 |
+
def __init__(self, c, cache_inverse=False):
|
443 |
+
super(Invertible1x1ConvLUS, self).__init__()
|
444 |
+
# Sample a random orthonormal matrix to initialize weights
|
445 |
+
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
|
446 |
+
# Ensure determinant is 1.0 not -1.0
|
447 |
+
if torch.det(W) < 0:
|
448 |
+
W[:, 0] = -1 * W[:, 0]
|
449 |
+
p, lower, upper = torch.lu_unpack(*torch.lu(W))
|
450 |
+
|
451 |
+
self.register_buffer("p", p)
|
452 |
+
# diagonals of lower will always be 1s anyway
|
453 |
+
lower = torch.tril(lower, -1)
|
454 |
+
lower_diag = torch.diag(torch.eye(c, c))
|
455 |
+
self.register_buffer("lower_diag", lower_diag)
|
456 |
+
self.lower = nn.Parameter(lower)
|
457 |
+
self.upper_diag = nn.Parameter(torch.diag(upper))
|
458 |
+
self.upper = nn.Parameter(torch.triu(upper, 1))
|
459 |
+
self.cache_inverse = cache_inverse
|
460 |
+
|
461 |
+
@torch.autocast(device, enabled=False)
|
462 |
+
def forward(self, z, inverse=False):
|
463 |
+
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
|
464 |
+
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
|
465 |
+
W = torch.mm(self.p, torch.mm(L, U))
|
466 |
+
if inverse:
|
467 |
+
if not hasattr(self, "W_inverse"):
|
468 |
+
# inverse computation
|
469 |
+
W_inverse = W.float().inverse()
|
470 |
+
if z.type() == "torch.cuda.HalfTensor":
|
471 |
+
W_inverse = W_inverse.half()
|
472 |
+
|
473 |
+
self.W_inverse = W_inverse[..., None]
|
474 |
+
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
475 |
+
if not self.cache_inverse:
|
476 |
+
delattr(self, "W_inverse")
|
477 |
+
return z
|
478 |
+
else:
|
479 |
+
W = W[..., None]
|
480 |
+
z = F.conv1d(z, W, bias=None, stride=1, padding=0)
|
481 |
+
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
|
482 |
+
return z, log_det_W
|
483 |
+
|
484 |
+
|
485 |
+
class Invertible1x1Conv(torch.nn.Module):
|
486 |
+
"""
|
487 |
+
The layer outputs both the convolution, and the log determinant
|
488 |
+
of its weight matrix. If inverse=True it does convolution with
|
489 |
+
inverse
|
490 |
+
"""
|
491 |
+
|
492 |
+
def __init__(self, c, cache_inverse=False):
|
493 |
+
super(Invertible1x1Conv, self).__init__()
|
494 |
+
self.conv = torch.nn.Conv1d(
|
495 |
+
c, c, kernel_size=1, stride=1, padding=0, bias=False
|
496 |
+
)
|
497 |
+
|
498 |
+
# Sample a random orthonormal matrix to initialize weights
|
499 |
+
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
|
500 |
+
|
501 |
+
# Ensure determinant is 1.0 not -1.0
|
502 |
+
if torch.det(W) < 0:
|
503 |
+
W[:, 0] = -1 * W[:, 0]
|
504 |
+
W = W.view(c, c, 1)
|
505 |
+
self.conv.weight.data = W
|
506 |
+
self.cache_inverse = cache_inverse
|
507 |
+
|
508 |
+
def forward(self, z, inverse=False):
|
509 |
+
# DO NOT apply n_of_groups, as it doesn't account for padded sequences
|
510 |
+
W = self.conv.weight.squeeze()
|
511 |
+
|
512 |
+
if inverse:
|
513 |
+
if not hasattr(self, "W_inverse"):
|
514 |
+
# Inverse computation
|
515 |
+
W_inverse = W.float().inverse()
|
516 |
+
if z.type() == "torch.cuda.HalfTensor":
|
517 |
+
W_inverse = W_inverse.half()
|
518 |
+
|
519 |
+
self.W_inverse = W_inverse[..., None]
|
520 |
+
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
521 |
+
if not self.cache_inverse:
|
522 |
+
delattr(self, "W_inverse")
|
523 |
+
return z
|
524 |
+
else:
|
525 |
+
# Forward computation
|
526 |
+
log_det_W = torch.logdet(W).clone()
|
527 |
+
z = self.conv(z)
|
528 |
+
return z, log_det_W
|
529 |
+
|
530 |
+
|
531 |
+
class SimpleConvNet(torch.nn.Module):
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
n_mel_channels,
|
535 |
+
n_context_dim,
|
536 |
+
final_out_channels,
|
537 |
+
n_layers=2,
|
538 |
+
kernel_size=5,
|
539 |
+
with_dilation=True,
|
540 |
+
max_channels=1024,
|
541 |
+
zero_init=True,
|
542 |
+
use_partial_padding=True,
|
543 |
+
):
|
544 |
+
super(SimpleConvNet, self).__init__()
|
545 |
+
self.layers = torch.nn.ModuleList()
|
546 |
+
self.n_layers = n_layers
|
547 |
+
in_channels = n_mel_channels + n_context_dim
|
548 |
+
out_channels = -1
|
549 |
+
self.use_partial_padding = use_partial_padding
|
550 |
+
for i in range(n_layers):
|
551 |
+
dilation = 2**i if with_dilation else 1
|
552 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
553 |
+
out_channels = min(max_channels, in_channels * 2)
|
554 |
+
self.layers.append(
|
555 |
+
ConvNorm(
|
556 |
+
in_channels,
|
557 |
+
out_channels,
|
558 |
+
kernel_size=kernel_size,
|
559 |
+
stride=1,
|
560 |
+
padding=padding,
|
561 |
+
dilation=dilation,
|
562 |
+
bias=True,
|
563 |
+
w_init_gain="relu",
|
564 |
+
use_partial_padding=use_partial_padding,
|
565 |
+
)
|
566 |
+
)
|
567 |
+
in_channels = out_channels
|
568 |
+
|
569 |
+
self.last_layer = torch.nn.Conv1d(
|
570 |
+
out_channels, final_out_channels, kernel_size=1
|
571 |
+
)
|
572 |
+
|
573 |
+
if zero_init:
|
574 |
+
self.last_layer.weight.data *= 0
|
575 |
+
self.last_layer.bias.data *= 0
|
576 |
+
|
577 |
+
def forward(self, z_w_context, seq_lens: torch.Tensor = None):
|
578 |
+
# seq_lens: tensor array of sequence sequence lengths
|
579 |
+
# output should be b x n_mel_channels x z_w_context.shape(2)
|
580 |
+
mask = None
|
581 |
+
if seq_lens is not None:
|
582 |
+
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
|
583 |
+
|
584 |
+
for i in range(self.n_layers):
|
585 |
+
z_w_context = self.layers[i](z_w_context, mask)
|
586 |
+
z_w_context = torch.relu(z_w_context)
|
587 |
+
|
588 |
+
z_w_context = self.last_layer(z_w_context)
|
589 |
+
return z_w_context
|
590 |
+
|
591 |
+
|
592 |
+
class WN(torch.nn.Module):
|
593 |
+
"""
|
594 |
+
Adapted from WN() module in WaveGlow with modififcations to variable names
|
595 |
+
"""
|
596 |
+
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
n_in_channels,
|
600 |
+
n_context_dim,
|
601 |
+
n_layers,
|
602 |
+
n_channels,
|
603 |
+
kernel_size=5,
|
604 |
+
affine_activation="softplus",
|
605 |
+
use_partial_padding=True,
|
606 |
+
):
|
607 |
+
super(WN, self).__init__()
|
608 |
+
assert kernel_size % 2 == 1
|
609 |
+
assert n_channels % 2 == 0
|
610 |
+
self.n_layers = n_layers
|
611 |
+
self.n_channels = n_channels
|
612 |
+
self.in_layers = torch.nn.ModuleList()
|
613 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
614 |
+
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
|
615 |
+
start = torch.nn.utils.weight_norm(start, name="weight")
|
616 |
+
self.start = start
|
617 |
+
self.softplus = torch.nn.Softplus()
|
618 |
+
self.affine_activation = affine_activation
|
619 |
+
self.use_partial_padding = use_partial_padding
|
620 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
621 |
+
# do nothing at first. This helps with training stability
|
622 |
+
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
|
623 |
+
end.weight.data.zero_()
|
624 |
+
end.bias.data.zero_()
|
625 |
+
self.end = end
|
626 |
+
|
627 |
+
for i in range(n_layers):
|
628 |
+
dilation = 2**i
|
629 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
630 |
+
in_layer = ConvNorm(
|
631 |
+
n_channels,
|
632 |
+
n_channels,
|
633 |
+
kernel_size=kernel_size,
|
634 |
+
dilation=dilation,
|
635 |
+
padding=padding,
|
636 |
+
use_partial_padding=use_partial_padding,
|
637 |
+
use_weight_norm=True,
|
638 |
+
)
|
639 |
+
# in_layer = nn.Conv1d(n_channels, n_channels, kernel_size,
|
640 |
+
# dilation=dilation, padding=padding)
|
641 |
+
# in_layer = nn.utils.weight_norm(in_layer)
|
642 |
+
self.in_layers.append(in_layer)
|
643 |
+
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
|
644 |
+
res_skip_layer = nn.utils.weight_norm(res_skip_layer)
|
645 |
+
self.res_skip_layers.append(res_skip_layer)
|
646 |
+
|
647 |
+
def forward(
|
648 |
+
self,
|
649 |
+
forward_input: Tuple[torch.Tensor, torch.Tensor],
|
650 |
+
seq_lens: torch.Tensor = None,
|
651 |
+
):
|
652 |
+
z, context = forward_input
|
653 |
+
z = torch.cat((z, context), 1) # append context to z as well
|
654 |
+
z = self.start(z)
|
655 |
+
output = torch.zeros_like(z)
|
656 |
+
mask = None
|
657 |
+
if seq_lens is not None:
|
658 |
+
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
|
659 |
+
non_linearity = torch.relu
|
660 |
+
if self.affine_activation == "softplus":
|
661 |
+
non_linearity = self.softplus
|
662 |
+
|
663 |
+
for i in range(self.n_layers):
|
664 |
+
z = non_linearity(self.in_layers[i](z, mask))
|
665 |
+
res_skip_acts = non_linearity(self.res_skip_layers[i](z))
|
666 |
+
output = output + res_skip_acts
|
667 |
+
|
668 |
+
output = self.end(output) # [B, dim, seq_len]
|
669 |
+
return output
|
670 |
+
|
671 |
+
|
672 |
+
# Affine Coupling Layers
|
673 |
+
class SplineTransformationLayerAR(torch.nn.Module):
|
674 |
+
def __init__(
|
675 |
+
self,
|
676 |
+
n_in_channels,
|
677 |
+
n_context_dim,
|
678 |
+
n_layers,
|
679 |
+
affine_model="simple_conv",
|
680 |
+
kernel_size=1,
|
681 |
+
scaling_fn="exp",
|
682 |
+
affine_activation="softplus",
|
683 |
+
n_channels=1024,
|
684 |
+
n_bins=8,
|
685 |
+
left=-6,
|
686 |
+
right=6,
|
687 |
+
bottom=-6,
|
688 |
+
top=6,
|
689 |
+
use_quadratic=False,
|
690 |
+
):
|
691 |
+
super(SplineTransformationLayerAR, self).__init__()
|
692 |
+
self.n_in_channels = n_in_channels # input dimensions
|
693 |
+
self.left = left
|
694 |
+
self.right = right
|
695 |
+
self.bottom = bottom
|
696 |
+
self.top = top
|
697 |
+
self.n_bins = n_bins
|
698 |
+
self.spline_fn = piecewise_linear_transform
|
699 |
+
self.inv_spline_fn = piecewise_linear_inverse_transform
|
700 |
+
self.use_quadratic = use_quadratic
|
701 |
+
|
702 |
+
if self.use_quadratic:
|
703 |
+
self.spline_fn = unbounded_piecewise_quadratic_transform
|
704 |
+
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
|
705 |
+
self.n_bins = 2 * self.n_bins + 1
|
706 |
+
final_out_channels = self.n_in_channels * self.n_bins
|
707 |
+
|
708 |
+
# autoregressive flow, kernel size 1 and no dilation
|
709 |
+
self.param_predictor = SimpleConvNet(
|
710 |
+
n_context_dim,
|
711 |
+
0,
|
712 |
+
final_out_channels,
|
713 |
+
n_layers,
|
714 |
+
with_dilation=False,
|
715 |
+
kernel_size=1,
|
716 |
+
zero_init=True,
|
717 |
+
use_partial_padding=False,
|
718 |
+
)
|
719 |
+
|
720 |
+
# output is unnormalized bin weights
|
721 |
+
|
722 |
+
def normalize(self, z, inverse):
|
723 |
+
# normalize to [0, 1]
|
724 |
+
if inverse:
|
725 |
+
z = (z - self.bottom) / (self.top - self.bottom)
|
726 |
+
else:
|
727 |
+
z = (z - self.left) / (self.right - self.left)
|
728 |
+
|
729 |
+
return z
|
730 |
+
|
731 |
+
def denormalize(self, z, inverse):
|
732 |
+
if inverse:
|
733 |
+
z = z * (self.right - self.left) + self.left
|
734 |
+
else:
|
735 |
+
z = z * (self.top - self.bottom) + self.bottom
|
736 |
+
|
737 |
+
return z
|
738 |
+
|
739 |
+
def forward(self, z, context, inverse=False):
|
740 |
+
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
|
741 |
+
|
742 |
+
z = self.normalize(z, inverse)
|
743 |
+
|
744 |
+
if z.min() < 0.0 or z.max() > 1.0:
|
745 |
+
print("spline z scaled beyond [0, 1]", z.min(), z.max())
|
746 |
+
|
747 |
+
z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1)
|
748 |
+
affine_params = self.param_predictor(context)
|
749 |
+
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1)
|
750 |
+
with torch.autocast(device, enabled=False):
|
751 |
+
if self.use_quadratic:
|
752 |
+
w = q_tilde[:, :, : self.n_bins // 2]
|
753 |
+
v = q_tilde[:, :, self.n_bins // 2 :]
|
754 |
+
z_tformed, log_s = self.spline_fn(
|
755 |
+
z_reshaped.float(), w.float(), v.float(), inverse=inverse
|
756 |
+
)
|
757 |
+
else:
|
758 |
+
z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float())
|
759 |
+
|
760 |
+
z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
|
761 |
+
z = self.denormalize(z, inverse)
|
762 |
+
if inverse:
|
763 |
+
return z
|
764 |
+
|
765 |
+
log_s = log_s.reshape(b_s, t_s, -1)
|
766 |
+
log_s = log_s.permute(0, 2, 1)
|
767 |
+
log_s = log_s + c_s * (
|
768 |
+
np.log(self.top - self.bottom) - np.log(self.right - self.left)
|
769 |
+
)
|
770 |
+
return z, log_s
|
771 |
+
|
772 |
+
|
773 |
+
class SplineTransformationLayer(torch.nn.Module):
|
774 |
+
def __init__(
|
775 |
+
self,
|
776 |
+
n_mel_channels,
|
777 |
+
n_context_dim,
|
778 |
+
n_layers,
|
779 |
+
with_dilation=True,
|
780 |
+
kernel_size=5,
|
781 |
+
scaling_fn="exp",
|
782 |
+
affine_activation="softplus",
|
783 |
+
n_channels=1024,
|
784 |
+
n_bins=8,
|
785 |
+
left=-4,
|
786 |
+
right=4,
|
787 |
+
bottom=-4,
|
788 |
+
top=4,
|
789 |
+
use_quadratic=False,
|
790 |
+
):
|
791 |
+
super(SplineTransformationLayer, self).__init__()
|
792 |
+
self.n_mel_channels = n_mel_channels # input dimensions
|
793 |
+
self.half_mel_channels = int(n_mel_channels / 2) # half, because we split
|
794 |
+
self.left = left
|
795 |
+
self.right = right
|
796 |
+
self.bottom = bottom
|
797 |
+
self.top = top
|
798 |
+
self.n_bins = n_bins
|
799 |
+
self.spline_fn = piecewise_linear_transform
|
800 |
+
self.inv_spline_fn = piecewise_linear_inverse_transform
|
801 |
+
self.use_quadratic = use_quadratic
|
802 |
+
|
803 |
+
if self.use_quadratic:
|
804 |
+
self.spline_fn = unbounded_piecewise_quadratic_transform
|
805 |
+
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
|
806 |
+
self.n_bins = 2 * self.n_bins + 1
|
807 |
+
final_out_channels = self.half_mel_channels * self.n_bins
|
808 |
+
|
809 |
+
self.param_predictor = SimpleConvNet(
|
810 |
+
self.half_mel_channels,
|
811 |
+
n_context_dim,
|
812 |
+
final_out_channels,
|
813 |
+
n_layers,
|
814 |
+
with_dilation=with_dilation,
|
815 |
+
kernel_size=kernel_size,
|
816 |
+
zero_init=False,
|
817 |
+
)
|
818 |
+
|
819 |
+
# output is unnormalized bin weights
|
820 |
+
|
821 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
822 |
+
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
|
823 |
+
|
824 |
+
# condition on z_0, transform z_1
|
825 |
+
n_half = self.half_mel_channels
|
826 |
+
z_0, z_1 = z[:, :n_half], z[:, n_half:]
|
827 |
+
|
828 |
+
# normalize to [0,1]
|
829 |
+
if inverse:
|
830 |
+
z_1 = (z_1 - self.bottom) / (self.top - self.bottom)
|
831 |
+
else:
|
832 |
+
z_1 = (z_1 - self.left) / (self.right - self.left)
|
833 |
+
|
834 |
+
z_w_context = torch.cat((z_0, context), 1)
|
835 |
+
affine_params = self.param_predictor(z_w_context, seq_lens)
|
836 |
+
z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1)
|
837 |
+
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins)
|
838 |
+
|
839 |
+
with torch.autocast(device, enabled=False):
|
840 |
+
if self.use_quadratic:
|
841 |
+
w = q_tilde[:, :, : self.n_bins // 2]
|
842 |
+
v = q_tilde[:, :, self.n_bins // 2 :]
|
843 |
+
z_1_tformed, log_s = self.spline_fn(
|
844 |
+
z_1_reshaped.float(), w.float(), v.float(), inverse=inverse
|
845 |
+
)
|
846 |
+
if not inverse:
|
847 |
+
log_s = torch.sum(log_s, 1)
|
848 |
+
else:
|
849 |
+
if inverse:
|
850 |
+
z_1_tformed, _dc = self.inv_spline_fn(
|
851 |
+
z_1_reshaped.float(), q_tilde.float(), False
|
852 |
+
)
|
853 |
+
else:
|
854 |
+
z_1_tformed, log_s = self.spline_fn(
|
855 |
+
z_1_reshaped.float(), q_tilde.float()
|
856 |
+
)
|
857 |
+
|
858 |
+
z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
|
859 |
+
|
860 |
+
# undo [0, 1] normalization
|
861 |
+
if inverse:
|
862 |
+
z_1 = z_1 * (self.right - self.left) + self.left
|
863 |
+
z = torch.cat((z_0, z_1), dim=1)
|
864 |
+
return z
|
865 |
+
else: # training
|
866 |
+
z_1 = z_1 * (self.top - self.bottom) + self.bottom
|
867 |
+
z = torch.cat((z_0, z_1), dim=1)
|
868 |
+
log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * (
|
869 |
+
np.log(self.top - self.bottom) - np.log(self.right - self.left)
|
870 |
+
)
|
871 |
+
return z, log_s
|
872 |
+
|
873 |
+
|
874 |
+
class AffineTransformationLayer(torch.nn.Module):
|
875 |
+
def __init__(
|
876 |
+
self,
|
877 |
+
n_mel_channels,
|
878 |
+
n_context_dim,
|
879 |
+
n_layers,
|
880 |
+
affine_model="simple_conv",
|
881 |
+
with_dilation=True,
|
882 |
+
kernel_size=5,
|
883 |
+
scaling_fn="exp",
|
884 |
+
affine_activation="softplus",
|
885 |
+
n_channels=1024,
|
886 |
+
use_partial_padding=False,
|
887 |
+
):
|
888 |
+
super(AffineTransformationLayer, self).__init__()
|
889 |
+
if affine_model not in ("wavenet", "simple_conv"):
|
890 |
+
raise Exception("{} affine model not supported".format(affine_model))
|
891 |
+
if isinstance(scaling_fn, list):
|
892 |
+
if not all(
|
893 |
+
[x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]
|
894 |
+
):
|
895 |
+
raise Exception("{} scaling fn not supported".format(scaling_fn))
|
896 |
+
else:
|
897 |
+
if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"):
|
898 |
+
raise Exception("{} scaling fn not supported".format(scaling_fn))
|
899 |
+
|
900 |
+
self.affine_model = affine_model
|
901 |
+
self.scaling_fn = scaling_fn
|
902 |
+
if affine_model == "wavenet":
|
903 |
+
self.affine_param_predictor = WN(
|
904 |
+
int(n_mel_channels / 2),
|
905 |
+
n_context_dim,
|
906 |
+
n_layers=n_layers,
|
907 |
+
n_channels=n_channels,
|
908 |
+
affine_activation=affine_activation,
|
909 |
+
use_partial_padding=use_partial_padding,
|
910 |
+
)
|
911 |
+
elif affine_model == "simple_conv":
|
912 |
+
self.affine_param_predictor = SimpleConvNet(
|
913 |
+
int(n_mel_channels / 2),
|
914 |
+
n_context_dim,
|
915 |
+
n_mel_channels,
|
916 |
+
n_layers,
|
917 |
+
with_dilation=with_dilation,
|
918 |
+
kernel_size=kernel_size,
|
919 |
+
use_partial_padding=use_partial_padding,
|
920 |
+
)
|
921 |
+
self.n_mel_channels = n_mel_channels
|
922 |
+
|
923 |
+
def get_scaling_and_logs(self, scale_unconstrained):
|
924 |
+
if self.scaling_fn == "translate":
|
925 |
+
s = torch.exp(scale_unconstrained * 0)
|
926 |
+
log_s = scale_unconstrained * 0
|
927 |
+
elif self.scaling_fn == "exp":
|
928 |
+
s = torch.exp(scale_unconstrained)
|
929 |
+
log_s = scale_unconstrained # log(exp
|
930 |
+
elif self.scaling_fn == "tanh":
|
931 |
+
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
|
932 |
+
log_s = torch.log(s)
|
933 |
+
elif self.scaling_fn == "sigmoid":
|
934 |
+
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
|
935 |
+
log_s = torch.log(s)
|
936 |
+
elif isinstance(self.scaling_fn, list):
|
937 |
+
s_list, log_s_list = [], []
|
938 |
+
for i in range(scale_unconstrained.shape[1]):
|
939 |
+
scaling_i = self.scaling_fn[i]
|
940 |
+
if scaling_i == "translate":
|
941 |
+
s_i = torch.exp(scale_unconstrained[:i] * 0)
|
942 |
+
log_s_i = scale_unconstrained[:, i] * 0
|
943 |
+
elif scaling_i == "exp":
|
944 |
+
s_i = torch.exp(scale_unconstrained[:, i])
|
945 |
+
log_s_i = scale_unconstrained[:, i]
|
946 |
+
elif scaling_i == "tanh":
|
947 |
+
s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6
|
948 |
+
log_s_i = torch.log(s_i)
|
949 |
+
elif scaling_i == "sigmoid":
|
950 |
+
s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6
|
951 |
+
log_s_i = torch.log(s_i)
|
952 |
+
s_list.append(s_i[:, None])
|
953 |
+
log_s_list.append(log_s_i[:, None])
|
954 |
+
s = torch.cat(s_list, dim=1)
|
955 |
+
log_s = torch.cat(log_s_list, dim=1)
|
956 |
+
return s, log_s
|
957 |
+
|
958 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
959 |
+
n_half = int(self.n_mel_channels / 2)
|
960 |
+
z_0, z_1 = z[:, :n_half], z[:, n_half:]
|
961 |
+
if self.affine_model == "wavenet":
|
962 |
+
affine_params = self.affine_param_predictor(
|
963 |
+
(z_0, context), seq_lens=seq_lens
|
964 |
+
)
|
965 |
+
elif self.affine_model == "simple_conv":
|
966 |
+
z_w_context = torch.cat((z_0, context), 1)
|
967 |
+
affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens)
|
968 |
+
|
969 |
+
scale_unconstrained = affine_params[:, :n_half, :]
|
970 |
+
b = affine_params[:, n_half:, :]
|
971 |
+
s, log_s = self.get_scaling_and_logs(scale_unconstrained)
|
972 |
+
|
973 |
+
if inverse:
|
974 |
+
z_1 = (z_1 - b) / s
|
975 |
+
z = torch.cat((z_0, z_1), dim=1)
|
976 |
+
return z
|
977 |
+
else:
|
978 |
+
z_1 = s * z_1 + b
|
979 |
+
z = torch.cat((z_0, z_1), dim=1)
|
980 |
+
return z, log_s
|
981 |
+
|
982 |
+
|
983 |
+
class ConvAttention(torch.nn.Module):
|
984 |
+
def __init__(
|
985 |
+
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0
|
986 |
+
):
|
987 |
+
super(ConvAttention, self).__init__()
|
988 |
+
self.temperature = temperature
|
989 |
+
self.softmax = torch.nn.Softmax(dim=3)
|
990 |
+
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
991 |
+
|
992 |
+
self.key_proj = nn.Sequential(
|
993 |
+
ConvNorm(
|
994 |
+
n_text_channels,
|
995 |
+
n_text_channels * 2,
|
996 |
+
kernel_size=3,
|
997 |
+
bias=True,
|
998 |
+
w_init_gain="relu",
|
999 |
+
),
|
1000 |
+
torch.nn.ReLU(),
|
1001 |
+
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
self.query_proj = nn.Sequential(
|
1005 |
+
ConvNorm(
|
1006 |
+
n_mel_channels,
|
1007 |
+
n_mel_channels * 2,
|
1008 |
+
kernel_size=3,
|
1009 |
+
bias=True,
|
1010 |
+
w_init_gain="relu",
|
1011 |
+
),
|
1012 |
+
torch.nn.ReLU(),
|
1013 |
+
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
|
1014 |
+
torch.nn.ReLU(),
|
1015 |
+
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
def run_padded_sequence(
|
1019 |
+
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
|
1020 |
+
):
|
1021 |
+
"""Sorts input data by previded ordering (and un-ordering) and runs the
|
1022 |
+
packed data through the recurrent model
|
1023 |
+
|
1024 |
+
Args:
|
1025 |
+
sorted_idx (torch.tensor): 1D sorting index
|
1026 |
+
unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx)
|
1027 |
+
lens: lengths of input data (sorted in descending order)
|
1028 |
+
padded_data (torch.tensor): input sequences (padded)
|
1029 |
+
recurrent_model (nn.Module): recurrent model to run data through
|
1030 |
+
Returns:
|
1031 |
+
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
|
1032 |
+
unsorted, ordering
|
1033 |
+
"""
|
1034 |
+
|
1035 |
+
# sort the data by decreasing length using provided index
|
1036 |
+
# we assume batch index is in dim=1
|
1037 |
+
padded_data = padded_data[:, sorted_idx]
|
1038 |
+
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens)
|
1039 |
+
hidden_vectors = recurrent_model(padded_data)[0]
|
1040 |
+
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
|
1041 |
+
# unsort the results at dim=1 and return
|
1042 |
+
hidden_vectors = hidden_vectors[:, unsort_idx]
|
1043 |
+
return hidden_vectors
|
1044 |
+
|
1045 |
+
def forward(
|
1046 |
+
self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None
|
1047 |
+
):
|
1048 |
+
"""Attention mechanism for radtts. Unlike in Flowtron, we have no
|
1049 |
+
restrictions such as causality etc, since we only need this during
|
1050 |
+
training.
|
1051 |
+
|
1052 |
+
Args:
|
1053 |
+
queries (torch.tensor): B x C x T1 tensor (likely mel data)
|
1054 |
+
keys (torch.tensor): B x C2 x T2 tensor (text data)
|
1055 |
+
query_lens: lengths for sorting the queries in descending order
|
1056 |
+
mask (torch.tensor): uint8 binary mask for variable length entries
|
1057 |
+
(should be in the T2 domain)
|
1058 |
+
Output:
|
1059 |
+
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
|
1060 |
+
Final dim T2 should sum to 1
|
1061 |
+
"""
|
1062 |
+
temp = 0.0005
|
1063 |
+
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
1064 |
+
# Beware can only do this since query_dim = attn_dim = n_mel_channels
|
1065 |
+
queries_enc = self.query_proj(queries)
|
1066 |
+
|
1067 |
+
# Gaussian Isotopic Attention
|
1068 |
+
# B x n_attn_dims x T1 x T2
|
1069 |
+
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
|
1070 |
+
|
1071 |
+
# compute log-likelihood from gaussian
|
1072 |
+
eps = 1e-8
|
1073 |
+
attn = -temp * attn.sum(1, keepdim=True)
|
1074 |
+
if attn_prior is not None:
|
1075 |
+
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps)
|
1076 |
+
|
1077 |
+
attn_logprob = attn.clone()
|
1078 |
+
|
1079 |
+
if mask is not None:
|
1080 |
+
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
|
1081 |
+
|
1082 |
+
attn = self.softmax(attn) # softmax along T2
|
1083 |
+
return attn, attn_logprob
|
configs/bigvgan_config.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 32,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 80,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 22050,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": 8000,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"normalize_volume": true,
|
55 |
+
|
56 |
+
"num_workers": 4,
|
57 |
+
|
58 |
+
"dist_config": {
|
59 |
+
"dist_backend": "nccl",
|
60 |
+
"dist_url": "tcp://localhost:54321",
|
61 |
+
"world_size": 1
|
62 |
+
}
|
63 |
+
}
|
configs/radtts-pp-dap-model.json
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_config": {
|
3 |
+
"output_directory": "outdir_pp_model",
|
4 |
+
"epochs": 10000000,
|
5 |
+
"optim_algo": "RAdam",
|
6 |
+
"learning_rate": 0.001,
|
7 |
+
"weight_decay": 1e-06,
|
8 |
+
"sigma": 1.0,
|
9 |
+
"iters_per_checkpoint": 1000,
|
10 |
+
"batch_size": 16,
|
11 |
+
"seed": null,
|
12 |
+
"checkpoint_path": "",
|
13 |
+
"ignore_layers": [],
|
14 |
+
"ignore_layers_warmstart": [],
|
15 |
+
"finetune_layers": [],
|
16 |
+
"include_layers": [],
|
17 |
+
"vocoder_config_path": "models/hifigan_22khz_config.json",
|
18 |
+
"vocoder_checkpoint_path": "models/hifigan_ljs_generator_v1.pt",
|
19 |
+
"log_attribute_samples": true,
|
20 |
+
"log_decoder_samples": true,
|
21 |
+
"warmstart_checkpoint_path": "outdir_pp/model_100000",
|
22 |
+
"use_amp": true,
|
23 |
+
"grad_clip_val": 1.0,
|
24 |
+
"loss_weights": {
|
25 |
+
"blank_logprob": -1,
|
26 |
+
"ctc_loss_weight": 0.1,
|
27 |
+
"binarization_loss_weight": 1.0,
|
28 |
+
"dur_loss_weight": 1.0,
|
29 |
+
"f0_loss_weight": 1.0,
|
30 |
+
"energy_loss_weight": 1.0,
|
31 |
+
"vpred_loss_weight": 1.0
|
32 |
+
},
|
33 |
+
"binarization_start_iter": 0,
|
34 |
+
"kl_loss_start_iter": 0,
|
35 |
+
"unfreeze_modules": "all"
|
36 |
+
},
|
37 |
+
"data_config": {
|
38 |
+
"training_files": {
|
39 |
+
"LJS": {
|
40 |
+
"basedir": "filelists/",
|
41 |
+
"audiodir": "wavs",
|
42 |
+
"filelist": "3speakers_ukrainian_train_filelist_dc.txt",
|
43 |
+
"lmdbpath": ""
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"validation_files": {
|
47 |
+
"LJS": {
|
48 |
+
"basedir": "filelists/",
|
49 |
+
"audiodir": "wavs",
|
50 |
+
"filelist": "3speakers_ukrainian_val_filelist_dc.txt",
|
51 |
+
"lmdbpath": ""
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"dur_min": 0.1,
|
55 |
+
"dur_max": 10.2,
|
56 |
+
"sampling_rate": 22050,
|
57 |
+
"filter_length": 1024,
|
58 |
+
"hop_length": 256,
|
59 |
+
"win_length": 1024,
|
60 |
+
"n_mel_channels": 80,
|
61 |
+
"mel_fmin": 0.0,
|
62 |
+
"mel_fmax": 8000.0,
|
63 |
+
"f0_min": 80.0,
|
64 |
+
"f0_max": 640.0,
|
65 |
+
"max_wav_value": 32768.0,
|
66 |
+
"use_f0": true,
|
67 |
+
"use_log_f0": 0,
|
68 |
+
"use_energy_avg": true,
|
69 |
+
"use_scaled_energy": true,
|
70 |
+
"symbol_set": "ukrainian",
|
71 |
+
"cleaner_names": [
|
72 |
+
"ukrainian_cleaners"
|
73 |
+
],
|
74 |
+
"heteronyms_path": "tts_text_processing/heteronyms",
|
75 |
+
"phoneme_dict_path": "tts_text_processing/cmudict-0.7b",
|
76 |
+
"p_phoneme": 0.0,
|
77 |
+
"handle_phoneme": "word",
|
78 |
+
"handle_phoneme_ambiguous": "ignore",
|
79 |
+
"include_speakers": null,
|
80 |
+
"n_frames": -1,
|
81 |
+
"betabinom_cache_path": "/home/dmytro_chaplinsky/RAD-TTS/radtts-code/cache",
|
82 |
+
"lmdb_cache_path": "",
|
83 |
+
"use_attn_prior_masking": true,
|
84 |
+
"prepend_space_to_text": true,
|
85 |
+
"append_space_to_text": true,
|
86 |
+
"add_bos_eos_to_text": false,
|
87 |
+
"betabinom_scaling_factor": 1.0,
|
88 |
+
"distance_tx_unvoiced": false,
|
89 |
+
"mel_noise_scale": 0.0
|
90 |
+
},
|
91 |
+
"dist_config": {
|
92 |
+
"dist_backend": "nccl",
|
93 |
+
"dist_url": "tcp://localhost:54321"
|
94 |
+
},
|
95 |
+
"model_config": {
|
96 |
+
"n_speakers": 3,
|
97 |
+
"n_speaker_dim": 16,
|
98 |
+
"n_text": 185,
|
99 |
+
"n_text_dim": 512,
|
100 |
+
"n_flows": 8,
|
101 |
+
"n_conv_layers_per_step": 4,
|
102 |
+
"n_mel_channels": 80,
|
103 |
+
"n_hidden": 1024,
|
104 |
+
"mel_encoder_n_hidden": 512,
|
105 |
+
"dummy_speaker_embedding": false,
|
106 |
+
"n_early_size": 2,
|
107 |
+
"n_early_every": 2,
|
108 |
+
"n_group_size": 2,
|
109 |
+
"affine_model": "wavenet",
|
110 |
+
"include_modules": "decatndpmvpredapm",
|
111 |
+
"scaling_fn": "tanh",
|
112 |
+
"matrix_decomposition": "LUS",
|
113 |
+
"learn_alignments": true,
|
114 |
+
"use_speaker_emb_for_alignment": false,
|
115 |
+
"attn_straight_through_estimator": true,
|
116 |
+
"use_context_lstm": true,
|
117 |
+
"context_lstm_norm": "spectral",
|
118 |
+
"context_lstm_w_f0_and_energy": true,
|
119 |
+
"text_encoder_lstm_norm": "spectral",
|
120 |
+
"n_f0_dims": 1,
|
121 |
+
"n_energy_avg_dims": 1,
|
122 |
+
"use_first_order_features": false,
|
123 |
+
"unvoiced_bias_activation": "relu",
|
124 |
+
"decoder_use_partial_padding": true,
|
125 |
+
"decoder_use_unvoiced_bias": true,
|
126 |
+
"ap_pred_log_f0": true,
|
127 |
+
"ap_use_unvoiced_bias": false,
|
128 |
+
"ap_use_voiced_embeddings": true,
|
129 |
+
"dur_model_config": {
|
130 |
+
"name": "dap",
|
131 |
+
"hparams": {
|
132 |
+
"n_speaker_dim": 16,
|
133 |
+
"bottleneck_hparams": {
|
134 |
+
"in_dim": 512,
|
135 |
+
"reduction_factor": 16,
|
136 |
+
"norm": "weightnorm",
|
137 |
+
"non_linearity": "relu"
|
138 |
+
},
|
139 |
+
"take_log_of_input": true,
|
140 |
+
"arch_hparams": {
|
141 |
+
"out_dim": 1,
|
142 |
+
"n_layers": 2,
|
143 |
+
"n_channels": 256,
|
144 |
+
"kernel_size": 3,
|
145 |
+
"p_dropout": 0.25,
|
146 |
+
"in_dim": 48
|
147 |
+
}
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"f0_model_config": {
|
151 |
+
"name": "dap",
|
152 |
+
"hparams": {
|
153 |
+
"n_speaker_dim": 16,
|
154 |
+
"bottleneck_hparams": {
|
155 |
+
"in_dim": 512,
|
156 |
+
"reduction_factor": 16,
|
157 |
+
"norm": "weightnorm",
|
158 |
+
"non_linearity": "relu"
|
159 |
+
},
|
160 |
+
"take_log_of_input": false,
|
161 |
+
"use_transformer": false,
|
162 |
+
"arch_hparams": {
|
163 |
+
"out_dim": 1,
|
164 |
+
"n_layers": 2,
|
165 |
+
"n_channels": 256,
|
166 |
+
"kernel_size": 11,
|
167 |
+
"p_dropout": 0.5,
|
168 |
+
"in_dim": 48
|
169 |
+
}
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"energy_model_config": {
|
173 |
+
"name": "dap",
|
174 |
+
"hparams": {
|
175 |
+
"n_speaker_dim": 16,
|
176 |
+
"bottleneck_hparams": {
|
177 |
+
"in_dim": 512,
|
178 |
+
"reduction_factor": 16,
|
179 |
+
"norm": "weightnorm",
|
180 |
+
"non_linearity": "relu"
|
181 |
+
},
|
182 |
+
"take_log_of_input": false,
|
183 |
+
"use_transformer": false,
|
184 |
+
"arch_hparams": {
|
185 |
+
"out_dim": 1,
|
186 |
+
"n_layers": 2,
|
187 |
+
"n_channels": 256,
|
188 |
+
"kernel_size": 3,
|
189 |
+
"p_dropout": 0.25,
|
190 |
+
"in_dim": 48
|
191 |
+
}
|
192 |
+
}
|
193 |
+
},
|
194 |
+
"v_model_config": {
|
195 |
+
"name": "dap",
|
196 |
+
"hparams": {
|
197 |
+
"n_speaker_dim": 16,
|
198 |
+
"take_log_of_input": false,
|
199 |
+
"bottleneck_hparams": {
|
200 |
+
"in_dim": 512,
|
201 |
+
"reduction_factor": 16,
|
202 |
+
"norm": "weightnorm",
|
203 |
+
"non_linearity": "relu"
|
204 |
+
},
|
205 |
+
"arch_hparams": {
|
206 |
+
"out_dim": 1,
|
207 |
+
"n_layers": 2,
|
208 |
+
"n_channels": 256,
|
209 |
+
"kernel_size": 3,
|
210 |
+
"p_dropout": 0.5,
|
211 |
+
"lstm_type": "",
|
212 |
+
"use_linear": 1,
|
213 |
+
"in_dim": 48
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
217 |
+
}
|
218 |
+
}
|
data.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# Based on https://github.com/NVIDIA/flowtron/blob/master/data.py
|
23 |
+
# Original license text:
|
24 |
+
###############################################################################
|
25 |
+
#
|
26 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
27 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
28 |
+
# you may not use this file except in compliance with the License.
|
29 |
+
# You may obtain a copy of the License at
|
30 |
+
#
|
31 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
32 |
+
#
|
33 |
+
# Unless required by applicable law or agreed to in writing, software
|
34 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
35 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
36 |
+
# See the License for the specific language governing permissions and
|
37 |
+
# limitations under the License.
|
38 |
+
#
|
39 |
+
###############################################################################
|
40 |
+
|
41 |
+
import os
|
42 |
+
import argparse
|
43 |
+
import json
|
44 |
+
import numpy as np
|
45 |
+
import lmdb
|
46 |
+
import pickle as pkl
|
47 |
+
import torch
|
48 |
+
import torch.utils.data
|
49 |
+
from scipy.io.wavfile import read
|
50 |
+
from audio_processing import TacotronSTFT
|
51 |
+
from tts_text_processing.text_processing import TextProcessing
|
52 |
+
from scipy.stats import betabinom
|
53 |
+
from librosa import pyin
|
54 |
+
from common import update_params
|
55 |
+
from scipy.ndimage import distance_transform_edt as distance_transform
|
56 |
+
|
57 |
+
|
58 |
+
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
|
59 |
+
P = phoneme_count
|
60 |
+
M = mel_count
|
61 |
+
x = np.arange(0, P)
|
62 |
+
mel_text_probs = []
|
63 |
+
for i in range(1, M + 1):
|
64 |
+
a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
|
65 |
+
rv = betabinom(P - 1, a, b)
|
66 |
+
mel_i_prob = rv.pmf(x)
|
67 |
+
mel_text_probs.append(mel_i_prob)
|
68 |
+
return torch.tensor(np.array(mel_text_probs))
|
69 |
+
|
70 |
+
|
71 |
+
def load_wav_to_torch(full_path):
|
72 |
+
"""Loads wavdata into torch array"""
|
73 |
+
sampling_rate, data = read(full_path)
|
74 |
+
return torch.from_numpy(np.array(data)).float(), sampling_rate
|
75 |
+
|
76 |
+
|
77 |
+
class Data(torch.utils.data.Dataset):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
datasets,
|
81 |
+
filter_length,
|
82 |
+
hop_length,
|
83 |
+
win_length,
|
84 |
+
sampling_rate,
|
85 |
+
n_mel_channels,
|
86 |
+
mel_fmin,
|
87 |
+
mel_fmax,
|
88 |
+
f0_min,
|
89 |
+
f0_max,
|
90 |
+
max_wav_value,
|
91 |
+
use_f0,
|
92 |
+
use_energy_avg,
|
93 |
+
use_log_f0,
|
94 |
+
use_scaled_energy,
|
95 |
+
symbol_set,
|
96 |
+
cleaner_names,
|
97 |
+
heteronyms_path,
|
98 |
+
phoneme_dict_path,
|
99 |
+
p_phoneme,
|
100 |
+
handle_phoneme="word",
|
101 |
+
handle_phoneme_ambiguous="ignore",
|
102 |
+
speaker_ids=None,
|
103 |
+
include_speakers=None,
|
104 |
+
n_frames=-1,
|
105 |
+
use_attn_prior_masking=True,
|
106 |
+
prepend_space_to_text=True,
|
107 |
+
append_space_to_text=True,
|
108 |
+
add_bos_eos_to_text=False,
|
109 |
+
betabinom_cache_path="",
|
110 |
+
betabinom_scaling_factor=0.05,
|
111 |
+
lmdb_cache_path="",
|
112 |
+
dur_min=None,
|
113 |
+
dur_max=None,
|
114 |
+
combine_speaker_and_emotion=False,
|
115 |
+
**kwargs,
|
116 |
+
):
|
117 |
+
self.combine_speaker_and_emotion = combine_speaker_and_emotion
|
118 |
+
self.max_wav_value = max_wav_value
|
119 |
+
self.audio_lmdb_dict = {} # dictionary of lmdbs for audio data
|
120 |
+
self.data = self.load_data(datasets)
|
121 |
+
self.distance_tx_unvoiced = False
|
122 |
+
if "distance_tx_unvoiced" in kwargs.keys():
|
123 |
+
self.distance_tx_unvoiced = kwargs["distance_tx_unvoiced"]
|
124 |
+
self.stft = TacotronSTFT(
|
125 |
+
filter_length=filter_length,
|
126 |
+
hop_length=hop_length,
|
127 |
+
win_length=win_length,
|
128 |
+
sampling_rate=sampling_rate,
|
129 |
+
n_mel_channels=n_mel_channels,
|
130 |
+
mel_fmin=mel_fmin,
|
131 |
+
mel_fmax=mel_fmax,
|
132 |
+
)
|
133 |
+
|
134 |
+
self.do_mel_scaling = kwargs.get("do_mel_scaling", True)
|
135 |
+
self.mel_noise_scale = kwargs.get("mel_noise_scale", 0.0)
|
136 |
+
self.filter_length = filter_length
|
137 |
+
self.hop_length = hop_length
|
138 |
+
self.win_length = win_length
|
139 |
+
self.mel_fmin = mel_fmin
|
140 |
+
self.mel_fmax = mel_fmax
|
141 |
+
self.f0_min = f0_min
|
142 |
+
self.f0_max = f0_max
|
143 |
+
self.use_f0 = use_f0
|
144 |
+
self.use_log_f0 = use_log_f0
|
145 |
+
self.use_energy_avg = use_energy_avg
|
146 |
+
self.use_scaled_energy = use_scaled_energy
|
147 |
+
self.sampling_rate = sampling_rate
|
148 |
+
self.tp = TextProcessing(
|
149 |
+
symbol_set,
|
150 |
+
cleaner_names,
|
151 |
+
heteronyms_path,
|
152 |
+
phoneme_dict_path,
|
153 |
+
p_phoneme=p_phoneme,
|
154 |
+
handle_phoneme=handle_phoneme,
|
155 |
+
handle_phoneme_ambiguous=handle_phoneme_ambiguous,
|
156 |
+
prepend_space_to_text=prepend_space_to_text,
|
157 |
+
append_space_to_text=append_space_to_text,
|
158 |
+
add_bos_eos_to_text=add_bos_eos_to_text,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.dur_min = dur_min
|
162 |
+
self.dur_max = dur_max
|
163 |
+
if speaker_ids is None or speaker_ids == "":
|
164 |
+
self.speaker_ids = self.create_speaker_lookup_table(self.data)
|
165 |
+
else:
|
166 |
+
self.speaker_ids = speaker_ids
|
167 |
+
|
168 |
+
print("Number of files", len(self.data))
|
169 |
+
if include_speakers is not None:
|
170 |
+
for speaker_set, include in include_speakers:
|
171 |
+
self.filter_by_speakers_(speaker_set, include)
|
172 |
+
print("Number of files after speaker filtering", len(self.data))
|
173 |
+
|
174 |
+
if dur_min is not None and dur_max is not None:
|
175 |
+
self.filter_by_duration_(dur_min, dur_max)
|
176 |
+
print("Number of files after duration filtering", len(self.data))
|
177 |
+
|
178 |
+
self.use_attn_prior_masking = bool(use_attn_prior_masking)
|
179 |
+
self.prepend_space_to_text = bool(prepend_space_to_text)
|
180 |
+
self.append_space_to_text = bool(append_space_to_text)
|
181 |
+
self.betabinom_cache_path = betabinom_cache_path
|
182 |
+
self.betabinom_scaling_factor = betabinom_scaling_factor
|
183 |
+
self.lmdb_cache_path = lmdb_cache_path
|
184 |
+
if self.lmdb_cache_path != "":
|
185 |
+
self.cache_data_lmdb = lmdb.open(
|
186 |
+
self.lmdb_cache_path, readonly=True, max_readers=1024, lock=False
|
187 |
+
).begin()
|
188 |
+
|
189 |
+
# # make sure caching path exists
|
190 |
+
# if not os.path.exists(self.betabinom_cache_path):
|
191 |
+
# os.makedirs(self.betabinom_cache_path)
|
192 |
+
|
193 |
+
print("Dataloader initialized with no augmentations")
|
194 |
+
self.speaker_map = None
|
195 |
+
if "speaker_map" in kwargs:
|
196 |
+
self.speaker_map = kwargs["speaker_map"]
|
197 |
+
|
198 |
+
def load_data(self, datasets, split="|"):
|
199 |
+
dataset = []
|
200 |
+
for dset_name, dset_dict in datasets.items():
|
201 |
+
folder_path = dset_dict["basedir"]
|
202 |
+
audiodir = dset_dict["audiodir"]
|
203 |
+
filename = dset_dict["filelist"]
|
204 |
+
audio_lmdb_key = None
|
205 |
+
if "lmdbpath" in dset_dict.keys() and len(dset_dict["lmdbpath"]) > 0:
|
206 |
+
self.audio_lmdb_dict[dset_name] = lmdb.open(
|
207 |
+
dset_dict["lmdbpath"], readonly=True, max_readers=256, lock=False
|
208 |
+
).begin()
|
209 |
+
audio_lmdb_key = dset_name
|
210 |
+
|
211 |
+
wav_folder_prefix = os.path.join(folder_path, audiodir)
|
212 |
+
filelist_path = os.path.join(folder_path, filename)
|
213 |
+
with open(filelist_path, encoding="utf-8") as f:
|
214 |
+
data = [line.strip().split(split) for line in f]
|
215 |
+
|
216 |
+
for d in data:
|
217 |
+
emotion = "other" if len(d) == 3 else d[3]
|
218 |
+
duration = -1 if len(d) == 3 else d[4]
|
219 |
+
dataset.append(
|
220 |
+
{
|
221 |
+
"audiopath": os.path.join(wav_folder_prefix, d[0]),
|
222 |
+
"text": d[1],
|
223 |
+
"speaker": d[2] + "-" + emotion
|
224 |
+
if self.combine_speaker_and_emotion
|
225 |
+
else d[2],
|
226 |
+
"emotion": emotion,
|
227 |
+
"duration": float(duration),
|
228 |
+
"lmdb_key": audio_lmdb_key,
|
229 |
+
}
|
230 |
+
)
|
231 |
+
return dataset
|
232 |
+
|
233 |
+
def filter_by_speakers_(self, speakers, include=True):
|
234 |
+
print("Include spaker {}: {}".format(speakers, include))
|
235 |
+
if include:
|
236 |
+
self.data = [x for x in self.data if x["speaker"] in speakers]
|
237 |
+
else:
|
238 |
+
self.data = [x for x in self.data if x["speaker"] not in speakers]
|
239 |
+
|
240 |
+
def filter_by_duration_(self, dur_min, dur_max):
|
241 |
+
self.data = [
|
242 |
+
x
|
243 |
+
for x in self.data
|
244 |
+
if x["duration"] == -1
|
245 |
+
or (x["duration"] >= dur_min and x["duration"] <= dur_max)
|
246 |
+
]
|
247 |
+
|
248 |
+
def create_speaker_lookup_table(self, data):
|
249 |
+
speaker_ids = np.sort(np.unique([x["speaker"] for x in data]))
|
250 |
+
d = {speaker_ids[i]: i for i in range(len(speaker_ids))}
|
251 |
+
print("Number of speakers:", len(d))
|
252 |
+
print("Speaker IDS", d)
|
253 |
+
return d
|
254 |
+
|
255 |
+
def f0_normalize(self, x):
|
256 |
+
if self.use_log_f0:
|
257 |
+
mask = x >= self.f0_min
|
258 |
+
x[mask] = torch.log(x[mask])
|
259 |
+
x[~mask] = 0.0
|
260 |
+
|
261 |
+
return x
|
262 |
+
|
263 |
+
def f0_denormalize(self, x):
|
264 |
+
if self.use_log_f0:
|
265 |
+
log_f0_min = np.log(self.f0_min)
|
266 |
+
mask = x >= log_f0_min
|
267 |
+
x[mask] = torch.exp(x[mask])
|
268 |
+
x[~mask] = 0.0
|
269 |
+
x[x <= 0.0] = 0.0
|
270 |
+
|
271 |
+
return x
|
272 |
+
|
273 |
+
def energy_avg_normalize(self, x):
|
274 |
+
if self.use_scaled_energy:
|
275 |
+
x = (x + 20.0) / 20.0
|
276 |
+
return x
|
277 |
+
|
278 |
+
def energy_avg_denormalize(self, x):
|
279 |
+
if self.use_scaled_energy:
|
280 |
+
x = x * 20.0 - 20.0
|
281 |
+
return x
|
282 |
+
|
283 |
+
def get_f0_pvoiced(
|
284 |
+
self,
|
285 |
+
audio,
|
286 |
+
sampling_rate=22050,
|
287 |
+
frame_length=1024,
|
288 |
+
hop_length=256,
|
289 |
+
f0_min=100,
|
290 |
+
f0_max=300,
|
291 |
+
):
|
292 |
+
audio_norm = audio / self.max_wav_value
|
293 |
+
f0, voiced_mask, p_voiced = pyin(
|
294 |
+
audio_norm,
|
295 |
+
f0_min,
|
296 |
+
f0_max,
|
297 |
+
sampling_rate,
|
298 |
+
frame_length=frame_length,
|
299 |
+
win_length=frame_length // 2,
|
300 |
+
hop_length=hop_length,
|
301 |
+
)
|
302 |
+
f0[~voiced_mask] = 0.0
|
303 |
+
f0 = torch.FloatTensor(f0)
|
304 |
+
p_voiced = torch.FloatTensor(p_voiced)
|
305 |
+
voiced_mask = torch.FloatTensor(voiced_mask)
|
306 |
+
return f0, voiced_mask, p_voiced
|
307 |
+
|
308 |
+
def get_energy_average(self, mel):
|
309 |
+
energy_avg = mel.mean(0)
|
310 |
+
energy_avg = self.energy_avg_normalize(energy_avg)
|
311 |
+
return energy_avg
|
312 |
+
|
313 |
+
def get_mel(self, audio):
|
314 |
+
audio_norm = audio / self.max_wav_value
|
315 |
+
audio_norm = audio_norm.unsqueeze(0)
|
316 |
+
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
317 |
+
melspec = self.stft.mel_spectrogram(audio_norm)
|
318 |
+
melspec = torch.squeeze(melspec, 0)
|
319 |
+
if self.do_mel_scaling:
|
320 |
+
melspec = (melspec + 5.5) / 2
|
321 |
+
if self.mel_noise_scale > 0:
|
322 |
+
melspec += torch.randn_like(melspec) * self.mel_noise_scale
|
323 |
+
return melspec
|
324 |
+
|
325 |
+
def get_speaker_id(self, speaker):
|
326 |
+
if self.speaker_map is not None and speaker in self.speaker_map:
|
327 |
+
speaker = self.speaker_map[speaker]
|
328 |
+
|
329 |
+
return torch.LongTensor([self.speaker_ids[speaker]])
|
330 |
+
|
331 |
+
def get_text(self, text):
|
332 |
+
text = self.tp.encode_text(text)
|
333 |
+
text = torch.LongTensor(text)
|
334 |
+
return text
|
335 |
+
|
336 |
+
def get_attention_prior(self, n_tokens, n_frames):
|
337 |
+
# cache the entire attn_prior by filename
|
338 |
+
if self.use_attn_prior_masking:
|
339 |
+
filename = "{}_{}".format(n_tokens, n_frames)
|
340 |
+
prior_path = os.path.join(self.betabinom_cache_path, filename)
|
341 |
+
prior_path += "_prior.pth"
|
342 |
+
if self.lmdb_cache_path != "":
|
343 |
+
attn_prior = pkl.loads(
|
344 |
+
self.cache_data_lmdb.get(prior_path.encode("ascii"))
|
345 |
+
)
|
346 |
+
elif os.path.exists(prior_path):
|
347 |
+
attn_prior = torch.load(prior_path)
|
348 |
+
else:
|
349 |
+
attn_prior = beta_binomial_prior_distribution(
|
350 |
+
n_tokens, n_frames, self.betabinom_scaling_factor
|
351 |
+
)
|
352 |
+
torch.save(attn_prior, prior_path)
|
353 |
+
else:
|
354 |
+
attn_prior = torch.ones(n_frames, n_tokens) # all ones baseline
|
355 |
+
|
356 |
+
return attn_prior
|
357 |
+
|
358 |
+
def __getitem__(self, index):
|
359 |
+
data = self.data[index]
|
360 |
+
audiopath, text = data["audiopath"], data["text"]
|
361 |
+
speaker_id = data["speaker"]
|
362 |
+
|
363 |
+
if data["lmdb_key"] is not None:
|
364 |
+
data_dict = pkl.loads(
|
365 |
+
self.audio_lmdb_dict[data["lmdb_key"]].get(audiopath.encode("ascii"))
|
366 |
+
)
|
367 |
+
audio = data_dict["audio"]
|
368 |
+
sampling_rate = data_dict["sampling_rate"]
|
369 |
+
else:
|
370 |
+
audio, sampling_rate = load_wav_to_torch(audiopath)
|
371 |
+
|
372 |
+
if sampling_rate != self.sampling_rate:
|
373 |
+
raise ValueError(
|
374 |
+
"{} SR doesn't match target {} SR".format(
|
375 |
+
sampling_rate, self.sampling_rate
|
376 |
+
)
|
377 |
+
)
|
378 |
+
|
379 |
+
mel = self.get_mel(audio)
|
380 |
+
f0 = None
|
381 |
+
p_voiced = None
|
382 |
+
voiced_mask = None
|
383 |
+
if self.use_f0:
|
384 |
+
filename = "_".join(audiopath.split("/")[-3:])
|
385 |
+
f0_path = os.path.join(self.betabinom_cache_path, filename)
|
386 |
+
f0_path += "_f0_sr{}_fl{}_hl{}_f0min{}_f0max{}_log{}.pt".format(
|
387 |
+
self.sampling_rate,
|
388 |
+
self.filter_length,
|
389 |
+
self.hop_length,
|
390 |
+
self.f0_min,
|
391 |
+
self.f0_max,
|
392 |
+
self.use_log_f0,
|
393 |
+
)
|
394 |
+
|
395 |
+
dikt = None
|
396 |
+
if len(self.lmdb_cache_path) > 0:
|
397 |
+
dikt = pkl.loads(self.cache_data_lmdb.get(f0_path.encode("ascii")))
|
398 |
+
f0 = dikt["f0"]
|
399 |
+
p_voiced = dikt["p_voiced"]
|
400 |
+
voiced_mask = dikt["voiced_mask"]
|
401 |
+
elif os.path.exists(f0_path):
|
402 |
+
try:
|
403 |
+
dikt = torch.load(f0_path)
|
404 |
+
except:
|
405 |
+
print(f"f0 loading from {f0_path} is broken, recomputing.")
|
406 |
+
|
407 |
+
if dikt is not None:
|
408 |
+
f0 = dikt["f0"]
|
409 |
+
p_voiced = dikt["p_voiced"]
|
410 |
+
voiced_mask = dikt["voiced_mask"]
|
411 |
+
else:
|
412 |
+
f0, voiced_mask, p_voiced = self.get_f0_pvoiced(
|
413 |
+
audio.cpu().numpy(),
|
414 |
+
self.sampling_rate,
|
415 |
+
self.filter_length,
|
416 |
+
self.hop_length,
|
417 |
+
self.f0_min,
|
418 |
+
self.f0_max,
|
419 |
+
)
|
420 |
+
print("saving f0 to {}".format(f0_path))
|
421 |
+
torch.save(
|
422 |
+
{"f0": f0, "voiced_mask": voiced_mask, "p_voiced": p_voiced},
|
423 |
+
f0_path,
|
424 |
+
)
|
425 |
+
if f0 is None:
|
426 |
+
raise Exception("STOP, BROKEN F0 {}".format(audiopath))
|
427 |
+
|
428 |
+
f0 = self.f0_normalize(f0)
|
429 |
+
if self.distance_tx_unvoiced:
|
430 |
+
mask = f0 <= 0.0
|
431 |
+
distance_map = np.log(distance_transform(mask))
|
432 |
+
distance_map[distance_map <= 0] = 0.0
|
433 |
+
f0 = f0 - distance_map
|
434 |
+
|
435 |
+
energy_avg = None
|
436 |
+
if self.use_energy_avg:
|
437 |
+
energy_avg = self.get_energy_average(mel)
|
438 |
+
if self.use_scaled_energy and energy_avg.min() < 0.0:
|
439 |
+
print(audiopath, "has scaled energy avg smaller than 0")
|
440 |
+
|
441 |
+
speaker_id = self.get_speaker_id(speaker_id)
|
442 |
+
text_encoded = self.get_text(text)
|
443 |
+
|
444 |
+
attn_prior = self.get_attention_prior(text_encoded.shape[0], mel.shape[1])
|
445 |
+
|
446 |
+
if not self.use_attn_prior_masking:
|
447 |
+
attn_prior = None
|
448 |
+
|
449 |
+
return {
|
450 |
+
"mel": mel,
|
451 |
+
"speaker_id": speaker_id,
|
452 |
+
"text_encoded": text_encoded,
|
453 |
+
"audiopath": audiopath,
|
454 |
+
"attn_prior": attn_prior,
|
455 |
+
"f0": f0,
|
456 |
+
"p_voiced": p_voiced,
|
457 |
+
"voiced_mask": voiced_mask,
|
458 |
+
"energy_avg": energy_avg,
|
459 |
+
}
|
460 |
+
|
461 |
+
def __len__(self):
|
462 |
+
return len(self.data)
|
463 |
+
|
464 |
+
|
465 |
+
class DataCollate:
|
466 |
+
"""Zero-pads model inputs and targets given number of steps"""
|
467 |
+
|
468 |
+
def __init__(self, n_frames_per_step=1):
|
469 |
+
self.n_frames_per_step = n_frames_per_step
|
470 |
+
|
471 |
+
def __call__(self, batch):
|
472 |
+
"""Collate from normalized data"""
|
473 |
+
# Right zero-pad all one-hot text sequences to max input length
|
474 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
475 |
+
torch.LongTensor([len(x["text_encoded"]) for x in batch]),
|
476 |
+
dim=0,
|
477 |
+
descending=True,
|
478 |
+
)
|
479 |
+
|
480 |
+
max_input_len = input_lengths[0]
|
481 |
+
text_padded = torch.LongTensor(len(batch), max_input_len)
|
482 |
+
text_padded.zero_()
|
483 |
+
|
484 |
+
for i in range(len(ids_sorted_decreasing)):
|
485 |
+
text = batch[ids_sorted_decreasing[i]]["text_encoded"]
|
486 |
+
text_padded[i, : text.size(0)] = text
|
487 |
+
|
488 |
+
# Right zero-pad mel-spec
|
489 |
+
num_mel_channels = batch[0]["mel"].size(0)
|
490 |
+
max_target_len = max([x["mel"].size(1) for x in batch])
|
491 |
+
|
492 |
+
# include mel padded, gate padded and speaker ids
|
493 |
+
mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
|
494 |
+
mel_padded.zero_()
|
495 |
+
f0_padded = None
|
496 |
+
p_voiced_padded = None
|
497 |
+
voiced_mask_padded = None
|
498 |
+
energy_avg_padded = None
|
499 |
+
if batch[0]["f0"] is not None:
|
500 |
+
f0_padded = torch.FloatTensor(len(batch), max_target_len)
|
501 |
+
f0_padded.zero_()
|
502 |
+
|
503 |
+
if batch[0]["p_voiced"] is not None:
|
504 |
+
p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
|
505 |
+
p_voiced_padded.zero_()
|
506 |
+
|
507 |
+
if batch[0]["voiced_mask"] is not None:
|
508 |
+
voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
|
509 |
+
voiced_mask_padded.zero_()
|
510 |
+
|
511 |
+
if batch[0]["energy_avg"] is not None:
|
512 |
+
energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
|
513 |
+
energy_avg_padded.zero_()
|
514 |
+
|
515 |
+
attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
|
516 |
+
attn_prior_padded.zero_()
|
517 |
+
|
518 |
+
output_lengths = torch.LongTensor(len(batch))
|
519 |
+
speaker_ids = torch.LongTensor(len(batch))
|
520 |
+
audiopaths = []
|
521 |
+
for i in range(len(ids_sorted_decreasing)):
|
522 |
+
mel = batch[ids_sorted_decreasing[i]]["mel"]
|
523 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
524 |
+
if batch[ids_sorted_decreasing[i]]["f0"] is not None:
|
525 |
+
f0 = batch[ids_sorted_decreasing[i]]["f0"]
|
526 |
+
f0_padded[i, : len(f0)] = f0
|
527 |
+
|
528 |
+
if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
|
529 |
+
voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
|
530 |
+
voiced_mask_padded[i, : len(f0)] = voiced_mask
|
531 |
+
|
532 |
+
if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
|
533 |
+
p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
|
534 |
+
p_voiced_padded[i, : len(f0)] = p_voiced
|
535 |
+
|
536 |
+
if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
|
537 |
+
energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
|
538 |
+
energy_avg_padded[i, : len(energy_avg)] = energy_avg
|
539 |
+
|
540 |
+
output_lengths[i] = mel.size(1)
|
541 |
+
speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
|
542 |
+
audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
|
543 |
+
audiopaths.append(audiopath)
|
544 |
+
cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
|
545 |
+
if cur_attn_prior is None:
|
546 |
+
attn_prior_padded = None
|
547 |
+
else:
|
548 |
+
attn_prior_padded[
|
549 |
+
i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
|
550 |
+
] = cur_attn_prior
|
551 |
+
|
552 |
+
return {
|
553 |
+
"mel": mel_padded,
|
554 |
+
"speaker_ids": speaker_ids,
|
555 |
+
"text": text_padded,
|
556 |
+
"input_lengths": input_lengths,
|
557 |
+
"output_lengths": output_lengths,
|
558 |
+
"audiopaths": audiopaths,
|
559 |
+
"attn_prior": attn_prior_padded,
|
560 |
+
"f0": f0_padded,
|
561 |
+
"p_voiced": p_voiced_padded,
|
562 |
+
"voiced_mask": voiced_mask_padded,
|
563 |
+
"energy_avg": energy_avg_padded,
|
564 |
+
}
|
565 |
+
|
566 |
+
|
567 |
+
# ===================================================================
|
568 |
+
# Takes directory of clean audio and makes directory of spectrograms
|
569 |
+
# Useful for making test sets
|
570 |
+
# ===================================================================
|
571 |
+
if __name__ == "__main__":
|
572 |
+
# Get defaults so it can work with no Sacred
|
573 |
+
parser = argparse.ArgumentParser()
|
574 |
+
parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
|
575 |
+
parser.add_argument("-p", "--params", nargs="+", default=[])
|
576 |
+
args = parser.parse_args()
|
577 |
+
args.rank = 0
|
578 |
+
|
579 |
+
# Parse configs. Globals nicer in this case
|
580 |
+
with open(args.config) as f:
|
581 |
+
data = f.read()
|
582 |
+
|
583 |
+
config = json.loads(data)
|
584 |
+
update_params(config, args.params)
|
585 |
+
print(config)
|
586 |
+
|
587 |
+
data_config = config["data_config"]
|
588 |
+
|
589 |
+
ignore_keys = ["training_files", "validation_files"]
|
590 |
+
trainset = Data(
|
591 |
+
data_config["training_files"],
|
592 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
593 |
+
)
|
594 |
+
|
595 |
+
valset = Data(
|
596 |
+
data_config["validation_files"],
|
597 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
598 |
+
speaker_ids=trainset.speaker_ids,
|
599 |
+
)
|
600 |
+
|
601 |
+
collate_fn = DataCollate()
|
602 |
+
|
603 |
+
for dataset in (trainset, valset):
|
604 |
+
for i, batch in enumerate(dataset):
|
605 |
+
out = batch
|
606 |
+
print("{}/{}".format(i, len(dataset)))
|
distributed.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original source: https://github.com/NVIDIA/waveglow/blob/master/distributed.py
|
2 |
+
#
|
3 |
+
# Original license text:
|
4 |
+
# *****************************************************************************
|
5 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
#
|
7 |
+
# Redistribution and use in source and binary forms, with or without
|
8 |
+
# modification, are permitted provided that the following conditions are met:
|
9 |
+
# * Redistributions of source code must retain the above copyright
|
10 |
+
# notice, this list of conditions and the following disclaimer.
|
11 |
+
# * Redistributions in binary form must reproduce the above copyright
|
12 |
+
# notice, this list of conditions and the following disclaimer in the
|
13 |
+
# documentation and/or other materials provided with the distribution.
|
14 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
15 |
+
# names of its contributors may be used to endorse or promote products
|
16 |
+
# derived from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
19 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
20 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
22 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
23 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
24 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
25 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
26 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
27 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 |
+
#
|
29 |
+
# *****************************************************************************
|
30 |
+
|
31 |
+
import os
|
32 |
+
import torch
|
33 |
+
import torch.distributed as dist
|
34 |
+
from torch.autograd import Variable
|
35 |
+
|
36 |
+
|
37 |
+
def reduce_tensor(tensor, num_gpus, reduce_dst=None):
|
38 |
+
if num_gpus <= 1: # pass-thru
|
39 |
+
return tensor
|
40 |
+
rt = tensor.clone()
|
41 |
+
if reduce_dst is not None:
|
42 |
+
dist.reduce(rt, reduce_dst, op=dist.ReduceOp.SUM)
|
43 |
+
else:
|
44 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
45 |
+
rt /= num_gpus
|
46 |
+
return rt
|
47 |
+
|
48 |
+
|
49 |
+
def init_distributed(rank, num_gpus, dist_backend, dist_url):
|
50 |
+
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
51 |
+
|
52 |
+
print("> initializing distributed for rank {} out of {}".format(rank, num_gpus))
|
53 |
+
|
54 |
+
# Set cuda device so everything is done on the right GPU.
|
55 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
56 |
+
|
57 |
+
init_method = "tcp://"
|
58 |
+
master_ip = os.getenv("MASTER_ADDR", "localhost")
|
59 |
+
master_port = os.getenv("MASTER_PORT", "6000")
|
60 |
+
init_method += master_ip + ":" + master_port
|
61 |
+
torch.distributed.init_process_group(
|
62 |
+
backend="nccl", world_size=num_gpus, rank=rank, init_method=init_method
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def _flatten_dense_tensors(tensors):
|
67 |
+
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
68 |
+
same dense type.
|
69 |
+
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
70 |
+
buffer. Element-wise operation on this buffer will be equivalent to
|
71 |
+
operating individually.
|
72 |
+
Arguments:
|
73 |
+
tensors (Iterable[Tensor]): dense tensors to flatten.
|
74 |
+
Returns:
|
75 |
+
A contiguous 1D buffer containing input tensors.
|
76 |
+
"""
|
77 |
+
if len(tensors) == 1:
|
78 |
+
return tensors[0].contiguous().view(-1)
|
79 |
+
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
|
80 |
+
return flat
|
81 |
+
|
82 |
+
|
83 |
+
def _unflatten_dense_tensors(flat, tensors):
|
84 |
+
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
|
85 |
+
same dense type, and that flat is given by _flatten_dense_tensors.
|
86 |
+
Arguments:
|
87 |
+
flat (Tensor): flattened dense tensors to unflatten.
|
88 |
+
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
89 |
+
unflatten flat.
|
90 |
+
Returns:
|
91 |
+
Unflattened dense tensors with sizes same as tensors and values from
|
92 |
+
flat.
|
93 |
+
"""
|
94 |
+
outputs = []
|
95 |
+
offset = 0
|
96 |
+
for tensor in tensors:
|
97 |
+
numel = tensor.numel()
|
98 |
+
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
|
99 |
+
offset += numel
|
100 |
+
return tuple(outputs)
|
101 |
+
|
102 |
+
|
103 |
+
def apply_gradient_allreduce(module):
|
104 |
+
"""
|
105 |
+
Modifies existing model to do gradient allreduce, but doesn't change class
|
106 |
+
so you don't need "module"
|
107 |
+
"""
|
108 |
+
if not hasattr(dist, "_backend"):
|
109 |
+
module.warn_on_half = True
|
110 |
+
else:
|
111 |
+
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
|
112 |
+
|
113 |
+
for p in module.state_dict().values():
|
114 |
+
if not torch.is_tensor(p):
|
115 |
+
continue
|
116 |
+
dist.broadcast(p, 0)
|
117 |
+
|
118 |
+
def allreduce_params():
|
119 |
+
if module.needs_reduction:
|
120 |
+
module.needs_reduction = False
|
121 |
+
buckets = {}
|
122 |
+
for param in module.parameters():
|
123 |
+
if param.requires_grad and param.grad is not None:
|
124 |
+
tp = type(param.data)
|
125 |
+
if tp not in buckets:
|
126 |
+
buckets[tp] = []
|
127 |
+
buckets[tp].append(param)
|
128 |
+
if module.warn_on_half:
|
129 |
+
if torch.cuda.HalfTensor in buckets:
|
130 |
+
print(
|
131 |
+
"WARNING: gloo dist backend for half parameters may be extremely slow."
|
132 |
+
+ " It is recommended to use the NCCL backend in this case. This currently requires"
|
133 |
+
+ "PyTorch built from top of tree master."
|
134 |
+
)
|
135 |
+
module.warn_on_half = False
|
136 |
+
|
137 |
+
for tp in buckets:
|
138 |
+
bucket = buckets[tp]
|
139 |
+
grads = [param.grad.data for param in bucket]
|
140 |
+
coalesced = _flatten_dense_tensors(grads)
|
141 |
+
dist.all_reduce(coalesced)
|
142 |
+
coalesced /= dist.get_world_size()
|
143 |
+
for buf, synced in zip(
|
144 |
+
grads, _unflatten_dense_tensors(coalesced, grads)
|
145 |
+
):
|
146 |
+
buf.copy_(synced)
|
147 |
+
|
148 |
+
for param in list(module.parameters()):
|
149 |
+
|
150 |
+
def allreduce_hook(*unused):
|
151 |
+
Variable._execution_engine.queue_callback(allreduce_params)
|
152 |
+
|
153 |
+
if param.requires_grad:
|
154 |
+
param.register_hook(allreduce_hook)
|
155 |
+
dir(param)
|
156 |
+
|
157 |
+
def set_needs_reduction(self, input, output):
|
158 |
+
self.needs_reduction = True
|
159 |
+
|
160 |
+
module.register_forward_hook(set_needs_reduction)
|
161 |
+
return module
|
filelists/3speakers_ukrainian_train_filelist.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
filelists/3speakers_ukrainian_train_filelist_dc.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
filelists/3speakers_ukrainian_val_filelist.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
|
2 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
|
3 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
|
4 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
|
5 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
|
6 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
|
7 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
|
8 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
|
9 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
|
10 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
|
11 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
|
12 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
|
13 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
|
14 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
|
15 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
|
16 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
|
17 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
|
18 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
|
19 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
|
20 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
|
21 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
|
22 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
|
23 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
|
24 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
|
25 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
|
26 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
|
27 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
|
28 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
|
29 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
|
30 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
|
31 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
|
32 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
|
33 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
|
34 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
|
35 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
|
36 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
|
37 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
|
38 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
|
39 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
|
40 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
|
41 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
|
42 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
|
43 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
|
44 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
|
45 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
|
46 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
|
47 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
|
48 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
|
49 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
|
50 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
|
51 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
|
52 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
|
53 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
|
54 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
|
55 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
|
56 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
|
57 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
|
58 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
|
59 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
|
60 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
|
61 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
|
62 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
|
63 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
|
64 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
|
65 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
|
66 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
|
67 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
|
68 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
|
69 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
|
70 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
|
71 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
|
72 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
|
73 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
|
74 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
|
75 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
|
76 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
|
77 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
|
78 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
|
79 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
|
80 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
|
81 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
|
82 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
|
83 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
|
84 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
|
85 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
|
filelists/3speakers_ukrainian_val_filelist_dc.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
|
2 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
|
3 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
|
4 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
|
5 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
|
6 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
|
7 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
|
8 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
|
9 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
|
10 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
|
11 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
|
12 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
|
13 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
|
14 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
|
15 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
|
16 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
|
17 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
|
18 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
|
19 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
|
20 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
|
21 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
|
22 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
|
23 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
|
24 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
|
25 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
|
26 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
|
27 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
|
28 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
|
29 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
|
30 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
|
31 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
|
32 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
|
33 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
|
34 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
|
35 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
|
36 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
|
37 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
|
38 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
|
39 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
|
40 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
|
41 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
|
42 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
|
43 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
|
44 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
|
45 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
|
46 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
|
47 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
|
48 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
|
49 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
|
50 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
|
51 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
|
52 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
|
53 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
|
54 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
|
55 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
|
56 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
|
57 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
|
58 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
|
59 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
|
60 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
|
61 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
|
62 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
|
63 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
|
64 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
|
65 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
|
66 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
|
67 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
|
68 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
|
69 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
|
70 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
|
71 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
|
72 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
|
73 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
|
74 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
|
75 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
|
76 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
|
77 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
|
78 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
|
79 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
|
80 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
|
81 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
|
82 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
|
83 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
|
84 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
|
85 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
|
loss.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
from common import get_mask_from_lengths
|
25 |
+
|
26 |
+
|
27 |
+
def compute_flow_loss(
|
28 |
+
z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0
|
29 |
+
):
|
30 |
+
log_det_W_total = 0.0
|
31 |
+
for i, log_s in enumerate(log_s_list):
|
32 |
+
if i == 0:
|
33 |
+
log_s_total = torch.sum(log_s * mask)
|
34 |
+
if len(log_det_W_list):
|
35 |
+
log_det_W_total = log_det_W_list[i]
|
36 |
+
else:
|
37 |
+
log_s_total = log_s_total + torch.sum(log_s * mask)
|
38 |
+
if len(log_det_W_list):
|
39 |
+
log_det_W_total += log_det_W_list[i]
|
40 |
+
|
41 |
+
if len(log_det_W_list):
|
42 |
+
log_det_W_total *= n_elements
|
43 |
+
|
44 |
+
z = z * mask
|
45 |
+
prior_NLL = torch.sum(z * z) / (2 * sigma * sigma)
|
46 |
+
|
47 |
+
loss = prior_NLL - log_s_total - log_det_W_total
|
48 |
+
|
49 |
+
denom = n_elements * n_dims
|
50 |
+
loss = loss / denom
|
51 |
+
loss_prior = prior_NLL / denom
|
52 |
+
return loss, loss_prior
|
53 |
+
|
54 |
+
|
55 |
+
def compute_regression_loss(x_hat, x, mask, name=False):
|
56 |
+
x = x[:, None] if len(x.shape) == 2 else x # add channel dim
|
57 |
+
mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
|
58 |
+
assert len(x.shape) == len(mask.shape)
|
59 |
+
|
60 |
+
x = x * mask
|
61 |
+
x_hat = x_hat * mask
|
62 |
+
|
63 |
+
if name == "vpred":
|
64 |
+
loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
|
65 |
+
else:
|
66 |
+
loss = F.mse_loss(x_hat, x, reduction="sum")
|
67 |
+
loss = loss / mask.sum()
|
68 |
+
|
69 |
+
loss_dict = {"loss_{}".format(name): loss}
|
70 |
+
|
71 |
+
return loss_dict
|
72 |
+
|
73 |
+
|
74 |
+
class AttributePredictionLoss(torch.nn.Module):
|
75 |
+
def __init__(self, name, model_config, loss_weight, sigma=1.0):
|
76 |
+
super(AttributePredictionLoss, self).__init__()
|
77 |
+
self.name = name
|
78 |
+
self.sigma = sigma
|
79 |
+
self.model_name = model_config["name"]
|
80 |
+
self.loss_weight = loss_weight
|
81 |
+
self.n_group_size = 1
|
82 |
+
if "n_group_size" in model_config["hparams"]:
|
83 |
+
self.n_group_size = model_config["hparams"]["n_group_size"]
|
84 |
+
|
85 |
+
def forward(self, model_output, lens):
|
86 |
+
mask = get_mask_from_lengths(lens // self.n_group_size)
|
87 |
+
mask = mask[:, None].float()
|
88 |
+
loss_dict = {}
|
89 |
+
if "z" in model_output:
|
90 |
+
n_elements = lens.sum() // self.n_group_size
|
91 |
+
n_dims = model_output["z"].size(1)
|
92 |
+
|
93 |
+
loss, loss_prior = compute_flow_loss(
|
94 |
+
model_output["z"],
|
95 |
+
model_output["log_det_W_list"],
|
96 |
+
model_output["log_s_list"],
|
97 |
+
n_elements,
|
98 |
+
n_dims,
|
99 |
+
mask,
|
100 |
+
self.sigma,
|
101 |
+
)
|
102 |
+
loss_dict = {
|
103 |
+
"loss_{}".format(self.name): (loss, self.loss_weight),
|
104 |
+
"loss_prior_{}".format(self.name): (loss_prior, 0.0),
|
105 |
+
}
|
106 |
+
elif "x_hat" in model_output:
|
107 |
+
loss_dict = compute_regression_loss(
|
108 |
+
model_output["x_hat"], model_output["x"], mask, self.name
|
109 |
+
)
|
110 |
+
for k, v in loss_dict.items():
|
111 |
+
loss_dict[k] = (v, self.loss_weight)
|
112 |
+
|
113 |
+
if len(loss_dict) == 0:
|
114 |
+
raise Exception("loss not supported")
|
115 |
+
|
116 |
+
return loss_dict
|
117 |
+
|
118 |
+
|
119 |
+
class AttentionCTCLoss(torch.nn.Module):
|
120 |
+
def __init__(self, blank_logprob=-1):
|
121 |
+
super(AttentionCTCLoss, self).__init__()
|
122 |
+
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
123 |
+
self.blank_logprob = blank_logprob
|
124 |
+
self.CTCLoss = nn.CTCLoss(zero_infinity=True)
|
125 |
+
|
126 |
+
def forward(self, attn_logprob, in_lens, out_lens):
|
127 |
+
key_lens = in_lens
|
128 |
+
query_lens = out_lens
|
129 |
+
attn_logprob_padded = F.pad(
|
130 |
+
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
|
131 |
+
)
|
132 |
+
cost_total = 0.0
|
133 |
+
for bid in range(attn_logprob.shape[0]):
|
134 |
+
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
135 |
+
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
|
136 |
+
: query_lens[bid], :, : key_lens[bid] + 1
|
137 |
+
]
|
138 |
+
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
139 |
+
ctc_cost = self.CTCLoss(
|
140 |
+
curr_logprob,
|
141 |
+
target_seq,
|
142 |
+
input_lengths=query_lens[bid : bid + 1],
|
143 |
+
target_lengths=key_lens[bid : bid + 1],
|
144 |
+
)
|
145 |
+
cost_total += ctc_cost
|
146 |
+
cost = cost_total / attn_logprob.shape[0]
|
147 |
+
return cost
|
148 |
+
|
149 |
+
|
150 |
+
class AttentionBinarizationLoss(torch.nn.Module):
|
151 |
+
def __init__(self):
|
152 |
+
super(AttentionBinarizationLoss, self).__init__()
|
153 |
+
|
154 |
+
def forward(self, hard_attention, soft_attention):
|
155 |
+
log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
|
156 |
+
return -log_sum / hard_attention.sum()
|
157 |
+
|
158 |
+
|
159 |
+
class RADTTSLoss(torch.nn.Module):
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
sigma=1.0,
|
163 |
+
n_group_size=1,
|
164 |
+
dur_model_config=None,
|
165 |
+
f0_model_config=None,
|
166 |
+
energy_model_config=None,
|
167 |
+
vpred_model_config=None,
|
168 |
+
loss_weights=None,
|
169 |
+
):
|
170 |
+
super(RADTTSLoss, self).__init__()
|
171 |
+
self.sigma = sigma
|
172 |
+
self.n_group_size = n_group_size
|
173 |
+
self.loss_weights = loss_weights
|
174 |
+
self.attn_ctc_loss = AttentionCTCLoss(
|
175 |
+
blank_logprob=loss_weights.get("blank_logprob", -1)
|
176 |
+
)
|
177 |
+
self.loss_fns = {}
|
178 |
+
if dur_model_config is not None:
|
179 |
+
self.loss_fns["duration_model_outputs"] = AttributePredictionLoss(
|
180 |
+
"duration", dur_model_config, loss_weights["dur_loss_weight"]
|
181 |
+
)
|
182 |
+
|
183 |
+
if f0_model_config is not None:
|
184 |
+
self.loss_fns["f0_model_outputs"] = AttributePredictionLoss(
|
185 |
+
"f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0
|
186 |
+
)
|
187 |
+
|
188 |
+
if energy_model_config is not None:
|
189 |
+
self.loss_fns["energy_model_outputs"] = AttributePredictionLoss(
|
190 |
+
"energy", energy_model_config, loss_weights["energy_loss_weight"]
|
191 |
+
)
|
192 |
+
|
193 |
+
if vpred_model_config is not None:
|
194 |
+
self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss(
|
195 |
+
"vpred", vpred_model_config, loss_weights["vpred_loss_weight"]
|
196 |
+
)
|
197 |
+
|
198 |
+
def forward(self, model_output, in_lens, out_lens):
|
199 |
+
loss_dict = {}
|
200 |
+
if len(model_output["z_mel"]):
|
201 |
+
n_elements = out_lens.sum() // self.n_group_size
|
202 |
+
mask = get_mask_from_lengths(out_lens // self.n_group_size)
|
203 |
+
mask = mask[:, None].float()
|
204 |
+
n_dims = model_output["z_mel"].size(1)
|
205 |
+
loss_mel, loss_prior_mel = compute_flow_loss(
|
206 |
+
model_output["z_mel"],
|
207 |
+
model_output["log_det_W_list"],
|
208 |
+
model_output["log_s_list"],
|
209 |
+
n_elements,
|
210 |
+
n_dims,
|
211 |
+
mask,
|
212 |
+
self.sigma,
|
213 |
+
)
|
214 |
+
loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight
|
215 |
+
loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0)
|
216 |
+
|
217 |
+
ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens)
|
218 |
+
loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"])
|
219 |
+
|
220 |
+
for k in model_output:
|
221 |
+
if k in self.loss_fns:
|
222 |
+
if model_output[k] is not None and len(model_output[k]) > 0:
|
223 |
+
t_lens = in_lens if "dur" in k else out_lens
|
224 |
+
mout = model_output[k]
|
225 |
+
for loss_name, v in self.loss_fns[k](mout, t_lens).items():
|
226 |
+
loss_dict[loss_name] = v
|
227 |
+
|
228 |
+
return loss_dict
|
partialconv1d.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified partialconv source code based on implementation from
|
2 |
+
# https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
|
3 |
+
###############################################################################
|
4 |
+
# BSD 3-Clause License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
7 |
+
#
|
8 |
+
# Author & Contact: Guilin Liu ([email protected])
|
9 |
+
###############################################################################
|
10 |
+
|
11 |
+
# Original Author & Contact: Guilin Liu ([email protected])
|
12 |
+
# Modified by Kevin Shih ([email protected])
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class PartialConv1d(nn.Conv1d):
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
self.multi_channel = False
|
22 |
+
self.return_mask = False
|
23 |
+
super(PartialConv1d, self).__init__(*args, **kwargs)
|
24 |
+
|
25 |
+
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
|
26 |
+
self.slide_winsize = (
|
27 |
+
self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
|
28 |
+
)
|
29 |
+
|
30 |
+
self.last_size = (None, None, None)
|
31 |
+
self.update_mask = None
|
32 |
+
self.mask_ratio = None
|
33 |
+
|
34 |
+
@torch.jit.ignore
|
35 |
+
def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None):
|
36 |
+
"""
|
37 |
+
input: standard input to a 1D conv
|
38 |
+
mask_in: binary mask for valid values, same shape as input
|
39 |
+
"""
|
40 |
+
assert len(input.shape) == 3
|
41 |
+
# if a mask is input, or tensor shape changed, update mask ratio
|
42 |
+
if mask_in is not None or self.last_size != tuple(input.shape):
|
43 |
+
self.last_size = tuple(input.shape)
|
44 |
+
with torch.no_grad():
|
45 |
+
if self.weight_maskUpdater.type() != input.type():
|
46 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
|
47 |
+
if mask_in is None:
|
48 |
+
mask = torch.ones(1, 1, input.data.shape[2]).to(input)
|
49 |
+
else:
|
50 |
+
mask = mask_in
|
51 |
+
self.update_mask = F.conv1d(
|
52 |
+
mask,
|
53 |
+
self.weight_maskUpdater,
|
54 |
+
bias=None,
|
55 |
+
stride=self.stride,
|
56 |
+
padding=self.padding,
|
57 |
+
dilation=self.dilation,
|
58 |
+
groups=1,
|
59 |
+
)
|
60 |
+
# for mixed precision training, change 1e-8 to 1e-6
|
61 |
+
self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6)
|
62 |
+
self.update_mask = torch.clamp(self.update_mask, 0, 1)
|
63 |
+
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
|
64 |
+
raw_out = super(PartialConv1d, self).forward(
|
65 |
+
torch.mul(input, mask) if mask_in is not None else input
|
66 |
+
)
|
67 |
+
if self.bias is not None:
|
68 |
+
bias_view = self.bias.view(1, self.out_channels, 1)
|
69 |
+
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
|
70 |
+
output = torch.mul(output, self.update_mask)
|
71 |
+
else:
|
72 |
+
output = torch.mul(raw_out, self.mask_ratio)
|
73 |
+
|
74 |
+
if self.return_mask:
|
75 |
+
return output, self.update_mask
|
76 |
+
else:
|
77 |
+
return output
|
radam.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original source taken from https://github.com/LiyuanLucasLiu/RAdam
|
2 |
+
#
|
3 |
+
# Copyright 2019 Liyuan Liu
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import math
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
# pylint: disable=no-name-in-module
|
21 |
+
from torch.optim.optimizer import Optimizer
|
22 |
+
|
23 |
+
|
24 |
+
class RAdam(Optimizer):
|
25 |
+
"""RAdam optimizer"""
|
26 |
+
|
27 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
28 |
+
"""
|
29 |
+
Init
|
30 |
+
|
31 |
+
:param params: parameters to optimize
|
32 |
+
:param lr: learning rate
|
33 |
+
:param betas: beta
|
34 |
+
:param eps: numerical precision
|
35 |
+
:param weight_decay: weight decay weight
|
36 |
+
"""
|
37 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
38 |
+
self.buffer = [[None, None, None] for _ in range(10)]
|
39 |
+
super().__init__(params, defaults)
|
40 |
+
|
41 |
+
def step(self, closure=None):
|
42 |
+
loss = None
|
43 |
+
if closure is not None:
|
44 |
+
loss = closure()
|
45 |
+
|
46 |
+
for group in self.param_groups:
|
47 |
+
for p in group["params"]:
|
48 |
+
if p.grad is None:
|
49 |
+
continue
|
50 |
+
grad = p.grad.data.float()
|
51 |
+
if grad.is_sparse:
|
52 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
53 |
+
|
54 |
+
p_data_fp32 = p.data.float()
|
55 |
+
|
56 |
+
state = self.state[p]
|
57 |
+
|
58 |
+
if len(state) == 0:
|
59 |
+
state["step"] = 0
|
60 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
61 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
62 |
+
else:
|
63 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
64 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
65 |
+
|
66 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67 |
+
beta1, beta2 = group["betas"]
|
68 |
+
|
69 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
70 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
71 |
+
|
72 |
+
state["step"] += 1
|
73 |
+
buffered = self.buffer[int(state["step"] % 10)]
|
74 |
+
if state["step"] == buffered[0]:
|
75 |
+
N_sma, step_size = buffered[1], buffered[2]
|
76 |
+
else:
|
77 |
+
buffered[0] = state["step"]
|
78 |
+
beta2_t = beta2 ** state["step"]
|
79 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
80 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
81 |
+
buffered[1] = N_sma
|
82 |
+
|
83 |
+
# more conservative since it's an approximated value
|
84 |
+
if N_sma >= 5:
|
85 |
+
step_size = (
|
86 |
+
group["lr"]
|
87 |
+
* math.sqrt(
|
88 |
+
(1 - beta2_t)
|
89 |
+
* (N_sma - 4)
|
90 |
+
/ (N_sma_max - 4)
|
91 |
+
* (N_sma - 2)
|
92 |
+
/ N_sma
|
93 |
+
* N_sma_max
|
94 |
+
/ (N_sma_max - 2)
|
95 |
+
)
|
96 |
+
/ (1 - beta1 ** state["step"])
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
100 |
+
buffered[2] = step_size
|
101 |
+
|
102 |
+
if group["weight_decay"] != 0:
|
103 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
104 |
+
|
105 |
+
# more conservative since it's an approximated value
|
106 |
+
if N_sma >= 5:
|
107 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
108 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
109 |
+
else:
|
110 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
111 |
+
|
112 |
+
p.data.copy_(p_data_fp32)
|
113 |
+
|
114 |
+
return loss
|
radtts.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
from common import Encoder, LengthRegulator, ConvAttention
|
24 |
+
from common import Invertible1x1ConvLUS, Invertible1x1Conv
|
25 |
+
from common import AffineTransformationLayer, LinearNorm, ExponentialClass
|
26 |
+
from common import get_mask_from_lengths
|
27 |
+
from attribute_prediction_model import get_attribute_prediction_model
|
28 |
+
from alignment import mas_width1 as mas
|
29 |
+
|
30 |
+
|
31 |
+
class FlowStep(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
n_mel_channels,
|
35 |
+
n_context_dim,
|
36 |
+
n_layers,
|
37 |
+
affine_model="simple_conv",
|
38 |
+
scaling_fn="exp",
|
39 |
+
matrix_decomposition="",
|
40 |
+
affine_activation="softplus",
|
41 |
+
use_partial_padding=False,
|
42 |
+
cache_inverse=False,
|
43 |
+
):
|
44 |
+
super(FlowStep, self).__init__()
|
45 |
+
if matrix_decomposition == "LUS":
|
46 |
+
self.invtbl_conv = Invertible1x1ConvLUS(
|
47 |
+
n_mel_channels, cache_inverse=cache_inverse
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
self.invtbl_conv = Invertible1x1Conv(
|
51 |
+
n_mel_channels, cache_inverse=cache_inverse
|
52 |
+
)
|
53 |
+
|
54 |
+
self.affine_tfn = AffineTransformationLayer(
|
55 |
+
n_mel_channels,
|
56 |
+
n_context_dim,
|
57 |
+
n_layers,
|
58 |
+
affine_model=affine_model,
|
59 |
+
scaling_fn=scaling_fn,
|
60 |
+
affine_activation=affine_activation,
|
61 |
+
use_partial_padding=use_partial_padding,
|
62 |
+
)
|
63 |
+
|
64 |
+
def enable_inverse_cache(self):
|
65 |
+
self.invtbl_conv.cache_inverse = True
|
66 |
+
|
67 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
68 |
+
if inverse: # for inference z-> mel
|
69 |
+
z = self.affine_tfn(z, context, inverse, seq_lens=seq_lens)
|
70 |
+
z = self.invtbl_conv(z, inverse)
|
71 |
+
return z
|
72 |
+
else: # training mel->z
|
73 |
+
z, log_det_W = self.invtbl_conv(z)
|
74 |
+
z, log_s = self.affine_tfn(z, context, seq_lens=seq_lens)
|
75 |
+
return z, log_det_W, log_s
|
76 |
+
|
77 |
+
|
78 |
+
class RADTTS(torch.nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
n_speakers,
|
82 |
+
n_speaker_dim,
|
83 |
+
n_text,
|
84 |
+
n_text_dim,
|
85 |
+
n_flows,
|
86 |
+
n_conv_layers_per_step,
|
87 |
+
n_mel_channels,
|
88 |
+
n_hidden,
|
89 |
+
mel_encoder_n_hidden,
|
90 |
+
dummy_speaker_embedding,
|
91 |
+
n_early_size,
|
92 |
+
n_early_every,
|
93 |
+
n_group_size,
|
94 |
+
affine_model,
|
95 |
+
dur_model_config,
|
96 |
+
f0_model_config,
|
97 |
+
energy_model_config,
|
98 |
+
v_model_config=None,
|
99 |
+
include_modules="dec",
|
100 |
+
scaling_fn="exp",
|
101 |
+
matrix_decomposition="",
|
102 |
+
learn_alignments=False,
|
103 |
+
affine_activation="softplus",
|
104 |
+
attn_use_CTC=True,
|
105 |
+
use_speaker_emb_for_alignment=False,
|
106 |
+
use_context_lstm=False,
|
107 |
+
context_lstm_norm=None,
|
108 |
+
text_encoder_lstm_norm=None,
|
109 |
+
n_f0_dims=0,
|
110 |
+
n_energy_avg_dims=0,
|
111 |
+
context_lstm_w_f0_and_energy=True,
|
112 |
+
use_first_order_features=False,
|
113 |
+
unvoiced_bias_activation="",
|
114 |
+
ap_pred_log_f0=False,
|
115 |
+
**kwargs,
|
116 |
+
):
|
117 |
+
super(RADTTS, self).__init__()
|
118 |
+
assert n_early_size % 2 == 0
|
119 |
+
self.do_mel_descaling = kwargs.get("do_mel_descaling", True)
|
120 |
+
self.n_mel_channels = n_mel_channels
|
121 |
+
self.n_f0_dims = n_f0_dims # >= 1 to trains with f0
|
122 |
+
self.n_energy_avg_dims = n_energy_avg_dims # >= 1 trains with energy
|
123 |
+
self.decoder_use_partial_padding = kwargs.get(
|
124 |
+
"decoder_use_partial_padding", True
|
125 |
+
)
|
126 |
+
self.n_speaker_dim = n_speaker_dim
|
127 |
+
assert self.n_speaker_dim % 2 == 0
|
128 |
+
self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim)
|
129 |
+
self.embedding = torch.nn.Embedding(n_text, n_text_dim)
|
130 |
+
self.flows = torch.nn.ModuleList()
|
131 |
+
self.encoder = Encoder(
|
132 |
+
encoder_embedding_dim=n_text_dim,
|
133 |
+
norm_fn=nn.InstanceNorm1d,
|
134 |
+
lstm_norm_fn=text_encoder_lstm_norm,
|
135 |
+
)
|
136 |
+
self.dummy_speaker_embedding = dummy_speaker_embedding
|
137 |
+
self.learn_alignments = learn_alignments
|
138 |
+
self.affine_activation = affine_activation
|
139 |
+
self.include_modules = include_modules
|
140 |
+
self.attn_use_CTC = bool(attn_use_CTC)
|
141 |
+
self.use_speaker_emb_for_alignment = use_speaker_emb_for_alignment
|
142 |
+
self.use_context_lstm = bool(use_context_lstm)
|
143 |
+
self.context_lstm_norm = context_lstm_norm
|
144 |
+
self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy
|
145 |
+
self.length_regulator = LengthRegulator()
|
146 |
+
self.use_first_order_features = bool(use_first_order_features)
|
147 |
+
self.decoder_use_unvoiced_bias = kwargs.get("decoder_use_unvoiced_bias", True)
|
148 |
+
self.ap_pred_log_f0 = ap_pred_log_f0
|
149 |
+
self.ap_use_unvoiced_bias = kwargs.get("ap_use_unvoiced_bias", True)
|
150 |
+
self.attn_straight_through_estimator = kwargs.get(
|
151 |
+
"attn_straight_through_estimator", False
|
152 |
+
)
|
153 |
+
if "atn" in include_modules or "dec" in include_modules:
|
154 |
+
if self.learn_alignments:
|
155 |
+
if self.use_speaker_emb_for_alignment:
|
156 |
+
self.attention = ConvAttention(
|
157 |
+
n_mel_channels, n_text_dim + self.n_speaker_dim
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
self.attention = ConvAttention(n_mel_channels, n_text_dim)
|
161 |
+
|
162 |
+
self.n_flows = n_flows
|
163 |
+
self.n_group_size = n_group_size
|
164 |
+
|
165 |
+
n_flowstep_cond_dims = (
|
166 |
+
self.n_speaker_dim
|
167 |
+
+ (n_text_dim + n_f0_dims + n_energy_avg_dims) * n_group_size
|
168 |
+
)
|
169 |
+
|
170 |
+
if self.use_context_lstm:
|
171 |
+
n_in_context_lstm = self.n_speaker_dim + n_text_dim * n_group_size
|
172 |
+
n_context_lstm_hidden = int(
|
173 |
+
(self.n_speaker_dim + n_text_dim * n_group_size) / 2
|
174 |
+
)
|
175 |
+
|
176 |
+
if self.context_lstm_w_f0_and_energy:
|
177 |
+
n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim
|
178 |
+
n_in_context_lstm *= n_group_size
|
179 |
+
n_in_context_lstm += self.n_speaker_dim
|
180 |
+
|
181 |
+
n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim
|
182 |
+
n_context_hidden = n_context_hidden * n_group_size / 2
|
183 |
+
n_context_hidden = self.n_speaker_dim + n_context_hidden
|
184 |
+
n_context_hidden = int(n_context_hidden)
|
185 |
+
|
186 |
+
n_flowstep_cond_dims = (
|
187 |
+
self.n_speaker_dim + n_text_dim * n_group_size
|
188 |
+
)
|
189 |
+
|
190 |
+
self.context_lstm = torch.nn.LSTM(
|
191 |
+
input_size=n_in_context_lstm,
|
192 |
+
hidden_size=n_context_lstm_hidden,
|
193 |
+
num_layers=1,
|
194 |
+
batch_first=True,
|
195 |
+
bidirectional=True,
|
196 |
+
)
|
197 |
+
|
198 |
+
if context_lstm_norm is not None:
|
199 |
+
if "spectral" in context_lstm_norm:
|
200 |
+
print("Applying spectral norm to context encoder LSTM")
|
201 |
+
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
202 |
+
elif "weight" in context_lstm_norm:
|
203 |
+
print("Applying weight norm to context encoder LSTM")
|
204 |
+
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
205 |
+
|
206 |
+
self.context_lstm = lstm_norm_fn_pntr(
|
207 |
+
self.context_lstm, "weight_hh_l0"
|
208 |
+
)
|
209 |
+
self.context_lstm = lstm_norm_fn_pntr(
|
210 |
+
self.context_lstm, "weight_hh_l0_reverse"
|
211 |
+
)
|
212 |
+
|
213 |
+
if self.n_group_size > 1:
|
214 |
+
self.unfold_params = {
|
215 |
+
"kernel_size": (n_group_size, 1),
|
216 |
+
"stride": n_group_size,
|
217 |
+
"padding": 0,
|
218 |
+
"dilation": 1,
|
219 |
+
}
|
220 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
221 |
+
|
222 |
+
self.exit_steps = []
|
223 |
+
self.n_early_size = n_early_size
|
224 |
+
n_mel_channels = n_mel_channels * n_group_size
|
225 |
+
|
226 |
+
for i in range(self.n_flows):
|
227 |
+
if i > 0 and i % n_early_every == 0: # early exitting
|
228 |
+
n_mel_channels -= self.n_early_size
|
229 |
+
self.exit_steps.append(i)
|
230 |
+
|
231 |
+
self.flows.append(
|
232 |
+
FlowStep(
|
233 |
+
n_mel_channels,
|
234 |
+
n_flowstep_cond_dims,
|
235 |
+
n_conv_layers_per_step,
|
236 |
+
affine_model,
|
237 |
+
scaling_fn,
|
238 |
+
matrix_decomposition,
|
239 |
+
affine_activation=affine_activation,
|
240 |
+
use_partial_padding=self.decoder_use_partial_padding,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
if "dpm" in include_modules:
|
245 |
+
dur_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
246 |
+
self.dur_pred_layer = get_attribute_prediction_model(dur_model_config)
|
247 |
+
|
248 |
+
self.use_unvoiced_bias = False
|
249 |
+
self.use_vpred_module = False
|
250 |
+
self.ap_use_voiced_embeddings = kwargs.get("ap_use_voiced_embeddings", True)
|
251 |
+
|
252 |
+
if self.decoder_use_unvoiced_bias or self.ap_use_unvoiced_bias:
|
253 |
+
assert unvoiced_bias_activation in {"relu", "exp"}
|
254 |
+
self.use_unvoiced_bias = True
|
255 |
+
if unvoiced_bias_activation == "relu":
|
256 |
+
unvbias_nonlin = nn.ReLU()
|
257 |
+
elif unvoiced_bias_activation == "exp":
|
258 |
+
unvbias_nonlin = ExponentialClass()
|
259 |
+
else:
|
260 |
+
exit(1) # we won't reach here anyway due to the assertion
|
261 |
+
self.unvoiced_bias_module = nn.Sequential(
|
262 |
+
LinearNorm(n_text_dim, 1), unvbias_nonlin
|
263 |
+
)
|
264 |
+
|
265 |
+
# all situations in which the vpred module is necessary
|
266 |
+
if (
|
267 |
+
self.ap_use_voiced_embeddings
|
268 |
+
or self.use_unvoiced_bias
|
269 |
+
or "vpred" in include_modules
|
270 |
+
):
|
271 |
+
self.use_vpred_module = True
|
272 |
+
|
273 |
+
if self.use_vpred_module:
|
274 |
+
v_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
275 |
+
self.v_pred_module = get_attribute_prediction_model(v_model_config)
|
276 |
+
# 4 embeddings, first two are scales, second two are biases
|
277 |
+
if self.ap_use_voiced_embeddings:
|
278 |
+
self.v_embeddings = torch.nn.Embedding(4, n_text_dim)
|
279 |
+
|
280 |
+
if "apm" in include_modules:
|
281 |
+
f0_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
282 |
+
energy_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
283 |
+
if self.use_first_order_features:
|
284 |
+
f0_model_config["hparams"]["n_in_dim"] = 2
|
285 |
+
energy_model_config["hparams"]["n_in_dim"] = 2
|
286 |
+
if (
|
287 |
+
"spline_flow_params" in f0_model_config["hparams"]
|
288 |
+
and f0_model_config["hparams"]["spline_flow_params"] is not None
|
289 |
+
):
|
290 |
+
f0_model_config["hparams"]["spline_flow_params"][
|
291 |
+
"n_in_channels"
|
292 |
+
] = 2
|
293 |
+
if (
|
294 |
+
"spline_flow_params" in energy_model_config["hparams"]
|
295 |
+
and energy_model_config["hparams"]["spline_flow_params"] is not None
|
296 |
+
):
|
297 |
+
energy_model_config["hparams"]["spline_flow_params"][
|
298 |
+
"n_in_channels"
|
299 |
+
] = 2
|
300 |
+
else:
|
301 |
+
if (
|
302 |
+
"spline_flow_params" in f0_model_config["hparams"]
|
303 |
+
and f0_model_config["hparams"]["spline_flow_params"] is not None
|
304 |
+
):
|
305 |
+
f0_model_config["hparams"]["spline_flow_params"][
|
306 |
+
"n_in_channels"
|
307 |
+
] = f0_model_config["hparams"]["n_in_dim"]
|
308 |
+
if (
|
309 |
+
"spline_flow_params" in energy_model_config["hparams"]
|
310 |
+
and energy_model_config["hparams"]["spline_flow_params"] is not None
|
311 |
+
):
|
312 |
+
energy_model_config["hparams"]["spline_flow_params"][
|
313 |
+
"n_in_channels"
|
314 |
+
] = energy_model_config["hparams"]["n_in_dim"]
|
315 |
+
|
316 |
+
self.f0_pred_module = get_attribute_prediction_model(f0_model_config)
|
317 |
+
self.energy_pred_module = get_attribute_prediction_model(
|
318 |
+
energy_model_config
|
319 |
+
)
|
320 |
+
|
321 |
+
def is_attribute_unconditional(self):
|
322 |
+
"""
|
323 |
+
returns true if the decoder is conditioned on neither energy nor F0
|
324 |
+
"""
|
325 |
+
return self.n_f0_dims == 0 and self.n_energy_avg_dims == 0
|
326 |
+
|
327 |
+
def encode_speaker(self, spk_ids):
|
328 |
+
spk_ids = spk_ids * 0 if self.dummy_speaker_embedding else spk_ids
|
329 |
+
spk_vecs = self.speaker_embedding(spk_ids)
|
330 |
+
return spk_vecs
|
331 |
+
|
332 |
+
def encode_text(self, text, in_lens):
|
333 |
+
# text_embeddings: b x len_text x n_text_dim
|
334 |
+
text_embeddings = self.embedding(text).transpose(1, 2)
|
335 |
+
# text_enc: b x n_text_dim x encoder_dim (512)
|
336 |
+
if in_lens is None:
|
337 |
+
text_enc = self.encoder.infer(text_embeddings).transpose(1, 2)
|
338 |
+
else:
|
339 |
+
text_enc = self.encoder(text_embeddings, in_lens).transpose(1, 2)
|
340 |
+
|
341 |
+
return text_enc, text_embeddings
|
342 |
+
|
343 |
+
def preprocess_context(
|
344 |
+
self, context, speaker_vecs, out_lens=None, f0=None, energy_avg=None
|
345 |
+
):
|
346 |
+
if self.n_group_size > 1:
|
347 |
+
# unfolding zero-padded values
|
348 |
+
context = self.unfold(context.unsqueeze(-1))
|
349 |
+
if f0 is not None:
|
350 |
+
f0 = self.unfold(f0[:, None, :, None])
|
351 |
+
if energy_avg is not None:
|
352 |
+
energy_avg = self.unfold(energy_avg[:, None, :, None])
|
353 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, context.shape[2])
|
354 |
+
context_w_spkvec = torch.cat((context, speaker_vecs), 1)
|
355 |
+
|
356 |
+
if self.use_context_lstm:
|
357 |
+
if self.context_lstm_w_f0_and_energy:
|
358 |
+
if f0 is not None:
|
359 |
+
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
|
360 |
+
|
361 |
+
if energy_avg is not None:
|
362 |
+
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
|
363 |
+
|
364 |
+
unfolded_out_lens = (out_lens // self.n_group_size).long().cpu()
|
365 |
+
unfolded_out_lens_packed = nn.utils.rnn.pack_padded_sequence(
|
366 |
+
context_w_spkvec.transpose(1, 2),
|
367 |
+
unfolded_out_lens,
|
368 |
+
batch_first=True,
|
369 |
+
enforce_sorted=False,
|
370 |
+
)
|
371 |
+
self.context_lstm.flatten_parameters()
|
372 |
+
context_lstm_packed_output, _ = self.context_lstm(unfolded_out_lens_packed)
|
373 |
+
context_lstm_padded_output, _ = nn.utils.rnn.pad_packed_sequence(
|
374 |
+
context_lstm_packed_output, batch_first=True
|
375 |
+
)
|
376 |
+
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)
|
377 |
+
|
378 |
+
if not self.context_lstm_w_f0_and_energy:
|
379 |
+
if f0 is not None:
|
380 |
+
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
|
381 |
+
|
382 |
+
if energy_avg is not None:
|
383 |
+
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
|
384 |
+
|
385 |
+
return context_w_spkvec
|
386 |
+
|
387 |
+
def enable_inverse_cache(self):
|
388 |
+
for flow_step in self.flows:
|
389 |
+
flow_step.enable_inverse_cache()
|
390 |
+
|
391 |
+
def fold(self, mel):
|
392 |
+
"""Inverse of the self.unfold(mel.unsqueeze(-1)) operation used for the
|
393 |
+
grouping or "squeeze" operation on input
|
394 |
+
|
395 |
+
Args:
|
396 |
+
mel: B x C x T tensor of temporal data
|
397 |
+
"""
|
398 |
+
mel = nn.functional.fold(
|
399 |
+
mel, output_size=(mel.shape[2] * self.n_group_size, 1), **self.unfold_params
|
400 |
+
).squeeze(-1)
|
401 |
+
return mel
|
402 |
+
|
403 |
+
def binarize_attention(self, attn, in_lens, out_lens):
|
404 |
+
"""For training purposes only. Binarizes attention with MAS. These will
|
405 |
+
no longer recieve a gradient
|
406 |
+
Args:
|
407 |
+
attn: B x 1 x max_mel_len x max_text_len
|
408 |
+
"""
|
409 |
+
b_size = attn.shape[0]
|
410 |
+
with torch.no_grad():
|
411 |
+
attn_cpu = attn.data.cpu().numpy()
|
412 |
+
attn_out = torch.zeros_like(attn)
|
413 |
+
for ind in range(b_size):
|
414 |
+
hard_attn = mas(attn_cpu[ind, 0, : out_lens[ind], : in_lens[ind]])
|
415 |
+
attn_out[ind, 0, : out_lens[ind], : in_lens[ind]] = torch.tensor(
|
416 |
+
hard_attn, device=attn.get_device()
|
417 |
+
)
|
418 |
+
return attn_out
|
419 |
+
|
420 |
+
def get_first_order_features(self, feats, out_lens, dilation=1):
|
421 |
+
"""
|
422 |
+
feats: b x max_length
|
423 |
+
out_lens: b-dim
|
424 |
+
"""
|
425 |
+
# add an extra column
|
426 |
+
feats_extended_R = torch.cat(
|
427 |
+
(feats, torch.zeros_like(feats[:, 0:dilation])), dim=1
|
428 |
+
)
|
429 |
+
feats_extended_L = torch.cat(
|
430 |
+
(torch.zeros_like(feats[:, 0:dilation]), feats), dim=1
|
431 |
+
)
|
432 |
+
dfeats_R = feats_extended_R[:, dilation:] - feats
|
433 |
+
dfeats_L = feats - feats_extended_L[:, 0:-dilation]
|
434 |
+
|
435 |
+
return (dfeats_R + dfeats_L) * 0.5
|
436 |
+
|
437 |
+
def apply_voice_mask_to_text(self, text_enc, voiced_mask):
|
438 |
+
"""
|
439 |
+
text_enc: b x C x N
|
440 |
+
voiced_mask: b x N
|
441 |
+
"""
|
442 |
+
voiced_mask = voiced_mask.unsqueeze(1)
|
443 |
+
voiced_embedding_s = self.v_embeddings.weight[0:1, :, None]
|
444 |
+
unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None]
|
445 |
+
voiced_embedding_b = self.v_embeddings.weight[2:3, :, None]
|
446 |
+
unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None]
|
447 |
+
scale = torch.sigmoid(
|
448 |
+
voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask)
|
449 |
+
)
|
450 |
+
bias = 0.1 * torch.tanh(
|
451 |
+
voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask)
|
452 |
+
)
|
453 |
+
return text_enc * scale + bias
|
454 |
+
|
455 |
+
def forward(
|
456 |
+
self,
|
457 |
+
mel,
|
458 |
+
speaker_ids,
|
459 |
+
text,
|
460 |
+
in_lens,
|
461 |
+
out_lens,
|
462 |
+
binarize_attention=False,
|
463 |
+
attn_prior=None,
|
464 |
+
f0=None,
|
465 |
+
energy_avg=None,
|
466 |
+
voiced_mask=None,
|
467 |
+
p_voiced=None,
|
468 |
+
):
|
469 |
+
speaker_vecs = self.encode_speaker(speaker_ids)
|
470 |
+
text_enc, text_embeddings = self.encode_text(text, in_lens)
|
471 |
+
|
472 |
+
log_s_list, log_det_W_list, z_mel = [], [], []
|
473 |
+
attn = None
|
474 |
+
attn_soft = None
|
475 |
+
attn_hard = None
|
476 |
+
if "atn" in self.include_modules or "dec" in self.include_modules:
|
477 |
+
# make sure to do the alignments before folding
|
478 |
+
attn_mask = get_mask_from_lengths(in_lens)[..., None] == 0
|
479 |
+
|
480 |
+
text_embeddings_for_attn = text_embeddings
|
481 |
+
if self.use_speaker_emb_for_alignment:
|
482 |
+
speaker_vecs_expd = speaker_vecs[:, :, None].expand(
|
483 |
+
-1, -1, text_embeddings.shape[2]
|
484 |
+
)
|
485 |
+
text_embeddings_for_attn = torch.cat(
|
486 |
+
(text_embeddings_for_attn, speaker_vecs_expd.detach()), 1
|
487 |
+
)
|
488 |
+
|
489 |
+
# attn_mask shld be 1 for unsd t-steps in text_enc_w_spkvec tensor
|
490 |
+
attn_soft, attn_logprob = self.attention(
|
491 |
+
mel,
|
492 |
+
text_embeddings_for_attn,
|
493 |
+
out_lens,
|
494 |
+
attn_mask,
|
495 |
+
key_lens=in_lens,
|
496 |
+
attn_prior=attn_prior,
|
497 |
+
)
|
498 |
+
|
499 |
+
if binarize_attention:
|
500 |
+
attn = self.binarize_attention(attn_soft, in_lens, out_lens)
|
501 |
+
attn_hard = attn
|
502 |
+
if self.attn_straight_through_estimator:
|
503 |
+
attn_hard = attn_soft + (attn_hard - attn_soft).detach()
|
504 |
+
else:
|
505 |
+
attn = attn_soft
|
506 |
+
|
507 |
+
context = torch.bmm(text_enc, attn.squeeze(1).transpose(1, 2))
|
508 |
+
|
509 |
+
f0_bias = 0
|
510 |
+
# unvoiced bias forward pass
|
511 |
+
if self.use_unvoiced_bias:
|
512 |
+
f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1))
|
513 |
+
f0_bias = -f0_bias[..., 0]
|
514 |
+
f0_bias = f0_bias * (~voiced_mask.bool()).float()
|
515 |
+
|
516 |
+
# mel decoder forward pass
|
517 |
+
if "dec" in self.include_modules:
|
518 |
+
if self.n_group_size > 1:
|
519 |
+
# might truncate some frames at the end, but that's ok
|
520 |
+
# sometimes referred to as the "squeeeze" operation
|
521 |
+
# invert this by calling self.fold(mel_or_z)
|
522 |
+
mel = self.unfold(mel.unsqueeze(-1))
|
523 |
+
z_out = []
|
524 |
+
# where context is folded
|
525 |
+
# mask f0 in case values are interpolated
|
526 |
+
|
527 |
+
if f0 is None:
|
528 |
+
f0_aug = None
|
529 |
+
else:
|
530 |
+
if self.decoder_use_unvoiced_bias:
|
531 |
+
f0_aug = f0 * voiced_mask + f0_bias
|
532 |
+
else:
|
533 |
+
f0_aug = f0 * voiced_mask
|
534 |
+
|
535 |
+
context_w_spkvec = self.preprocess_context(
|
536 |
+
context, speaker_vecs, out_lens, f0_aug, energy_avg
|
537 |
+
)
|
538 |
+
|
539 |
+
log_s_list, log_det_W_list, z_out = [], [], []
|
540 |
+
unfolded_seq_lens = out_lens // self.n_group_size
|
541 |
+
for i, flow_step in enumerate(self.flows):
|
542 |
+
if i in self.exit_steps:
|
543 |
+
z = mel[:, : self.n_early_size]
|
544 |
+
z_out.append(z)
|
545 |
+
mel = mel[:, self.n_early_size :]
|
546 |
+
mel, log_det_W, log_s = flow_step(
|
547 |
+
mel, context_w_spkvec, seq_lens=unfolded_seq_lens
|
548 |
+
)
|
549 |
+
log_s_list.append(log_s)
|
550 |
+
log_det_W_list.append(log_det_W)
|
551 |
+
|
552 |
+
z_out.append(mel)
|
553 |
+
z_mel = torch.cat(z_out, 1)
|
554 |
+
|
555 |
+
# duration predictor forward pass
|
556 |
+
duration_model_outputs = None
|
557 |
+
if "dpm" in self.include_modules:
|
558 |
+
if attn_hard is None:
|
559 |
+
attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
|
560 |
+
|
561 |
+
# convert hard attention to durations
|
562 |
+
attn_hard_reduced = attn_hard.sum(2)[:, 0, :]
|
563 |
+
duration_model_outputs = self.dur_pred_layer(
|
564 |
+
torch.detach(text_enc),
|
565 |
+
torch.detach(speaker_vecs),
|
566 |
+
torch.detach(attn_hard_reduced.float()),
|
567 |
+
in_lens,
|
568 |
+
)
|
569 |
+
|
570 |
+
# f0, energy, vpred predictors forward pass
|
571 |
+
f0_model_outputs = None
|
572 |
+
energy_model_outputs = None
|
573 |
+
vpred_model_outputs = None
|
574 |
+
if "apm" in self.include_modules:
|
575 |
+
if attn_hard is None:
|
576 |
+
attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
|
577 |
+
|
578 |
+
# convert hard attention to durations
|
579 |
+
if binarize_attention:
|
580 |
+
text_enc_time_expanded = context.clone()
|
581 |
+
else:
|
582 |
+
text_enc_time_expanded = torch.bmm(
|
583 |
+
text_enc, attn_hard.squeeze(1).transpose(1, 2)
|
584 |
+
)
|
585 |
+
|
586 |
+
if self.use_vpred_module:
|
587 |
+
# unvoiced bias requires voiced mask prediction
|
588 |
+
vpred_model_outputs = self.v_pred_module(
|
589 |
+
torch.detach(text_enc_time_expanded),
|
590 |
+
torch.detach(speaker_vecs),
|
591 |
+
torch.detach(voiced_mask),
|
592 |
+
out_lens,
|
593 |
+
)
|
594 |
+
|
595 |
+
# affine transform context using voiced mask
|
596 |
+
if self.ap_use_voiced_embeddings:
|
597 |
+
text_enc_time_expanded = self.apply_voice_mask_to_text(
|
598 |
+
text_enc_time_expanded, voiced_mask
|
599 |
+
)
|
600 |
+
|
601 |
+
# whether to use the unvoiced bias in the attribute predictor
|
602 |
+
# circumvent in-place modification
|
603 |
+
f0_target = f0.clone()
|
604 |
+
if self.ap_use_unvoiced_bias:
|
605 |
+
f0_target = torch.detach(f0_target * voiced_mask + f0_bias)
|
606 |
+
else:
|
607 |
+
f0_target = torch.detach(f0_target)
|
608 |
+
|
609 |
+
# fit to log f0 in f0 predictor
|
610 |
+
f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()])
|
611 |
+
f0_target = f0_target / 6 # scale to ~ [0, 1] in log space
|
612 |
+
energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1]
|
613 |
+
|
614 |
+
if self.use_first_order_features:
|
615 |
+
df0 = self.get_first_order_features(f0_target, out_lens)
|
616 |
+
denergy_avg = self.get_first_order_features(energy_avg, out_lens)
|
617 |
+
|
618 |
+
f0_voiced = torch.cat((f0_target[:, None], df0[:, None]), dim=1)
|
619 |
+
energy_avg = torch.cat(
|
620 |
+
(energy_avg[:, None], denergy_avg[:, None]), dim=1
|
621 |
+
)
|
622 |
+
|
623 |
+
f0_voiced = f0_voiced * 3 # scale to ~ 1 std
|
624 |
+
energy_avg = energy_avg * 3 # scale to ~ 1 std
|
625 |
+
else:
|
626 |
+
f0_voiced = f0_target * 2 # scale to ~ 1 std
|
627 |
+
energy_avg = energy_avg * 1.4 # scale to ~ 1 std
|
628 |
+
|
629 |
+
f0_model_outputs = self.f0_pred_module(
|
630 |
+
text_enc_time_expanded, torch.detach(speaker_vecs), f0_voiced, out_lens
|
631 |
+
)
|
632 |
+
|
633 |
+
energy_model_outputs = self.energy_pred_module(
|
634 |
+
text_enc_time_expanded, torch.detach(speaker_vecs), energy_avg, out_lens
|
635 |
+
)
|
636 |
+
|
637 |
+
outputs = {
|
638 |
+
"z_mel": z_mel,
|
639 |
+
"log_det_W_list": log_det_W_list,
|
640 |
+
"log_s_list": log_s_list,
|
641 |
+
"duration_model_outputs": duration_model_outputs,
|
642 |
+
"f0_model_outputs": f0_model_outputs,
|
643 |
+
"energy_model_outputs": energy_model_outputs,
|
644 |
+
"vpred_model_outputs": vpred_model_outputs,
|
645 |
+
"attn_soft": attn_soft,
|
646 |
+
"attn": attn,
|
647 |
+
"text_embeddings": text_embeddings,
|
648 |
+
"attn_logprob": attn_logprob,
|
649 |
+
}
|
650 |
+
|
651 |
+
return outputs
|
652 |
+
|
653 |
+
def infer(
|
654 |
+
self,
|
655 |
+
speaker_id,
|
656 |
+
text,
|
657 |
+
sigma,
|
658 |
+
sigma_dur=0.8,
|
659 |
+
sigma_f0=0.8,
|
660 |
+
sigma_energy=0.8,
|
661 |
+
token_dur_scaling=1.0,
|
662 |
+
token_duration_max=100,
|
663 |
+
speaker_id_text=None,
|
664 |
+
speaker_id_attributes=None,
|
665 |
+
dur=None,
|
666 |
+
f0=None,
|
667 |
+
energy_avg=None,
|
668 |
+
voiced_mask=None,
|
669 |
+
f0_mean=0.0,
|
670 |
+
f0_std=0.0,
|
671 |
+
energy_mean=0.0,
|
672 |
+
energy_std=0.0,
|
673 |
+
use_cuda=False,
|
674 |
+
):
|
675 |
+
batch_size = text.shape[0]
|
676 |
+
n_tokens = text.shape[1]
|
677 |
+
spk_vec = self.encode_speaker(speaker_id)
|
678 |
+
spk_vec_text, spk_vec_attributes = spk_vec, spk_vec
|
679 |
+
if speaker_id_text is not None:
|
680 |
+
spk_vec_text = self.encode_speaker(speaker_id_text)
|
681 |
+
if speaker_id_attributes is not None:
|
682 |
+
spk_vec_attributes = self.encode_speaker(speaker_id_attributes)
|
683 |
+
|
684 |
+
txt_enc, txt_emb = self.encode_text(text, None)
|
685 |
+
|
686 |
+
if dur is None:
|
687 |
+
# get token durations
|
688 |
+
if use_cuda:
|
689 |
+
z_dur = torch.cuda.FloatTensor(batch_size, 1, n_tokens)
|
690 |
+
else:
|
691 |
+
z_dur = torch.FloatTensor(batch_size, 1, n_tokens)
|
692 |
+
|
693 |
+
z_dur = z_dur.normal_() * sigma_dur
|
694 |
+
|
695 |
+
dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
|
696 |
+
if dur.shape[-1] < txt_enc.shape[-1]:
|
697 |
+
to_pad = txt_enc.shape[-1] - dur.shape[2]
|
698 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
699 |
+
dur = pad_fn(dur)
|
700 |
+
dur = dur[:, 0]
|
701 |
+
dur = dur.clamp(0, token_duration_max)
|
702 |
+
dur = dur * token_dur_scaling if token_dur_scaling > 0 else dur
|
703 |
+
dur = (dur + 0.5).floor().int()
|
704 |
+
|
705 |
+
out_lens = dur.sum(1).long().cpu() if dur.shape[0] != 1 else [dur.sum(1)]
|
706 |
+
max_n_frames = max(out_lens)
|
707 |
+
|
708 |
+
out_lens = torch.LongTensor(out_lens).to(txt_enc.device)
|
709 |
+
|
710 |
+
# get attributes f0, energy, vpred, etc)
|
711 |
+
txt_enc_time_expanded = self.length_regulator(
|
712 |
+
txt_enc.transpose(1, 2), dur
|
713 |
+
).transpose(1, 2)
|
714 |
+
|
715 |
+
if not self.is_attribute_unconditional():
|
716 |
+
# if explicitly modeling attributes
|
717 |
+
if voiced_mask is None:
|
718 |
+
if self.use_vpred_module:
|
719 |
+
# get logits
|
720 |
+
voiced_mask = self.v_pred_module.infer(
|
721 |
+
None, txt_enc_time_expanded, spk_vec_attributes
|
722 |
+
)
|
723 |
+
voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5
|
724 |
+
voiced_mask = voiced_mask.float()
|
725 |
+
|
726 |
+
ap_txt_enc_time_expanded = txt_enc_time_expanded
|
727 |
+
# voice mask augmentation only used for attribute prediction
|
728 |
+
if self.ap_use_voiced_embeddings:
|
729 |
+
ap_txt_enc_time_expanded = self.apply_voice_mask_to_text(
|
730 |
+
txt_enc_time_expanded, voiced_mask
|
731 |
+
)
|
732 |
+
|
733 |
+
f0_bias = 0
|
734 |
+
# unvoiced bias forward pass
|
735 |
+
if self.use_unvoiced_bias:
|
736 |
+
f0_bias = self.unvoiced_bias_module(
|
737 |
+
txt_enc_time_expanded.permute(0, 2, 1)
|
738 |
+
)
|
739 |
+
f0_bias = -f0_bias[..., 0]
|
740 |
+
f0_bias = f0_bias * (~voiced_mask.bool()).float()
|
741 |
+
|
742 |
+
if f0 is None:
|
743 |
+
n_f0_feature_channels = 2 if self.use_first_order_features else 1
|
744 |
+
|
745 |
+
if use_cuda:
|
746 |
+
z_f0 = (
|
747 |
+
torch.cuda.FloatTensor(
|
748 |
+
batch_size, n_f0_feature_channels, max_n_frames
|
749 |
+
).normal_()
|
750 |
+
* sigma_f0
|
751 |
+
)
|
752 |
+
else:
|
753 |
+
z_f0 = (
|
754 |
+
torch.FloatTensor(
|
755 |
+
batch_size, n_f0_feature_channels, max_n_frames
|
756 |
+
).normal_()
|
757 |
+
* sigma_f0
|
758 |
+
)
|
759 |
+
|
760 |
+
f0 = self.infer_f0(
|
761 |
+
z_f0,
|
762 |
+
ap_txt_enc_time_expanded,
|
763 |
+
spk_vec_attributes,
|
764 |
+
voiced_mask,
|
765 |
+
out_lens,
|
766 |
+
)[:, 0]
|
767 |
+
|
768 |
+
if f0_mean > 0.0:
|
769 |
+
vmask_bool = voiced_mask.bool()
|
770 |
+
f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std()
|
771 |
+
f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma
|
772 |
+
f0_std = f0_std if f0_std > 0 else f0_sigma
|
773 |
+
f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean
|
774 |
+
|
775 |
+
if energy_avg is None:
|
776 |
+
n_energy_feature_channels = 2 if self.use_first_order_features else 1
|
777 |
+
if use_cuda:
|
778 |
+
z_energy_avg = (
|
779 |
+
torch.cuda.FloatTensor(
|
780 |
+
batch_size, n_energy_feature_channels, max_n_frames
|
781 |
+
).normal_()
|
782 |
+
* sigma_energy
|
783 |
+
)
|
784 |
+
else:
|
785 |
+
z_energy_avg = (
|
786 |
+
torch.FloatTensor(
|
787 |
+
batch_size, n_energy_feature_channels, max_n_frames
|
788 |
+
).normal_()
|
789 |
+
* sigma_energy
|
790 |
+
)
|
791 |
+
energy_avg = self.infer_energy(
|
792 |
+
z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
|
793 |
+
)[:, 0]
|
794 |
+
|
795 |
+
# replication pad, because ungrouping with different group sizes
|
796 |
+
# may lead to mismatched lengths
|
797 |
+
if energy_avg.shape[1] < out_lens[0]:
|
798 |
+
to_pad = out_lens[0] - energy_avg.shape[1]
|
799 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
800 |
+
f0 = pad_fn(f0[None])[0]
|
801 |
+
energy_avg = pad_fn(energy_avg[None])[0]
|
802 |
+
if f0.shape[1] < out_lens[0]:
|
803 |
+
to_pad = out_lens[0] - f0.shape[1]
|
804 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
805 |
+
f0 = pad_fn(f0[None])[0]
|
806 |
+
|
807 |
+
if self.decoder_use_unvoiced_bias:
|
808 |
+
context_w_spkvec = self.preprocess_context(
|
809 |
+
txt_enc_time_expanded,
|
810 |
+
spk_vec,
|
811 |
+
out_lens,
|
812 |
+
f0 * voiced_mask + f0_bias,
|
813 |
+
energy_avg,
|
814 |
+
)
|
815 |
+
else:
|
816 |
+
context_w_spkvec = self.preprocess_context(
|
817 |
+
txt_enc_time_expanded,
|
818 |
+
spk_vec,
|
819 |
+
out_lens,
|
820 |
+
f0 * voiced_mask,
|
821 |
+
energy_avg,
|
822 |
+
)
|
823 |
+
else:
|
824 |
+
context_w_spkvec = self.preprocess_context(
|
825 |
+
txt_enc_time_expanded, spk_vec, out_lens, None, None
|
826 |
+
)
|
827 |
+
|
828 |
+
if use_cuda:
|
829 |
+
residual = torch.cuda.FloatTensor(
|
830 |
+
batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
|
831 |
+
)
|
832 |
+
else:
|
833 |
+
residual = torch.FloatTensor(
|
834 |
+
batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
|
835 |
+
)
|
836 |
+
|
837 |
+
residual = residual.normal_() * sigma
|
838 |
+
|
839 |
+
# map from z sample to data
|
840 |
+
exit_steps_stack = self.exit_steps.copy()
|
841 |
+
mel = residual[:, len(exit_steps_stack) * self.n_early_size :]
|
842 |
+
remaining_residual = residual[:, : len(exit_steps_stack) * self.n_early_size]
|
843 |
+
unfolded_seq_lens = out_lens // self.n_group_size
|
844 |
+
for i, flow_step in enumerate(reversed(self.flows)):
|
845 |
+
curr_step = len(self.flows) - i - 1
|
846 |
+
mel = flow_step(
|
847 |
+
mel, context_w_spkvec, inverse=True, seq_lens=unfolded_seq_lens
|
848 |
+
)
|
849 |
+
if len(exit_steps_stack) > 0 and curr_step == exit_steps_stack[-1]:
|
850 |
+
# concatenate the next chunk of z
|
851 |
+
exit_steps_stack.pop()
|
852 |
+
residual_to_add = remaining_residual[
|
853 |
+
:, len(exit_steps_stack) * self.n_early_size :
|
854 |
+
]
|
855 |
+
remaining_residual = remaining_residual[
|
856 |
+
:, : len(exit_steps_stack) * self.n_early_size
|
857 |
+
]
|
858 |
+
mel = torch.cat((residual_to_add, mel), 1)
|
859 |
+
|
860 |
+
if self.n_group_size > 1:
|
861 |
+
mel = self.fold(mel)
|
862 |
+
if self.do_mel_descaling:
|
863 |
+
mel = mel * 2 - 5.5
|
864 |
+
|
865 |
+
return {
|
866 |
+
"mel": mel,
|
867 |
+
"dur": dur,
|
868 |
+
"f0": f0,
|
869 |
+
"energy_avg": energy_avg,
|
870 |
+
"voiced_mask": voiced_mask,
|
871 |
+
}
|
872 |
+
|
873 |
+
def infer_f0(
|
874 |
+
self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None
|
875 |
+
):
|
876 |
+
f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens)
|
877 |
+
|
878 |
+
if voiced_mask is not None and len(voiced_mask.shape) == 2:
|
879 |
+
voiced_mask = voiced_mask[:, None]
|
880 |
+
|
881 |
+
# constants
|
882 |
+
if self.ap_pred_log_f0:
|
883 |
+
if self.use_first_order_features:
|
884 |
+
f0 = f0[:, 0:1, :] / 3
|
885 |
+
else:
|
886 |
+
f0 = f0 / 2
|
887 |
+
f0 = f0 * 6
|
888 |
+
else:
|
889 |
+
f0 = f0 / 6
|
890 |
+
f0 = f0 / 640
|
891 |
+
|
892 |
+
if voiced_mask is None:
|
893 |
+
voiced_mask = f0 > 0.0
|
894 |
+
else:
|
895 |
+
voiced_mask = voiced_mask.bool()
|
896 |
+
|
897 |
+
# due to grouping, f0 might be 1 frame short
|
898 |
+
voiced_mask = voiced_mask[:, :, : f0.shape[-1]]
|
899 |
+
if self.ap_pred_log_f0:
|
900 |
+
# if variable is set, decoder sees linear f0
|
901 |
+
# mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool()
|
902 |
+
f0[voiced_mask] = torch.exp(f0[voiced_mask])
|
903 |
+
f0[~voiced_mask] = 0.0
|
904 |
+
return f0
|
905 |
+
|
906 |
+
def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens):
|
907 |
+
energy = self.energy_pred_module.infer(
|
908 |
+
residual, txt_enc_time_expanded, spk_vec, lens
|
909 |
+
)
|
910 |
+
|
911 |
+
# magic constants
|
912 |
+
if self.use_first_order_features:
|
913 |
+
energy = energy / 3
|
914 |
+
else:
|
915 |
+
energy = energy / 1.4
|
916 |
+
energy = (energy + 1) / 2
|
917 |
+
return energy
|
918 |
+
|
919 |
+
def remove_norms(self):
|
920 |
+
"""Removes spectral and weightnorms from model. Call before inference"""
|
921 |
+
for name, module in self.named_modules():
|
922 |
+
try:
|
923 |
+
nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
|
924 |
+
print("Removed spectral norm from {}".format(name))
|
925 |
+
except:
|
926 |
+
pass
|
927 |
+
try:
|
928 |
+
nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
|
929 |
+
print("Removed spectral norm from {}".format(name))
|
930 |
+
except:
|
931 |
+
pass
|
932 |
+
try:
|
933 |
+
nn.utils.remove_weight_norm(module)
|
934 |
+
print("Removed wnorm from {}".format(name))
|
935 |
+
except:
|
936 |
+
pass
|
requirements-dev.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ruff
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
|
3 |
+
gradio==5.18.0
|
4 |
+
|
5 |
+
torch
|
6 |
+
torchaudio
|
7 |
+
scipy
|
8 |
+
numba
|
9 |
+
lmdb
|
10 |
+
librosa
|
11 |
+
|
12 |
+
unidecode
|
13 |
+
inflect
|
splines.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original Source:
|
2 |
+
# Original Source:
|
3 |
+
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py
|
4 |
+
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py
|
5 |
+
# Modifications made to jacobian computation by Yurong You and Kevin Shih
|
6 |
+
# Original License Text:
|
7 |
+
#########################################################################
|
8 |
+
|
9 |
+
# The MIT License (MIT)
|
10 |
+
# Copyright (c) 2020, nicolas deutschmann
|
11 |
+
|
12 |
+
# Permission is hereby granted, free of charge, to any person obtaining
|
13 |
+
# a copy of this software and associated documentation files (the
|
14 |
+
# "Software"), to deal in the Software without restriction, including
|
15 |
+
# without limitation the rights to use, copy, modify, merge, publish,
|
16 |
+
# distribute, sublicense, and/or sell copies of the Software, and to
|
17 |
+
# permit persons to whom the Software is furnished to do so, subject to
|
18 |
+
# the following conditions:
|
19 |
+
|
20 |
+
# The above copyright notice and this permission notice shall be
|
21 |
+
# included in all copies or substantial portions of the Software.
|
22 |
+
|
23 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
24 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
25 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
26 |
+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
27 |
+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
28 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
29 |
+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
30 |
+
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch.nn.functional as F
|
34 |
+
|
35 |
+
third_dimension_softmax = torch.nn.Softmax(dim=2)
|
36 |
+
|
37 |
+
|
38 |
+
def piecewise_linear_transform(
|
39 |
+
x, q_tilde, compute_jacobian=True, outlier_passthru=True
|
40 |
+
):
|
41 |
+
"""Apply an element-wise piecewise-linear transformation to some variables
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
x : torch.Tensor
|
46 |
+
a tensor with shape (N,k) where N is the batch dimension while k is the
|
47 |
+
dimension of the variable space. This variable span the k-dimensional unit
|
48 |
+
hypercube
|
49 |
+
|
50 |
+
q_tilde: torch.Tensor
|
51 |
+
is a tensor with shape (N,k,b) where b is the number of bins.
|
52 |
+
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
|
53 |
+
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
|
54 |
+
Normalization is imposed in this function using softmax.
|
55 |
+
|
56 |
+
compute_jacobian : bool, optional
|
57 |
+
determines whether the jacobian should be compute or None is returned
|
58 |
+
|
59 |
+
Returns
|
60 |
+
-------
|
61 |
+
tuple of torch.Tensor
|
62 |
+
pair `(y,h)`.
|
63 |
+
- `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
|
64 |
+
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
|
65 |
+
"""
|
66 |
+
logj = None
|
67 |
+
|
68 |
+
# TODO bottom-up assesment of handling the differentiability of variables
|
69 |
+
# Compute the bin width w
|
70 |
+
N, k, b = q_tilde.shape
|
71 |
+
Nx, kx = x.shape
|
72 |
+
assert N == Nx and k == kx, "Shape mismatch"
|
73 |
+
|
74 |
+
w = 1.0 / b
|
75 |
+
|
76 |
+
# Compute normalized bin heights with softmax function on bin dimension
|
77 |
+
q = 1.0 / w * third_dimension_softmax(q_tilde)
|
78 |
+
# x is in the mx-th bin: x \in [0,1],
|
79 |
+
# mx \in [[0,b-1]], so we clamp away the case x == 1
|
80 |
+
mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long)
|
81 |
+
# Need special error handling because trying to index with mx
|
82 |
+
# if it contains nans will lock the GPU. (device-side assert triggered)
|
83 |
+
if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b):
|
84 |
+
raise Exception("NaN detected in PWLinear bin indexing")
|
85 |
+
|
86 |
+
# We compute the output variable in-place
|
87 |
+
out = x - mx * w # alpha (element of [0.,w], the position of x in its bin
|
88 |
+
|
89 |
+
# Multiply by the slope
|
90 |
+
# q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index
|
91 |
+
# gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value
|
92 |
+
# i.e. we say slope[i, j] = q[i, j, mx [i, j]]
|
93 |
+
slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1)
|
94 |
+
out = out * slopes
|
95 |
+
# The jacobian is the product of the slopes in all dimensions
|
96 |
+
|
97 |
+
# Compute the integral over the left-bins.
|
98 |
+
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
|
99 |
+
# We want that index i contains the cumsum *strictly to the left* so we shift by 1
|
100 |
+
# leaving the first entry null, which is achieved with a roll and assignment
|
101 |
+
q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2)
|
102 |
+
q_left_integrals[:, :, 0] = 0
|
103 |
+
|
104 |
+
# 2. Access the correct index to get the left integral of each point and add it to our transformation
|
105 |
+
out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1)
|
106 |
+
|
107 |
+
# Regularization: points must be strictly within the unit hypercube
|
108 |
+
# Use the dtype information from pytorch
|
109 |
+
eps = torch.finfo(out.dtype).eps
|
110 |
+
out = out.clamp(min=eps, max=1.0 - eps)
|
111 |
+
oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float()
|
112 |
+
if outlier_passthru:
|
113 |
+
out = out * (1 - oob_mask) + x * oob_mask
|
114 |
+
slopes = slopes * (1 - oob_mask) + oob_mask
|
115 |
+
|
116 |
+
if compute_jacobian:
|
117 |
+
# logj = torch.log(torch.prod(slopes.float(), 1))
|
118 |
+
logj = torch.sum(torch.log(slopes), 1)
|
119 |
+
del slopes
|
120 |
+
|
121 |
+
return out, logj
|
122 |
+
|
123 |
+
|
124 |
+
def piecewise_linear_inverse_transform(
|
125 |
+
y, q_tilde, compute_jacobian=True, outlier_passthru=True
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Apply inverse of an element-wise piecewise-linear transformation to some
|
129 |
+
variables
|
130 |
+
|
131 |
+
Parameters
|
132 |
+
----------
|
133 |
+
y : torch.Tensor
|
134 |
+
a tensor with shape (N,k) where N is the batch dimension while k is the
|
135 |
+
dimension of the variable space. This variable span the k-dimensional unit
|
136 |
+
hypercube
|
137 |
+
|
138 |
+
q_tilde: torch.Tensor
|
139 |
+
is a tensor with shape (N,k,b) where b is the number of bins.
|
140 |
+
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
|
141 |
+
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
|
142 |
+
Normalization is imposed in this function using softmax.
|
143 |
+
|
144 |
+
compute_jacobian : bool, optional
|
145 |
+
determines whether the jacobian should be compute or None is returned
|
146 |
+
|
147 |
+
Returns
|
148 |
+
-------
|
149 |
+
tuple of torch.Tensor
|
150 |
+
pair `(x,h)`.
|
151 |
+
- `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
|
152 |
+
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# TODO bottom-up assesment of handling the differentiability of variables
|
156 |
+
|
157 |
+
# Compute the bin width w
|
158 |
+
N, k, b = q_tilde.shape
|
159 |
+
Ny, ky = y.shape
|
160 |
+
assert N == Ny and k == ky, "Shape mismatch"
|
161 |
+
|
162 |
+
w = 1.0 / b
|
163 |
+
|
164 |
+
# Compute normalized bin heights with softmax function on the bin dimension
|
165 |
+
q = 1.0 / w * third_dimension_softmax(q_tilde)
|
166 |
+
|
167 |
+
# Compute the integral over the left-bins in the forward transform.
|
168 |
+
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
|
169 |
+
# We want that index i contains the cumsum *strictly to the left*,
|
170 |
+
# so we shift by 1 leaving the first entry null,
|
171 |
+
# which is achieved with a roll and assignment
|
172 |
+
q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2)
|
173 |
+
q_left_integrals[:, :, 0] = 0
|
174 |
+
|
175 |
+
# Find which bin each y belongs to by finding the smallest bin such that
|
176 |
+
# y - q_left_integral is positive
|
177 |
+
|
178 |
+
edges = (y.unsqueeze(-1) - q_left_integrals).detach()
|
179 |
+
# y and q_left_integrals are between 0 and 1,
|
180 |
+
# so that their difference is at most 1.
|
181 |
+
# By setting the negative values to 2., we know that the
|
182 |
+
# smallest value left is the smallest positive
|
183 |
+
edges[edges < 0] = 2.0
|
184 |
+
edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long)
|
185 |
+
|
186 |
+
# Need special error handling because trying to index with mx
|
187 |
+
# if it contains nans will lock the GPU. (device-side assert triggered)
|
188 |
+
if (
|
189 |
+
torch.any(torch.isnan(edges)).item()
|
190 |
+
or torch.any(edges < 0)
|
191 |
+
or torch.any(edges >= b)
|
192 |
+
):
|
193 |
+
raise Exception("NaN detected in PWLinear bin indexing")
|
194 |
+
|
195 |
+
# Gather the left integrals at each edge. See comment about gathering in q_left_integrals
|
196 |
+
# for the unsqueeze
|
197 |
+
q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1)
|
198 |
+
|
199 |
+
# Gather the slope at each edge.
|
200 |
+
q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1)
|
201 |
+
|
202 |
+
# Build the output
|
203 |
+
x = (y - q_left_integrals) / q + edges * w
|
204 |
+
|
205 |
+
# Regularization: points must be strictly within the unit hypercube
|
206 |
+
# Use the dtype information from pytorch
|
207 |
+
eps = torch.finfo(x.dtype).eps
|
208 |
+
x = x.clamp(min=eps, max=1.0 - eps)
|
209 |
+
oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float()
|
210 |
+
if outlier_passthru:
|
211 |
+
x = x * (1 - oob_mask) + y * oob_mask
|
212 |
+
q = q * (1 - oob_mask) + oob_mask
|
213 |
+
|
214 |
+
# Prepare the jacobian
|
215 |
+
logj = None
|
216 |
+
if compute_jacobian:
|
217 |
+
# logj = - torch.log(torch.prod(q, 1))
|
218 |
+
logj = -torch.sum(torch.log(q.float()), 1)
|
219 |
+
return x.detach(), logj
|
220 |
+
|
221 |
+
|
222 |
+
def unbounded_piecewise_quadratic_transform(
|
223 |
+
x, w_tilde, v_tilde, upper=1, lower=0, inverse=False
|
224 |
+
):
|
225 |
+
assert upper > lower
|
226 |
+
_range = upper - lower
|
227 |
+
inside_interval_mask = (x >= lower) & (x < upper)
|
228 |
+
outside_interval_mask = ~inside_interval_mask
|
229 |
+
|
230 |
+
outputs = torch.zeros_like(x)
|
231 |
+
log_j = torch.zeros_like(x)
|
232 |
+
|
233 |
+
outputs[outside_interval_mask] = x[outside_interval_mask]
|
234 |
+
log_j[outside_interval_mask] = 0
|
235 |
+
|
236 |
+
output, _log_j = piecewise_quadratic_transform(
|
237 |
+
(x[inside_interval_mask] - lower) / _range,
|
238 |
+
w_tilde[inside_interval_mask, :],
|
239 |
+
v_tilde[inside_interval_mask, :],
|
240 |
+
inverse=inverse,
|
241 |
+
)
|
242 |
+
outputs[inside_interval_mask] = output * _range + lower
|
243 |
+
if not inverse:
|
244 |
+
# the before and after transformation cancel out, so the log_j would be just as it is.
|
245 |
+
log_j[inside_interval_mask] = _log_j
|
246 |
+
else:
|
247 |
+
log_j = None
|
248 |
+
return outputs, log_j
|
249 |
+
|
250 |
+
|
251 |
+
def weighted_softmax(v, w):
|
252 |
+
# to avoid NaN...
|
253 |
+
v = v - torch.max(v, dim=-1, keepdim=True)[0]
|
254 |
+
v = torch.exp(v) + 1e-8 # to avoid NaN...
|
255 |
+
v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True)
|
256 |
+
return v / v_sum
|
257 |
+
|
258 |
+
|
259 |
+
def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False):
|
260 |
+
"""Element-wise piecewise-quadratic transformation
|
261 |
+
Parameters
|
262 |
+
----------
|
263 |
+
x : torch.Tensor
|
264 |
+
*, The variable spans the D-dim unit hypercube ([0,1))
|
265 |
+
w_tilde : torch.Tensor
|
266 |
+
* x K defined in the paper
|
267 |
+
v_tilde : torch.Tensor
|
268 |
+
* x (K+1) defined in the paper
|
269 |
+
inverse : bool
|
270 |
+
forward or inverse
|
271 |
+
Returns
|
272 |
+
-------
|
273 |
+
c : torch.Tensor
|
274 |
+
*, transformed value
|
275 |
+
log_j : torch.Tensor
|
276 |
+
*, log determinant of the Jacobian matrix
|
277 |
+
"""
|
278 |
+
w = torch.softmax(w_tilde, dim=-1)
|
279 |
+
v = weighted_softmax(v_tilde, w)
|
280 |
+
w_cumsum = torch.cumsum(w, dim=-1)
|
281 |
+
# force sum = 1
|
282 |
+
w_cumsum[..., -1] = 1.0
|
283 |
+
w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0)
|
284 |
+
cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1)
|
285 |
+
# force sum = 1
|
286 |
+
cdf[..., -1] = 1.0
|
287 |
+
cdf_shift = F.pad(cdf, (1, 0), "constant", 0)
|
288 |
+
|
289 |
+
if not inverse:
|
290 |
+
# * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx])
|
291 |
+
bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1))
|
292 |
+
else:
|
293 |
+
# * x D x 1, (cdf[idx-1] < x <= cdf[idx])
|
294 |
+
bin_index = torch.searchsorted(cdf, x.unsqueeze(-1))
|
295 |
+
|
296 |
+
w_b = torch.gather(w, -1, bin_index).squeeze(-1)
|
297 |
+
w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1)
|
298 |
+
v_b = torch.gather(v, -1, bin_index).squeeze(-1)
|
299 |
+
v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1)
|
300 |
+
cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1)
|
301 |
+
|
302 |
+
if not inverse:
|
303 |
+
alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps)
|
304 |
+
c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1
|
305 |
+
|
306 |
+
# just sum of log pdfs
|
307 |
+
log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log()
|
308 |
+
|
309 |
+
# make sure it falls into [0,1)
|
310 |
+
c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps)
|
311 |
+
return c, log_j
|
312 |
+
else:
|
313 |
+
# quadratic equation for alpha
|
314 |
+
# alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root
|
315 |
+
# skip calculating the log_j in inverse since we don't need it
|
316 |
+
a = (v_bp1 - v_b) * w_b / 2
|
317 |
+
b = v_b * w_b
|
318 |
+
c = cdf_bn1 - x
|
319 |
+
alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a)
|
320 |
+
inv = alpha * w_b + w_bn1
|
321 |
+
|
322 |
+
# make sure it falls into [0,1)
|
323 |
+
inv = inv.clamp(
|
324 |
+
min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps
|
325 |
+
)
|
326 |
+
return inv, None
|
transformer.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py
|
2 |
+
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from common import get_mask_from_lengths, LinearNorm
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEmbedding(nn.Module):
|
22 |
+
def __init__(self, demb):
|
23 |
+
super(PositionalEmbedding, self).__init__()
|
24 |
+
self.demb = demb
|
25 |
+
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
26 |
+
self.register_buffer("inv_freq", inv_freq)
|
27 |
+
|
28 |
+
def forward(self, pos_seq, bsz=None):
|
29 |
+
sinusoid_inp = torch.matmul(
|
30 |
+
torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)
|
31 |
+
)
|
32 |
+
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
33 |
+
if bsz is not None:
|
34 |
+
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
35 |
+
else:
|
36 |
+
return pos_emb[None, :, :]
|
37 |
+
|
38 |
+
|
39 |
+
class PositionwiseConvFF(nn.Module):
|
40 |
+
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
41 |
+
super(PositionwiseConvFF, self).__init__()
|
42 |
+
|
43 |
+
self.d_model = d_model
|
44 |
+
self.d_inner = d_inner
|
45 |
+
self.dropout = dropout
|
46 |
+
|
47 |
+
self.CoreNet = nn.Sequential(
|
48 |
+
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
49 |
+
nn.ReLU(),
|
50 |
+
# nn.Dropout(dropout), # worse convergence
|
51 |
+
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
52 |
+
nn.Dropout(dropout),
|
53 |
+
)
|
54 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
55 |
+
self.pre_lnorm = pre_lnorm
|
56 |
+
|
57 |
+
def forward(self, inp):
|
58 |
+
return self._forward(inp)
|
59 |
+
|
60 |
+
def _forward(self, inp):
|
61 |
+
if self.pre_lnorm:
|
62 |
+
# layer normalization + positionwise feed-forward
|
63 |
+
core_out = inp.transpose(1, 2)
|
64 |
+
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
|
65 |
+
core_out = core_out.transpose(1, 2)
|
66 |
+
|
67 |
+
# residual connection
|
68 |
+
output = core_out + inp
|
69 |
+
else:
|
70 |
+
# positionwise feed-forward
|
71 |
+
core_out = inp.transpose(1, 2)
|
72 |
+
core_out = self.CoreNet(core_out)
|
73 |
+
core_out = core_out.transpose(1, 2)
|
74 |
+
|
75 |
+
# residual connection + layer normalization
|
76 |
+
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
77 |
+
|
78 |
+
return output
|
79 |
+
|
80 |
+
|
81 |
+
class MultiHeadAttn(nn.Module):
|
82 |
+
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=False):
|
83 |
+
super(MultiHeadAttn, self).__init__()
|
84 |
+
|
85 |
+
self.n_head = n_head
|
86 |
+
self.d_model = d_model
|
87 |
+
self.d_head = d_head
|
88 |
+
self.scale = 1 / (d_head**0.5)
|
89 |
+
self.pre_lnorm = pre_lnorm
|
90 |
+
|
91 |
+
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
92 |
+
self.drop = nn.Dropout(dropout)
|
93 |
+
self.dropatt = nn.Dropout(dropatt)
|
94 |
+
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
95 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
96 |
+
|
97 |
+
def forward(self, inp, attn_mask=None):
|
98 |
+
return self._forward(inp, attn_mask)
|
99 |
+
|
100 |
+
def _forward(self, inp, attn_mask=None):
|
101 |
+
residual = inp
|
102 |
+
|
103 |
+
if self.pre_lnorm:
|
104 |
+
# layer normalization
|
105 |
+
inp = self.layer_norm(inp)
|
106 |
+
|
107 |
+
n_head, d_head = self.n_head, self.d_head
|
108 |
+
|
109 |
+
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
110 |
+
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
111 |
+
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
112 |
+
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
113 |
+
|
114 |
+
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
115 |
+
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
116 |
+
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
117 |
+
|
118 |
+
attn_score = torch.bmm(q, k.transpose(1, 2))
|
119 |
+
attn_score.mul_(self.scale)
|
120 |
+
|
121 |
+
if attn_mask is not None:
|
122 |
+
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
123 |
+
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
124 |
+
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
|
125 |
+
|
126 |
+
attn_prob = F.softmax(attn_score, dim=2)
|
127 |
+
attn_prob = self.dropatt(attn_prob)
|
128 |
+
attn_vec = torch.bmm(attn_prob, v)
|
129 |
+
|
130 |
+
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
131 |
+
attn_vec = (
|
132 |
+
attn_vec.permute(1, 2, 0, 3)
|
133 |
+
.contiguous()
|
134 |
+
.view(inp.size(0), inp.size(1), n_head * d_head)
|
135 |
+
)
|
136 |
+
|
137 |
+
# linear projection
|
138 |
+
attn_out = self.o_net(attn_vec)
|
139 |
+
attn_out = self.drop(attn_out)
|
140 |
+
|
141 |
+
# residual connection + layer normalization
|
142 |
+
output = self.layer_norm(residual + attn_out)
|
143 |
+
|
144 |
+
output = output.to(attn_out.dtype)
|
145 |
+
|
146 |
+
return output
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerLayer(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs
|
152 |
+
):
|
153 |
+
super(TransformerLayer, self).__init__()
|
154 |
+
|
155 |
+
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
156 |
+
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout)
|
157 |
+
|
158 |
+
def forward(self, dec_inp, mask=None):
|
159 |
+
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
160 |
+
output *= mask
|
161 |
+
output = self.pos_ff(output)
|
162 |
+
output *= mask
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
class FFTransformer(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
in_dim,
|
170 |
+
out_dim=1,
|
171 |
+
n_layers=6,
|
172 |
+
n_head=1,
|
173 |
+
d_head=64,
|
174 |
+
d_inner=1024,
|
175 |
+
kernel_size=3,
|
176 |
+
dropout=0.1,
|
177 |
+
dropatt=0.1,
|
178 |
+
dropemb=0.0,
|
179 |
+
):
|
180 |
+
super(FFTransformer, self).__init__()
|
181 |
+
self.in_dim = in_dim
|
182 |
+
self.out_dim = out_dim
|
183 |
+
self.n_head = n_head
|
184 |
+
self.d_head = d_head
|
185 |
+
|
186 |
+
self.pos_emb = PositionalEmbedding(self.in_dim)
|
187 |
+
self.drop = nn.Dropout(dropemb)
|
188 |
+
self.layers = nn.ModuleList()
|
189 |
+
|
190 |
+
for _ in range(n_layers):
|
191 |
+
self.layers.append(
|
192 |
+
TransformerLayer(
|
193 |
+
n_head,
|
194 |
+
in_dim,
|
195 |
+
d_head,
|
196 |
+
d_inner,
|
197 |
+
kernel_size,
|
198 |
+
dropout,
|
199 |
+
dropatt=dropatt,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
|
203 |
+
self.dense = LinearNorm(in_dim, out_dim)
|
204 |
+
|
205 |
+
def forward(self, dec_inp, in_lens):
|
206 |
+
# B, C, T --> B, T, C
|
207 |
+
inp = dec_inp.transpose(1, 2)
|
208 |
+
mask = get_mask_from_lengths(in_lens)[..., None]
|
209 |
+
|
210 |
+
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
|
211 |
+
pos_emb = self.pos_emb(pos_seq) * mask
|
212 |
+
|
213 |
+
out = self.drop(inp + pos_emb)
|
214 |
+
|
215 |
+
for layer in self.layers:
|
216 |
+
out = layer(out, mask=mask)
|
217 |
+
|
218 |
+
out = self.dense(out).transpose(1, 2)
|
219 |
+
return out
|
tts_text_processing/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
tts_text_processing/abbreviations.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
_no_period_re = re.compile(r"(No[.])(?=[ ]?[0-9])")
|
4 |
+
_percent_re = re.compile(r"([ ]?[%])")
|
5 |
+
_half_re = re.compile("([0-9]½)|(½)")
|
6 |
+
|
7 |
+
|
8 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
9 |
+
_abbreviations = [
|
10 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
11 |
+
for x in [
|
12 |
+
("mrs", "misess"),
|
13 |
+
("ms", "miss"),
|
14 |
+
("mr", "mister"),
|
15 |
+
("dr", "doctor"),
|
16 |
+
("st", "saint"),
|
17 |
+
("co", "company"),
|
18 |
+
("jr", "junior"),
|
19 |
+
("maj", "major"),
|
20 |
+
("gen", "general"),
|
21 |
+
("drs", "doctors"),
|
22 |
+
("rev", "reverend"),
|
23 |
+
("lt", "lieutenant"),
|
24 |
+
("hon", "honorable"),
|
25 |
+
("sgt", "sergeant"),
|
26 |
+
("capt", "captain"),
|
27 |
+
("esq", "esquire"),
|
28 |
+
("ltd", "limited"),
|
29 |
+
("col", "colonel"),
|
30 |
+
("ft", "fort"),
|
31 |
+
]
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def _expand_no_period(m):
|
36 |
+
word = m.group(0)
|
37 |
+
if word[0] == "N":
|
38 |
+
return "Number"
|
39 |
+
return "number"
|
40 |
+
|
41 |
+
|
42 |
+
def _expand_percent(m):
|
43 |
+
return " percent"
|
44 |
+
|
45 |
+
|
46 |
+
def _expand_half(m):
|
47 |
+
word = m.group(1)
|
48 |
+
if word is None:
|
49 |
+
return "half"
|
50 |
+
return word[0] + " and a half"
|
51 |
+
|
52 |
+
|
53 |
+
def normalize_abbreviations(text):
|
54 |
+
text = re.sub(_no_period_re, _expand_no_period, text)
|
55 |
+
text = re.sub(_percent_re, _expand_percent, text)
|
56 |
+
text = re.sub(_half_re, _expand_half, text)
|
57 |
+
return text
|
tts_text_processing/acronyms.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
_letter_to_arpabet = {
|
4 |
+
"A": "EY1",
|
5 |
+
"B": "B IY1",
|
6 |
+
"C": "S IY1",
|
7 |
+
"D": "D IY1",
|
8 |
+
"E": "IY1",
|
9 |
+
"F": "EH1 F",
|
10 |
+
"G": "JH IY1",
|
11 |
+
"H": "EY1 CH",
|
12 |
+
"I": "AY1",
|
13 |
+
"J": "JH EY1",
|
14 |
+
"K": "K EY1",
|
15 |
+
"L": "EH1 L",
|
16 |
+
"M": "EH1 M",
|
17 |
+
"N": "EH1 N",
|
18 |
+
"O": "OW1",
|
19 |
+
"P": "P IY1",
|
20 |
+
"Q": "K Y UW1",
|
21 |
+
"R": "AA1 R",
|
22 |
+
"S": "EH1 S",
|
23 |
+
"T": "T IY1",
|
24 |
+
"U": "Y UW1",
|
25 |
+
"V": "V IY1",
|
26 |
+
"X": "EH1 K S",
|
27 |
+
"Y": "W AY1",
|
28 |
+
"W": "D AH1 B AH0 L Y UW0",
|
29 |
+
"Z": "Z IY1",
|
30 |
+
"s": "Z",
|
31 |
+
}
|
32 |
+
|
33 |
+
# must ignore roman numerals
|
34 |
+
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
|
35 |
+
_acronym_re = re.compile(r"([A-Z][A-Z]+)s?")
|
36 |
+
|
37 |
+
|
38 |
+
class AcronymNormalizer(object):
|
39 |
+
def __init__(self, phoneme_dict):
|
40 |
+
self.phoneme_dict = phoneme_dict
|
41 |
+
|
42 |
+
def normalize_acronyms(self, text):
|
43 |
+
def _expand_acronyms(m, add_spaces=True):
|
44 |
+
acronym = m.group(0)
|
45 |
+
# remove dots if they exist
|
46 |
+
acronym = re.sub("\.", "", acronym)
|
47 |
+
|
48 |
+
acronym = "".join(acronym.split())
|
49 |
+
arpabet = self.phoneme_dict.lookup(acronym)
|
50 |
+
|
51 |
+
if arpabet is None:
|
52 |
+
acronym = list(acronym)
|
53 |
+
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
54 |
+
# temporary fix
|
55 |
+
if arpabet[-1] == "{Z}" and len(arpabet) > 1:
|
56 |
+
arpabet[-2] = arpabet[-2][:-1] + " " + arpabet[-1][1:]
|
57 |
+
del arpabet[-1]
|
58 |
+
arpabet = " ".join(arpabet)
|
59 |
+
elif len(arpabet) == 1:
|
60 |
+
arpabet = "{" + arpabet[0] + "}"
|
61 |
+
else:
|
62 |
+
arpabet = acronym
|
63 |
+
return arpabet
|
64 |
+
|
65 |
+
text = re.sub(_acronym_re, _expand_acronyms, text)
|
66 |
+
return text
|
67 |
+
|
68 |
+
def __call__(self, text):
|
69 |
+
return self.normalize_acronyms(text)
|
tts_text_processing/cleaners.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
"""
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
"""
|
14 |
+
|
15 |
+
import re
|
16 |
+
from string import punctuation
|
17 |
+
from functools import reduce
|
18 |
+
from unidecode import unidecode
|
19 |
+
from .numerical import normalize_numbers, normalize_currency
|
20 |
+
from .acronyms import AcronymNormalizer
|
21 |
+
from .datestime import normalize_datestime
|
22 |
+
from .letters_and_numbers import normalize_letters_and_numbers
|
23 |
+
from .abbreviations import normalize_abbreviations
|
24 |
+
|
25 |
+
|
26 |
+
# Regular expression matching whitespace:
|
27 |
+
_whitespace_re = re.compile(r"\s+")
|
28 |
+
|
29 |
+
# Regular expression separating words enclosed in curly braces for cleaning
|
30 |
+
_arpa_re = re.compile(r"{[^}]+}|\S+")
|
31 |
+
|
32 |
+
|
33 |
+
def expand_abbreviations(text):
|
34 |
+
return normalize_abbreviations(text)
|
35 |
+
|
36 |
+
|
37 |
+
def expand_numbers(text):
|
38 |
+
return normalize_numbers(text)
|
39 |
+
|
40 |
+
|
41 |
+
def expand_currency(text):
|
42 |
+
return normalize_currency(text)
|
43 |
+
|
44 |
+
|
45 |
+
def expand_datestime(text):
|
46 |
+
return normalize_datestime(text)
|
47 |
+
|
48 |
+
|
49 |
+
def expand_letters_and_numbers(text):
|
50 |
+
return normalize_letters_and_numbers(text)
|
51 |
+
|
52 |
+
|
53 |
+
def lowercase(text):
|
54 |
+
return text.lower()
|
55 |
+
|
56 |
+
|
57 |
+
def collapse_whitespace(text):
|
58 |
+
return re.sub(_whitespace_re, " ", text)
|
59 |
+
|
60 |
+
|
61 |
+
def separate_acronyms(text):
|
62 |
+
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
63 |
+
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
def convert_to_ascii(text):
|
68 |
+
return unidecode(text)
|
69 |
+
|
70 |
+
|
71 |
+
def dehyphenize_compound_words(text):
|
72 |
+
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=[a-zA-Z])", " ", text)
|
73 |
+
return text
|
74 |
+
|
75 |
+
|
76 |
+
def remove_space_before_punctuation(text):
|
77 |
+
return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
|
78 |
+
|
79 |
+
|
80 |
+
class Cleaner(object):
|
81 |
+
def __init__(self, cleaner_names, phonemedict):
|
82 |
+
self.cleaner_names = cleaner_names
|
83 |
+
self.phonemedict = phonemedict
|
84 |
+
self.acronym_normalizer = AcronymNormalizer(self.phonemedict)
|
85 |
+
|
86 |
+
def __call__(self, text):
|
87 |
+
for cleaner_name in self.cleaner_names:
|
88 |
+
sequence_fns, word_fns = self.get_cleaner_fns(cleaner_name)
|
89 |
+
for fn in sequence_fns:
|
90 |
+
text = fn(text)
|
91 |
+
|
92 |
+
text = [
|
93 |
+
reduce(lambda x, y: y(x), word_fns, split) if split[0] != "{" else split
|
94 |
+
for split in _arpa_re.findall(text)
|
95 |
+
]
|
96 |
+
text = " ".join(text)
|
97 |
+
text = remove_space_before_punctuation(text)
|
98 |
+
return text
|
99 |
+
|
100 |
+
def get_cleaner_fns(self, cleaner_name):
|
101 |
+
if cleaner_name == "basic_cleaners":
|
102 |
+
sequence_fns = [lowercase, collapse_whitespace]
|
103 |
+
word_fns = []
|
104 |
+
elif cleaner_name == "english_cleaners":
|
105 |
+
sequence_fns = [collapse_whitespace, convert_to_ascii, lowercase]
|
106 |
+
word_fns = [expand_numbers, expand_abbreviations]
|
107 |
+
elif cleaner_name == "radtts_cleaners":
|
108 |
+
sequence_fns = [
|
109 |
+
collapse_whitespace,
|
110 |
+
expand_currency,
|
111 |
+
expand_datestime,
|
112 |
+
expand_letters_and_numbers,
|
113 |
+
]
|
114 |
+
word_fns = [expand_numbers, expand_abbreviations]
|
115 |
+
elif cleaner_name == "ukrainian_cleaners":
|
116 |
+
sequence_fns = [lowercase, collapse_whitespace]
|
117 |
+
word_fns = []
|
118 |
+
elif cleaner_name == "transliteration_cleaners":
|
119 |
+
sequence_fns = [convert_to_ascii, lowercase, collapse_whitespace]
|
120 |
+
else:
|
121 |
+
raise Exception("{} cleaner not supported".format(cleaner_name))
|
122 |
+
|
123 |
+
return sequence_fns, word_fns
|
tts_text_processing/cmudict.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
valid_symbols = [
|
7 |
+
"AA",
|
8 |
+
"AA0",
|
9 |
+
"AA1",
|
10 |
+
"AA2",
|
11 |
+
"AE",
|
12 |
+
"AE0",
|
13 |
+
"AE1",
|
14 |
+
"AE2",
|
15 |
+
"AH",
|
16 |
+
"AH0",
|
17 |
+
"AH1",
|
18 |
+
"AH2",
|
19 |
+
"AO",
|
20 |
+
"AO0",
|
21 |
+
"AO1",
|
22 |
+
"AO2",
|
23 |
+
"AW",
|
24 |
+
"AW0",
|
25 |
+
"AW1",
|
26 |
+
"AW2",
|
27 |
+
"AY",
|
28 |
+
"AY0",
|
29 |
+
"AY1",
|
30 |
+
"AY2",
|
31 |
+
"B",
|
32 |
+
"CH",
|
33 |
+
"D",
|
34 |
+
"DH",
|
35 |
+
"EH",
|
36 |
+
"EH0",
|
37 |
+
"EH1",
|
38 |
+
"EH2",
|
39 |
+
"ER",
|
40 |
+
"ER0",
|
41 |
+
"ER1",
|
42 |
+
"ER2",
|
43 |
+
"EY",
|
44 |
+
"EY0",
|
45 |
+
"EY1",
|
46 |
+
"EY2",
|
47 |
+
"F",
|
48 |
+
"G",
|
49 |
+
"HH",
|
50 |
+
"IH",
|
51 |
+
"IH0",
|
52 |
+
"IH1",
|
53 |
+
"IH2",
|
54 |
+
"IY",
|
55 |
+
"IY0",
|
56 |
+
"IY1",
|
57 |
+
"IY2",
|
58 |
+
"JH",
|
59 |
+
"K",
|
60 |
+
"L",
|
61 |
+
"M",
|
62 |
+
"N",
|
63 |
+
"NG",
|
64 |
+
"OW",
|
65 |
+
"OW0",
|
66 |
+
"OW1",
|
67 |
+
"OW2",
|
68 |
+
"OY",
|
69 |
+
"OY0",
|
70 |
+
"OY1",
|
71 |
+
"OY2",
|
72 |
+
"P",
|
73 |
+
"R",
|
74 |
+
"S",
|
75 |
+
"SH",
|
76 |
+
"T",
|
77 |
+
"TH",
|
78 |
+
"UH",
|
79 |
+
"UH0",
|
80 |
+
"UH1",
|
81 |
+
"UH2",
|
82 |
+
"UW",
|
83 |
+
"UW0",
|
84 |
+
"UW1",
|
85 |
+
"UW2",
|
86 |
+
"V",
|
87 |
+
"W",
|
88 |
+
"Y",
|
89 |
+
"Z",
|
90 |
+
"ZH",
|
91 |
+
]
|
92 |
+
|
93 |
+
_valid_symbol_set = set(valid_symbols)
|
94 |
+
|
95 |
+
|
96 |
+
class CMUDict:
|
97 |
+
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
98 |
+
|
99 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
100 |
+
if isinstance(file_or_path, str):
|
101 |
+
with open(file_or_path, encoding="latin-1") as f:
|
102 |
+
entries = _parse_cmudict(f)
|
103 |
+
else:
|
104 |
+
entries = _parse_cmudict(file_or_path)
|
105 |
+
if not keep_ambiguous:
|
106 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
107 |
+
self._entries = entries
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self._entries)
|
111 |
+
|
112 |
+
def lookup(self, word):
|
113 |
+
"""Returns list of ARPAbet pronunciations of the given word."""
|
114 |
+
return self._entries.get(word.upper())
|
115 |
+
|
116 |
+
|
117 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
118 |
+
|
119 |
+
|
120 |
+
def _parse_cmudict(file):
|
121 |
+
cmudict = {}
|
122 |
+
for line in file:
|
123 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
124 |
+
parts = line.split(" ")
|
125 |
+
word = re.sub(_alt_re, "", parts[0])
|
126 |
+
pronunciation = _get_pronunciation(parts[1])
|
127 |
+
if pronunciation:
|
128 |
+
if word in cmudict:
|
129 |
+
cmudict[word].append(pronunciation)
|
130 |
+
else:
|
131 |
+
cmudict[word] = [pronunciation]
|
132 |
+
return cmudict
|
133 |
+
|
134 |
+
|
135 |
+
def _get_pronunciation(s):
|
136 |
+
parts = s.strip().split(" ")
|
137 |
+
for part in parts:
|
138 |
+
if part not in _valid_symbol_set:
|
139 |
+
return None
|
140 |
+
return " ".join(parts)
|
tts_text_processing/datestime.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_ampm_re = re.compile(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)")
|
6 |
+
|
7 |
+
|
8 |
+
def _expand_ampm(m):
|
9 |
+
matches = list(m.groups(0))
|
10 |
+
txt = matches[0]
|
11 |
+
txt = txt if int(matches[1]) == 0 else txt + " " + matches[1]
|
12 |
+
|
13 |
+
if matches[2][0].lower() == "a":
|
14 |
+
txt += " a.m."
|
15 |
+
elif matches[2][0].lower() == "p":
|
16 |
+
txt += " p.m."
|
17 |
+
|
18 |
+
return txt
|
19 |
+
|
20 |
+
|
21 |
+
def normalize_datestime(text):
|
22 |
+
text = re.sub(_ampm_re, _expand_ampm, text)
|
23 |
+
# text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
24 |
+
return text
|
tts_text_processing/grapheme_dictionary.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
6 |
+
|
7 |
+
|
8 |
+
class Grapheme2PhonemeDictionary:
|
9 |
+
"""Thin wrapper around g2p data."""
|
10 |
+
|
11 |
+
def __init__(self, file_or_path, keep_ambiguous=True, encoding="latin-1"):
|
12 |
+
with open(file_or_path, encoding=encoding) as f:
|
13 |
+
entries = _parse_g2p(f)
|
14 |
+
if not keep_ambiguous:
|
15 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
16 |
+
self._entries = entries
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self._entries)
|
20 |
+
|
21 |
+
def lookup(self, word):
|
22 |
+
"""Returns list of pronunciations of the given word."""
|
23 |
+
return self._entries.get(word.upper())
|
24 |
+
|
25 |
+
|
26 |
+
def _parse_g2p(file):
|
27 |
+
g2p = {}
|
28 |
+
for line in file:
|
29 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
30 |
+
parts = line.split(" ")
|
31 |
+
word = re.sub(_alt_re, "", parts[0])
|
32 |
+
pronunciation = parts[1].strip()
|
33 |
+
if word in g2p:
|
34 |
+
g2p[word].append(pronunciation)
|
35 |
+
else:
|
36 |
+
g2p[word] = [pronunciation]
|
37 |
+
return g2p
|
tts_text_processing/heteronyms
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
abject
|
2 |
+
abrogate
|
3 |
+
absent
|
4 |
+
abstract
|
5 |
+
abuse
|
6 |
+
ache
|
7 |
+
acre
|
8 |
+
acuminate
|
9 |
+
addict
|
10 |
+
address
|
11 |
+
adduct
|
12 |
+
adele
|
13 |
+
advocate
|
14 |
+
affect
|
15 |
+
affiliate
|
16 |
+
agape
|
17 |
+
aged
|
18 |
+
agglomerate
|
19 |
+
aggregate
|
20 |
+
agonic
|
21 |
+
agora
|
22 |
+
allied
|
23 |
+
ally
|
24 |
+
alternate
|
25 |
+
alum
|
26 |
+
am
|
27 |
+
analyses
|
28 |
+
andrea
|
29 |
+
animate
|
30 |
+
apply
|
31 |
+
appropriate
|
32 |
+
approximate
|
33 |
+
ares
|
34 |
+
arithmetic
|
35 |
+
arsenic
|
36 |
+
articulate
|
37 |
+
associate
|
38 |
+
attribute
|
39 |
+
august
|
40 |
+
axes
|
41 |
+
ay
|
42 |
+
aye
|
43 |
+
bases
|
44 |
+
bass
|
45 |
+
bathed
|
46 |
+
bested
|
47 |
+
bifurcate
|
48 |
+
blessed
|
49 |
+
blotto
|
50 |
+
bow
|
51 |
+
bowed
|
52 |
+
bowman
|
53 |
+
brassy
|
54 |
+
buffet
|
55 |
+
bustier
|
56 |
+
carbonate
|
57 |
+
celtic
|
58 |
+
choral
|
59 |
+
chumash
|
60 |
+
close
|
61 |
+
closer
|
62 |
+
coax
|
63 |
+
coincidence
|
64 |
+
color coordinate
|
65 |
+
colour coordinate
|
66 |
+
comber
|
67 |
+
combine
|
68 |
+
combs
|
69 |
+
committee
|
70 |
+
commune
|
71 |
+
compact
|
72 |
+
complex
|
73 |
+
compound
|
74 |
+
compress
|
75 |
+
concert
|
76 |
+
conduct
|
77 |
+
confine
|
78 |
+
confines
|
79 |
+
conflict
|
80 |
+
conglomerate
|
81 |
+
conscript
|
82 |
+
conserve
|
83 |
+
consist
|
84 |
+
console
|
85 |
+
consort
|
86 |
+
construct
|
87 |
+
consult
|
88 |
+
consummate
|
89 |
+
content
|
90 |
+
contest
|
91 |
+
contract
|
92 |
+
contracts
|
93 |
+
contrast
|
94 |
+
converse
|
95 |
+
convert
|
96 |
+
convict
|
97 |
+
coop
|
98 |
+
coordinate
|
99 |
+
covey
|
100 |
+
crooked
|
101 |
+
curate
|
102 |
+
cussed
|
103 |
+
decollate
|
104 |
+
decrease
|
105 |
+
defect
|
106 |
+
defense
|
107 |
+
delegate
|
108 |
+
deliberate
|
109 |
+
denier
|
110 |
+
desert
|
111 |
+
detail
|
112 |
+
deviate
|
113 |
+
diagnoses
|
114 |
+
diffuse
|
115 |
+
digest
|
116 |
+
discard
|
117 |
+
discharge
|
118 |
+
discount
|
119 |
+
do
|
120 |
+
document
|
121 |
+
does
|
122 |
+
dogged
|
123 |
+
domesticate
|
124 |
+
dominican
|
125 |
+
dove
|
126 |
+
dr
|
127 |
+
drawer
|
128 |
+
duplicate
|
129 |
+
egress
|
130 |
+
ejaculate
|
131 |
+
eject
|
132 |
+
elaborate
|
133 |
+
ellipses
|
134 |
+
email
|
135 |
+
emu
|
136 |
+
entrace
|
137 |
+
entrance
|
138 |
+
escort
|
139 |
+
estimate
|
140 |
+
eta
|
141 |
+
etna
|
142 |
+
evening
|
143 |
+
excise
|
144 |
+
excuse
|
145 |
+
exploit
|
146 |
+
export
|
147 |
+
extract
|
148 |
+
fine
|
149 |
+
flower
|
150 |
+
forbear
|
151 |
+
four-legged
|
152 |
+
frequent
|
153 |
+
furrier
|
154 |
+
gallant
|
155 |
+
gel
|
156 |
+
geminate
|
157 |
+
gillie
|
158 |
+
glower
|
159 |
+
gotham
|
160 |
+
graduate
|
161 |
+
haggis
|
162 |
+
heavy
|
163 |
+
hinder
|
164 |
+
house
|
165 |
+
housewife
|
166 |
+
impact
|
167 |
+
imped
|
168 |
+
implant
|
169 |
+
implement
|
170 |
+
import
|
171 |
+
impress
|
172 |
+
incense
|
173 |
+
incline
|
174 |
+
increase
|
175 |
+
infix
|
176 |
+
insert
|
177 |
+
instar
|
178 |
+
insult
|
179 |
+
integral
|
180 |
+
intercept
|
181 |
+
interchange
|
182 |
+
interflow
|
183 |
+
interleaf
|
184 |
+
intermediate
|
185 |
+
intern
|
186 |
+
interspace
|
187 |
+
intimate
|
188 |
+
intrigue
|
189 |
+
invalid
|
190 |
+
invert
|
191 |
+
invite
|
192 |
+
irony
|
193 |
+
jagged
|
194 |
+
jesses
|
195 |
+
julies
|
196 |
+
kite
|
197 |
+
laminate
|
198 |
+
laos
|
199 |
+
lather
|
200 |
+
lead
|
201 |
+
learned
|
202 |
+
leasing
|
203 |
+
lech
|
204 |
+
legitimate
|
205 |
+
lied
|
206 |
+
lima
|
207 |
+
lipread
|
208 |
+
live
|
209 |
+
lower
|
210 |
+
lunged
|
211 |
+
maas
|
212 |
+
magdalen
|
213 |
+
manes
|
214 |
+
mare
|
215 |
+
marked
|
216 |
+
merchandise
|
217 |
+
merlion
|
218 |
+
minute
|
219 |
+
misconduct
|
220 |
+
misled
|
221 |
+
misprint
|
222 |
+
mobile
|
223 |
+
moderate
|
224 |
+
mong
|
225 |
+
moped
|
226 |
+
moth
|
227 |
+
mouth
|
228 |
+
mow
|
229 |
+
mpg
|
230 |
+
multiply
|
231 |
+
mush
|
232 |
+
nana
|
233 |
+
nice
|
234 |
+
nice
|
235 |
+
number
|
236 |
+
numerate
|
237 |
+
nun
|
238 |
+
object
|
239 |
+
opiate
|
240 |
+
ornament
|
241 |
+
outbox
|
242 |
+
outcry
|
243 |
+
outpour
|
244 |
+
outreach
|
245 |
+
outride
|
246 |
+
outright
|
247 |
+
outside
|
248 |
+
outwork
|
249 |
+
overall
|
250 |
+
overbid
|
251 |
+
overcall
|
252 |
+
overcast
|
253 |
+
overfall
|
254 |
+
overflow
|
255 |
+
overhaul
|
256 |
+
overhead
|
257 |
+
overlap
|
258 |
+
overlay
|
259 |
+
overuse
|
260 |
+
overweight
|
261 |
+
overwork
|
262 |
+
pace
|
263 |
+
palled
|
264 |
+
palling
|
265 |
+
para
|
266 |
+
pasty
|
267 |
+
pate
|
268 |
+
pauline
|
269 |
+
pedal
|
270 |
+
peer
|
271 |
+
perfect
|
272 |
+
periodic
|
273 |
+
permit
|
274 |
+
pervert
|
275 |
+
pinta
|
276 |
+
placer
|
277 |
+
platy
|
278 |
+
polish
|
279 |
+
polish
|
280 |
+
poll
|
281 |
+
pontificate
|
282 |
+
postulate
|
283 |
+
pram
|
284 |
+
prayer
|
285 |
+
precipitate
|
286 |
+
predate
|
287 |
+
predicate
|
288 |
+
prefix
|
289 |
+
preposition
|
290 |
+
present
|
291 |
+
pretest
|
292 |
+
primer
|
293 |
+
proceeds
|
294 |
+
produce
|
295 |
+
progress
|
296 |
+
project
|
297 |
+
proportionate
|
298 |
+
prospect
|
299 |
+
protest
|
300 |
+
pussy
|
301 |
+
putter
|
302 |
+
putting
|
303 |
+
quite
|
304 |
+
ragged
|
305 |
+
raven
|
306 |
+
re
|
307 |
+
read
|
308 |
+
reading
|
309 |
+
reading
|
310 |
+
real
|
311 |
+
rebel
|
312 |
+
recall
|
313 |
+
recap
|
314 |
+
recitative
|
315 |
+
recollect
|
316 |
+
record
|
317 |
+
recreate
|
318 |
+
recreation
|
319 |
+
redress
|
320 |
+
refill
|
321 |
+
refund
|
322 |
+
refuse
|
323 |
+
reject
|
324 |
+
relay
|
325 |
+
remake
|
326 |
+
repaint
|
327 |
+
reprint
|
328 |
+
reread
|
329 |
+
rerun
|
330 |
+
resent
|
331 |
+
reside
|
332 |
+
resign
|
333 |
+
respray
|
334 |
+
resume
|
335 |
+
retard
|
336 |
+
retest
|
337 |
+
retread
|
338 |
+
rewrite
|
339 |
+
root
|
340 |
+
routed
|
341 |
+
routing
|
342 |
+
row
|
343 |
+
rugged
|
344 |
+
rummy
|
345 |
+
sais
|
346 |
+
sake
|
347 |
+
sambuca
|
348 |
+
saucier
|
349 |
+
second
|
350 |
+
secrete
|
351 |
+
secreted
|
352 |
+
secreting
|
353 |
+
segment
|
354 |
+
separate
|
355 |
+
sewer
|
356 |
+
shirk
|
357 |
+
shower
|
358 |
+
sin
|
359 |
+
skied
|
360 |
+
slaver
|
361 |
+
slough
|
362 |
+
sow
|
363 |
+
spoof
|
364 |
+
squid
|
365 |
+
stingy
|
366 |
+
subject
|
367 |
+
subordinate
|
368 |
+
subvert
|
369 |
+
supply
|
370 |
+
supposed
|
371 |
+
survey
|
372 |
+
suspect
|
373 |
+
syringes
|
374 |
+
tabulate
|
375 |
+
tales
|
376 |
+
tarrier
|
377 |
+
tarry
|
378 |
+
taxes
|
379 |
+
taxis
|
380 |
+
tear
|
381 |
+
theron
|
382 |
+
thou
|
383 |
+
three-legged
|
384 |
+
tier
|
385 |
+
tinged
|
386 |
+
torment
|
387 |
+
transfer
|
388 |
+
transform
|
389 |
+
transplant
|
390 |
+
transport
|
391 |
+
transpose
|
392 |
+
tush
|
393 |
+
two-legged
|
394 |
+
unionised
|
395 |
+
unionized
|
396 |
+
update
|
397 |
+
uplift
|
398 |
+
upset
|
399 |
+
use
|
400 |
+
used
|
401 |
+
vale
|
402 |
+
violist
|
403 |
+
viva
|
404 |
+
ware
|
405 |
+
whinged
|
406 |
+
whoop
|
407 |
+
wicked
|
408 |
+
wind
|
409 |
+
windy
|
410 |
+
wino
|
411 |
+
won
|
412 |
+
worsted
|
413 |
+
wound
|
tts_text_processing/letters_and_numbers.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_letters_and_numbers_re = re.compile(
|
6 |
+
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE
|
7 |
+
)
|
8 |
+
|
9 |
+
_hardware_re = re.compile(
|
10 |
+
"([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)", re.IGNORECASE
|
11 |
+
)
|
12 |
+
_hardware_key = {
|
13 |
+
"tb": "terabyte",
|
14 |
+
"gb": "gigabyte",
|
15 |
+
"mb": "megabyte",
|
16 |
+
"kb": "kilobyte",
|
17 |
+
"ghz": "gigahertz",
|
18 |
+
"mhz": "megahertz",
|
19 |
+
"khz": "kilohertz",
|
20 |
+
"hz": "hertz",
|
21 |
+
"mm": "millimeter",
|
22 |
+
"cm": "centimeter",
|
23 |
+
"km": "kilometer",
|
24 |
+
}
|
25 |
+
|
26 |
+
_dimension_re = re.compile(
|
27 |
+
r"\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b"
|
28 |
+
)
|
29 |
+
_dimension_key = {"m": "meter", "in": "inch", "inch": "inch"}
|
30 |
+
|
31 |
+
|
32 |
+
def _expand_letters_and_numbers(m):
|
33 |
+
text = re.split(r"(\d+)", m.group(0))
|
34 |
+
|
35 |
+
# remove trailing space
|
36 |
+
if text[-1] == "":
|
37 |
+
text = text[:-1]
|
38 |
+
elif text[0] == "":
|
39 |
+
text = text[1:]
|
40 |
+
|
41 |
+
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
42 |
+
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
43 |
+
text[-2] = text[-2] + text[-1]
|
44 |
+
text = text[:-1]
|
45 |
+
|
46 |
+
# for combining digits 2 by 2
|
47 |
+
new_text = []
|
48 |
+
for i in range(len(text)):
|
49 |
+
string = text[i]
|
50 |
+
if string.isdigit() and len(string) < 5:
|
51 |
+
# heuristics
|
52 |
+
if len(string) > 2 and string[-2] == "0":
|
53 |
+
if string[-1] == "0":
|
54 |
+
string = [string]
|
55 |
+
else:
|
56 |
+
string = [string[:-3], string[-2], string[-1]]
|
57 |
+
elif len(string) % 2 == 0:
|
58 |
+
string = [string[i : i + 2] for i in range(0, len(string), 2)]
|
59 |
+
elif len(string) > 2:
|
60 |
+
string = [string[0]] + [
|
61 |
+
string[i : i + 2] for i in range(1, len(string), 2)
|
62 |
+
]
|
63 |
+
new_text.extend(string)
|
64 |
+
else:
|
65 |
+
new_text.append(string)
|
66 |
+
|
67 |
+
text = new_text
|
68 |
+
text = " ".join(text)
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
def _expand_hardware(m):
|
73 |
+
quantity, measure = m.groups(0)
|
74 |
+
measure = _hardware_key[measure.lower()]
|
75 |
+
if measure[-1] != "z" and float(quantity.replace(",", "")) > 1:
|
76 |
+
return "{} {}s".format(quantity, measure)
|
77 |
+
return "{} {}".format(quantity, measure)
|
78 |
+
|
79 |
+
|
80 |
+
def _expand_dimension(m):
|
81 |
+
text = "".join([x for x in m.groups(0) if x != 0])
|
82 |
+
text = text.replace(" x ", " by ")
|
83 |
+
text = text.replace("x", " by ")
|
84 |
+
if text.endswith(tuple(_dimension_key.keys())):
|
85 |
+
if text[-2].isdigit():
|
86 |
+
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
87 |
+
elif text[-3].isdigit():
|
88 |
+
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
89 |
+
return text
|
90 |
+
|
91 |
+
|
92 |
+
def normalize_letters_and_numbers(text):
|
93 |
+
text = re.sub(_hardware_re, _expand_hardware, text)
|
94 |
+
text = re.sub(_dimension_re, _expand_dimension, text)
|
95 |
+
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
96 |
+
return text
|
tts_text_processing/numerical.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
|
6 |
+
_magnitudes = ["trillion", "billion", "million", "thousand", "hundred", "m", "b", "t"]
|
7 |
+
_magnitudes_key = {"m": "million", "b": "billion", "t": "trillion"}
|
8 |
+
_measurements = "(f|c|k|d|m)"
|
9 |
+
_measurements_key = {"f": "fahrenheit", "c": "celsius", "k": "thousand", "m": "meters"}
|
10 |
+
_currency_key = {"$": "dollar", "£": "pound", "€": "euro", "₩": "won"}
|
11 |
+
_inflect = inflect.engine()
|
12 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
13 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
14 |
+
_currency_re = re.compile(
|
15 |
+
r"([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?".format(
|
16 |
+
"|".join(_magnitudes)
|
17 |
+
),
|
18 |
+
re.IGNORECASE,
|
19 |
+
)
|
20 |
+
_measurement_re = re.compile(
|
21 |
+
r"([0-9\.\,]*[0-9]+(\s)?{}\b)".format(_measurements), re.IGNORECASE
|
22 |
+
)
|
23 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
24 |
+
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
25 |
+
_roman_re = re.compile(
|
26 |
+
r"\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b"
|
27 |
+
) # avoid I
|
28 |
+
_multiply_re = re.compile(r"(\b[0-9]+)(x)([0-9]+)")
|
29 |
+
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
30 |
+
|
31 |
+
|
32 |
+
def _remove_commas(m):
|
33 |
+
return m.group(1).replace(",", "")
|
34 |
+
|
35 |
+
|
36 |
+
def _expand_decimal_point(m):
|
37 |
+
return m.group(1).replace(".", " point ")
|
38 |
+
|
39 |
+
|
40 |
+
def _expand_currency(m):
|
41 |
+
currency = _currency_key[m.group(1)]
|
42 |
+
quantity = m.group(2)
|
43 |
+
magnitude = m.group(3)
|
44 |
+
|
45 |
+
# remove commas from quantity to be able to convert to numerical
|
46 |
+
quantity = quantity.replace(",", "")
|
47 |
+
|
48 |
+
# check for million, billion, etc...
|
49 |
+
if magnitude is not None and magnitude.lower() in _magnitudes:
|
50 |
+
if len(magnitude) == 1:
|
51 |
+
magnitude = _magnitudes_key[magnitude.lower()]
|
52 |
+
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency + "s")
|
53 |
+
|
54 |
+
parts = quantity.split(".")
|
55 |
+
if len(parts) > 2:
|
56 |
+
return quantity + " " + currency + "s" # Unexpected format
|
57 |
+
|
58 |
+
dollars = int(parts[0]) if parts[0] else 0
|
59 |
+
|
60 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
61 |
+
if dollars and cents:
|
62 |
+
dollar_unit = currency if dollars == 1 else currency + "s"
|
63 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
64 |
+
return "{} {}, {} {}".format(
|
65 |
+
_expand_hundreds(dollars),
|
66 |
+
dollar_unit,
|
67 |
+
_inflect.number_to_words(cents),
|
68 |
+
cent_unit,
|
69 |
+
)
|
70 |
+
elif dollars:
|
71 |
+
dollar_unit = currency if dollars == 1 else currency + "s"
|
72 |
+
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
73 |
+
elif cents:
|
74 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
75 |
+
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
76 |
+
else:
|
77 |
+
return "zero" + " " + currency + "s"
|
78 |
+
|
79 |
+
|
80 |
+
def _expand_hundreds(text):
|
81 |
+
number = float(text)
|
82 |
+
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
83 |
+
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
84 |
+
else:
|
85 |
+
return _inflect.number_to_words(text)
|
86 |
+
|
87 |
+
|
88 |
+
def _expand_ordinal(m):
|
89 |
+
return _inflect.number_to_words(m.group(0))
|
90 |
+
|
91 |
+
|
92 |
+
def _expand_measurement(m):
|
93 |
+
_, number, measurement = re.split("(\d+(?:\.\d+)?)", m.group(0))
|
94 |
+
number = _inflect.number_to_words(number)
|
95 |
+
measurement = "".join(measurement.split())
|
96 |
+
measurement = _measurements_key[measurement.lower()]
|
97 |
+
return "{} {}".format(number, measurement)
|
98 |
+
|
99 |
+
|
100 |
+
def _expand_range(m):
|
101 |
+
return " to "
|
102 |
+
|
103 |
+
|
104 |
+
def _expand_multiply(m):
|
105 |
+
left = m.group(1)
|
106 |
+
right = m.group(3)
|
107 |
+
return "{} by {}".format(left, right)
|
108 |
+
|
109 |
+
|
110 |
+
def _expand_roman(m):
|
111 |
+
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
112 |
+
roman_numerals = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
|
113 |
+
result = 0
|
114 |
+
num = m.group(0)
|
115 |
+
for i, c in enumerate(num):
|
116 |
+
if (i + 1) == len(num) or roman_numerals[c] >= roman_numerals[num[i + 1]]:
|
117 |
+
result += roman_numerals[c]
|
118 |
+
else:
|
119 |
+
result -= roman_numerals[c]
|
120 |
+
return str(result)
|
121 |
+
|
122 |
+
|
123 |
+
def _expand_number(m):
|
124 |
+
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
125 |
+
number = int(number)
|
126 |
+
if (
|
127 |
+
number > 1000
|
128 |
+
and number < 10000
|
129 |
+
and (number % 100 == 0)
|
130 |
+
and (number % 1000 != 0)
|
131 |
+
):
|
132 |
+
text = _inflect.number_to_words(number // 100) + " hundred"
|
133 |
+
elif number > 1000 and number < 3000:
|
134 |
+
if number == 2000:
|
135 |
+
text = "two thousand"
|
136 |
+
elif number > 2000 and number < 2010:
|
137 |
+
text = "two thousand " + _inflect.number_to_words(number % 100)
|
138 |
+
elif number % 100 == 0:
|
139 |
+
text = _inflect.number_to_words(number // 100) + " hundred"
|
140 |
+
else:
|
141 |
+
number = _inflect.number_to_words(
|
142 |
+
number, andword="", zero="oh", group=2
|
143 |
+
).replace(", ", " ")
|
144 |
+
number = re.sub(r"-", " ", number)
|
145 |
+
text = number
|
146 |
+
else:
|
147 |
+
number = _inflect.number_to_words(number, andword="and")
|
148 |
+
number = re.sub(r"-", " ", number)
|
149 |
+
number = re.sub(r",", "", number)
|
150 |
+
text = number
|
151 |
+
|
152 |
+
if suffix in ("'s", "s"):
|
153 |
+
if text[-1] == "y":
|
154 |
+
text = text[:-1] + "ies"
|
155 |
+
else:
|
156 |
+
text = text + suffix
|
157 |
+
|
158 |
+
return text
|
159 |
+
|
160 |
+
|
161 |
+
def normalize_currency(text):
|
162 |
+
return re.sub(_currency_re, _expand_currency, text)
|
163 |
+
|
164 |
+
|
165 |
+
def normalize_numbers(text):
|
166 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
167 |
+
text = re.sub(_currency_re, _expand_currency, text)
|
168 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
169 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
170 |
+
# text = re.sub(_range_re, _expand_range, text)
|
171 |
+
# text = re.sub(_measurement_re, _expand_measurement, text)
|
172 |
+
text = re.sub(_roman_re, _expand_roman, text)
|
173 |
+
text = re.sub(_multiply_re, _expand_multiply, text)
|
174 |
+
text = re.sub(_number_re, _expand_number, text)
|
175 |
+
return text
|