antoniomae1234
commited on
Commit
•
2493d72
1
Parent(s):
945170a
changes in flenema
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .cardboardlint.yml +5 -0
- .circleci/config.yml +53 -0
- .compute +17 -0
- .dockerignore +1 -0
- .github/ISSUE_TEMPLATE.md +19 -0
- .github/PR_TEMPLATE.md +18 -0
- .github/stale.yml +19 -0
- .gitignore +132 -0
- .pylintrc +586 -0
- CODE_OF_CONDUCT.md +19 -0
- CODE_OWNERS.rst +75 -0
- CONTRIBUTING.md +51 -0
- LICENSE.txt +373 -0
- MANIFEST.in +11 -0
- README.md +281 -3
- TTS/.models.json +77 -0
- TTS/__init__.py +0 -0
- TTS/bin/__init__.py +0 -0
- TTS/bin/compute_attention_masks.py +166 -0
- TTS/bin/compute_embeddings.py +130 -0
- TTS/bin/compute_statistics.py +90 -0
- TTS/bin/convert_melgan_tflite.py +32 -0
- TTS/bin/convert_melgan_torch_to_tf.py +116 -0
- TTS/bin/convert_tacotron2_tflite.py +37 -0
- TTS/bin/convert_tacotron2_torch_to_tf.py +213 -0
- TTS/bin/distribute.py +69 -0
- TTS/bin/synthesize.py +218 -0
- TTS/bin/train_encoder.py +274 -0
- TTS/bin/train_glow_tts.py +657 -0
- TTS/bin/train_speedy_speech.py +618 -0
- TTS/bin/train_tacotron.py +731 -0
- TTS/bin/train_vocoder_gan.py +664 -0
- TTS/bin/train_vocoder_wavegrad.py +511 -0
- TTS/bin/train_vocoder_wavernn.py +539 -0
- TTS/bin/tune_wavegrad.py +91 -0
- TTS/server/README.md +65 -0
- TTS/server/__init__.py +0 -0
- TTS/server/conf.json +12 -0
- TTS/server/server.py +116 -0
- TTS/server/static/TTS_circle.png +0 -0
- TTS/server/templates/details.html +131 -0
- TTS/server/templates/index.html +114 -0
- TTS/speaker_encoder/README.md +18 -0
- TTS/speaker_encoder/__init__.py +0 -0
- TTS/speaker_encoder/config.json +103 -0
- TTS/speaker_encoder/dataset.py +169 -0
- TTS/speaker_encoder/losses.py +160 -0
- TTS/speaker_encoder/model.py +112 -0
- TTS/speaker_encoder/requirements.txt +2 -0
- TTS/speaker_encoder/umap.png +0 -0
.cardboardlint.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
linters:
|
2 |
+
- pylint:
|
3 |
+
# pylintrc: pylintrc
|
4 |
+
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
|
5 |
+
# exclude:
|
.circleci/config.yml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 2
|
2 |
+
|
3 |
+
workflows:
|
4 |
+
version: 2
|
5 |
+
test:
|
6 |
+
jobs:
|
7 |
+
- test-3.6
|
8 |
+
- test-3.7
|
9 |
+
- test-3.8
|
10 |
+
|
11 |
+
executor: ubuntu-latest
|
12 |
+
|
13 |
+
on:
|
14 |
+
push:
|
15 |
+
pull_request:
|
16 |
+
types: [opened, synchronize, reopened]
|
17 |
+
|
18 |
+
jobs:
|
19 |
+
test-3.6: &test-template
|
20 |
+
docker:
|
21 |
+
- image: circleci/python:3.6
|
22 |
+
resource_class: large
|
23 |
+
working_directory: ~/repo
|
24 |
+
steps:
|
25 |
+
- checkout
|
26 |
+
- run: |
|
27 |
+
sudo apt update
|
28 |
+
sudo apt install espeak git
|
29 |
+
- run: sudo pip install --upgrade pip
|
30 |
+
- run: sudo pip install -e .
|
31 |
+
- run: |
|
32 |
+
sudo pip install --quiet --upgrade cardboardlint pylint
|
33 |
+
cardboardlinter --refspec ${CIRCLE_BRANCH} -n auto
|
34 |
+
- run: nosetests tests --nocapture
|
35 |
+
- run: |
|
36 |
+
sudo ./tests/test_server_package.sh
|
37 |
+
sudo ./tests/test_glow-tts_train.sh
|
38 |
+
sudo ./tests/test_server_package.sh
|
39 |
+
sudo ./tests/test_tacotron_train.sh
|
40 |
+
sudo ./tests/test_vocoder_gan_train.sh
|
41 |
+
sudo ./tests/test_vocoder_wavegrad_train.sh
|
42 |
+
sudo ./tests/test_vocoder_wavernn_train.sh
|
43 |
+
sudo ./tests/test_speedy_speech_train.sh
|
44 |
+
|
45 |
+
test-3.7:
|
46 |
+
<<: *test-template
|
47 |
+
docker:
|
48 |
+
- image: circleci/python:3.7
|
49 |
+
|
50 |
+
test-3.8:
|
51 |
+
<<: *test-template
|
52 |
+
docker:
|
53 |
+
- image: circleci/python:3.8
|
.compute
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
yes | apt-get install sox
|
3 |
+
yes | apt-get install ffmpeg
|
4 |
+
yes | apt-get install espeak
|
5 |
+
yes | apt-get install tmux
|
6 |
+
yes | apt-get install zsh
|
7 |
+
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
|
8 |
+
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
|
9 |
+
sudo sh install.sh
|
10 |
+
# pip install pytorch==1.7.0+cu100
|
11 |
+
# python3 setup.py develop
|
12 |
+
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
|
13 |
+
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
|
14 |
+
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
|
15 |
+
# python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
|
16 |
+
# python3 distribute.py --config_path config.json
|
17 |
+
while true; do sleep 1000000; done
|
.dockerignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.git/
|
.github/ISSUE_TEMPLATE.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 'TTS Discourse '
|
3 |
+
about: Pls consider to use TTS Discourse page.
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
<b>Questions</b> will not be answered here!!
|
10 |
+
|
11 |
+
Help is much more valuable if it's shared publicly, so that more people can benefit from it.
|
12 |
+
|
13 |
+
Please consider posting on [TTS Discourse](https://discourse.mozilla.org/c/tts) page or matrix [chat room](https://matrix.to/#/!KTePhNahjgiVumkqca:matrix.org?via=matrix.org) if your issue is not directly related to TTS development (Bugs, code updates etc.).
|
14 |
+
|
15 |
+
You can also check https://github.com/mozilla/TTS/wiki/FAQ for common questions and answers.
|
16 |
+
|
17 |
+
Happy posting!
|
18 |
+
|
19 |
+
https://discourse.mozilla.org/c/tts
|
.github/PR_TEMPLATE.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 'Contribution Guideline '
|
3 |
+
about: Refer to Contirbution Guideline
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
### Contribution Guideline
|
10 |
+
|
11 |
+
Please send your PRs to `dev` branch if it is not directly related to a specific branch.
|
12 |
+
Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter.
|
13 |
+
We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the changed code, you can use the follow command:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install pylint cardboardlint
|
17 |
+
cardboardlinter --refspec master
|
18 |
+
```
|
.github/stale.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Number of days of inactivity before an issue becomes stale
|
2 |
+
daysUntilStale: 60
|
3 |
+
# Number of days of inactivity before a stale issue is closed
|
4 |
+
daysUntilClose: 7
|
5 |
+
# Issues with these labels will never be considered stale
|
6 |
+
exemptLabels:
|
7 |
+
- pinned
|
8 |
+
- security
|
9 |
+
# Label to use when marking an issue as stale
|
10 |
+
staleLabel: wontfix
|
11 |
+
# Comment to post when marking an issue as stale. Set to `false` to disable
|
12 |
+
markComment: >
|
13 |
+
This issue has been automatically marked as stale because it has not had
|
14 |
+
recent activity. It will be closed if no further activity occurs. Thank you
|
15 |
+
for your contributions. You might also look our discourse page for further help.
|
16 |
+
https://discourse.mozilla.org/c/tts
|
17 |
+
# Comment to post when closing a stale issue. Set to `false` to disable
|
18 |
+
closeComment: false
|
19 |
+
|
.gitignore
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WadaSNR/
|
2 |
+
.idea/
|
3 |
+
*.pyc
|
4 |
+
.DS_Store
|
5 |
+
./__init__.py
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
.hypothesis/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
.static_storage/
|
61 |
+
.media/
|
62 |
+
local_settings.py
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# pyenv
|
81 |
+
.python-version
|
82 |
+
|
83 |
+
# celery beat schedule file
|
84 |
+
celerybeat-schedule
|
85 |
+
|
86 |
+
# SageMath parsed files
|
87 |
+
*.sage.py
|
88 |
+
|
89 |
+
# Environments
|
90 |
+
.env
|
91 |
+
.venv
|
92 |
+
env/
|
93 |
+
venv/
|
94 |
+
ENV/
|
95 |
+
env.bak/
|
96 |
+
venv.bak/
|
97 |
+
|
98 |
+
# Spyder project settings
|
99 |
+
.spyderproject
|
100 |
+
.spyproject
|
101 |
+
|
102 |
+
# Rope project settings
|
103 |
+
.ropeproject
|
104 |
+
|
105 |
+
# mkdocs documentation
|
106 |
+
/site
|
107 |
+
|
108 |
+
# mypy
|
109 |
+
.mypy_cache/
|
110 |
+
|
111 |
+
# vim
|
112 |
+
*.swp
|
113 |
+
*.swm
|
114 |
+
*.swn
|
115 |
+
*.swo
|
116 |
+
|
117 |
+
# pytorch models
|
118 |
+
*.pth.tar
|
119 |
+
result/
|
120 |
+
|
121 |
+
# setup.py
|
122 |
+
version.py
|
123 |
+
|
124 |
+
# jupyter dummy files
|
125 |
+
core
|
126 |
+
|
127 |
+
tests/outputs/*
|
128 |
+
TODO.txt
|
129 |
+
.vscode/*
|
130 |
+
data/*
|
131 |
+
notebooks/data/*
|
132 |
+
TTS/tts/layers/glow_tts/monotonic_align/core.c
|
.pylintrc
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[MASTER]
|
2 |
+
|
3 |
+
# A comma-separated list of package or module names from where C extensions may
|
4 |
+
# be loaded. Extensions are loading into the active Python interpreter and may
|
5 |
+
# run arbitrary code.
|
6 |
+
extension-pkg-whitelist=
|
7 |
+
|
8 |
+
# Add files or directories to the blacklist. They should be base names, not
|
9 |
+
# paths.
|
10 |
+
ignore=CVS
|
11 |
+
|
12 |
+
# Add files or directories matching the regex patterns to the blacklist. The
|
13 |
+
# regex matches against base names, not paths.
|
14 |
+
ignore-patterns=
|
15 |
+
|
16 |
+
# Python code to execute, usually for sys.path manipulation such as
|
17 |
+
# pygtk.require().
|
18 |
+
#init-hook=
|
19 |
+
|
20 |
+
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
21 |
+
# number of processors available to use.
|
22 |
+
jobs=1
|
23 |
+
|
24 |
+
# Control the amount of potential inferred values when inferring a single
|
25 |
+
# object. This can help the performance when dealing with large functions or
|
26 |
+
# complex, nested conditions.
|
27 |
+
limit-inference-results=100
|
28 |
+
|
29 |
+
# List of plugins (as comma separated values of python modules names) to load,
|
30 |
+
# usually to register additional checkers.
|
31 |
+
load-plugins=
|
32 |
+
|
33 |
+
# Pickle collected data for later comparisons.
|
34 |
+
persistent=yes
|
35 |
+
|
36 |
+
# Specify a configuration file.
|
37 |
+
#rcfile=
|
38 |
+
|
39 |
+
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
40 |
+
# user-friendly hints instead of false-positive error messages.
|
41 |
+
suggestion-mode=yes
|
42 |
+
|
43 |
+
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
44 |
+
# active Python interpreter and may run arbitrary code.
|
45 |
+
unsafe-load-any-extension=no
|
46 |
+
|
47 |
+
|
48 |
+
[MESSAGES CONTROL]
|
49 |
+
|
50 |
+
# Only show warnings with the listed confidence levels. Leave empty to show
|
51 |
+
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
52 |
+
confidence=
|
53 |
+
|
54 |
+
# Disable the message, report, category or checker with the given id(s). You
|
55 |
+
# can either give multiple identifiers separated by comma (,) or put this
|
56 |
+
# option multiple times (only on the command line, not in the configuration
|
57 |
+
# file where it should appear only once). You can also use "--disable=all" to
|
58 |
+
# disable everything first and then reenable specific checks. For example, if
|
59 |
+
# you want to run only the similarities checker, you can use "--disable=all
|
60 |
+
# --enable=similarities". If you want to run only the classes checker, but have
|
61 |
+
# no Warning level messages displayed, use "--disable=all --enable=classes
|
62 |
+
# --disable=W".
|
63 |
+
disable=missing-docstring,
|
64 |
+
line-too-long,
|
65 |
+
fixme,
|
66 |
+
wrong-import-order,
|
67 |
+
ungrouped-imports,
|
68 |
+
wrong-import-position,
|
69 |
+
import-error,
|
70 |
+
invalid-name,
|
71 |
+
too-many-instance-attributes,
|
72 |
+
arguments-differ,
|
73 |
+
no-name-in-module,
|
74 |
+
no-member,
|
75 |
+
unsubscriptable-object,
|
76 |
+
print-statement,
|
77 |
+
parameter-unpacking,
|
78 |
+
unpacking-in-except,
|
79 |
+
old-raise-syntax,
|
80 |
+
backtick,
|
81 |
+
long-suffix,
|
82 |
+
old-ne-operator,
|
83 |
+
old-octal-literal,
|
84 |
+
import-star-module-level,
|
85 |
+
non-ascii-bytes-literal,
|
86 |
+
raw-checker-failed,
|
87 |
+
bad-inline-option,
|
88 |
+
locally-disabled,
|
89 |
+
file-ignored,
|
90 |
+
suppressed-message,
|
91 |
+
useless-suppression,
|
92 |
+
deprecated-pragma,
|
93 |
+
use-symbolic-message-instead,
|
94 |
+
useless-object-inheritance,
|
95 |
+
too-few-public-methods,
|
96 |
+
too-many-branches,
|
97 |
+
too-many-arguments,
|
98 |
+
too-many-locals,
|
99 |
+
too-many-statements,
|
100 |
+
apply-builtin,
|
101 |
+
basestring-builtin,
|
102 |
+
buffer-builtin,
|
103 |
+
cmp-builtin,
|
104 |
+
coerce-builtin,
|
105 |
+
execfile-builtin,
|
106 |
+
file-builtin,
|
107 |
+
long-builtin,
|
108 |
+
raw_input-builtin,
|
109 |
+
reduce-builtin,
|
110 |
+
standarderror-builtin,
|
111 |
+
unicode-builtin,
|
112 |
+
xrange-builtin,
|
113 |
+
coerce-method,
|
114 |
+
delslice-method,
|
115 |
+
getslice-method,
|
116 |
+
setslice-method,
|
117 |
+
no-absolute-import,
|
118 |
+
old-division,
|
119 |
+
dict-iter-method,
|
120 |
+
dict-view-method,
|
121 |
+
next-method-called,
|
122 |
+
metaclass-assignment,
|
123 |
+
indexing-exception,
|
124 |
+
raising-string,
|
125 |
+
reload-builtin,
|
126 |
+
oct-method,
|
127 |
+
hex-method,
|
128 |
+
nonzero-method,
|
129 |
+
cmp-method,
|
130 |
+
input-builtin,
|
131 |
+
round-builtin,
|
132 |
+
intern-builtin,
|
133 |
+
unichr-builtin,
|
134 |
+
map-builtin-not-iterating,
|
135 |
+
zip-builtin-not-iterating,
|
136 |
+
range-builtin-not-iterating,
|
137 |
+
filter-builtin-not-iterating,
|
138 |
+
using-cmp-argument,
|
139 |
+
eq-without-hash,
|
140 |
+
div-method,
|
141 |
+
idiv-method,
|
142 |
+
rdiv-method,
|
143 |
+
exception-message-attribute,
|
144 |
+
invalid-str-codec,
|
145 |
+
sys-max-int,
|
146 |
+
bad-python3-import,
|
147 |
+
deprecated-string-function,
|
148 |
+
deprecated-str-translate-call,
|
149 |
+
deprecated-itertools-function,
|
150 |
+
deprecated-types-field,
|
151 |
+
next-method-defined,
|
152 |
+
dict-items-not-iterating,
|
153 |
+
dict-keys-not-iterating,
|
154 |
+
dict-values-not-iterating,
|
155 |
+
deprecated-operator-function,
|
156 |
+
deprecated-urllib-function,
|
157 |
+
xreadlines-attribute,
|
158 |
+
deprecated-sys-function,
|
159 |
+
exception-escape,
|
160 |
+
comprehension-escape,
|
161 |
+
duplicate-code
|
162 |
+
|
163 |
+
# Enable the message, report, category or checker with the given id(s). You can
|
164 |
+
# either give multiple identifier separated by comma (,) or put this option
|
165 |
+
# multiple time (only on the command line, not in the configuration file where
|
166 |
+
# it should appear only once). See also the "--disable" option for examples.
|
167 |
+
enable=c-extension-no-member
|
168 |
+
|
169 |
+
|
170 |
+
[REPORTS]
|
171 |
+
|
172 |
+
# Python expression which should return a note less than 10 (10 is the highest
|
173 |
+
# note). You have access to the variables errors warning, statement which
|
174 |
+
# respectively contain the number of errors / warnings messages and the total
|
175 |
+
# number of statements analyzed. This is used by the global evaluation report
|
176 |
+
# (RP0004).
|
177 |
+
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
178 |
+
|
179 |
+
# Template used to display messages. This is a python new-style format string
|
180 |
+
# used to format the message information. See doc for all details.
|
181 |
+
#msg-template=
|
182 |
+
|
183 |
+
# Set the output format. Available formats are text, parseable, colorized, json
|
184 |
+
# and msvs (visual studio). You can also give a reporter class, e.g.
|
185 |
+
# mypackage.mymodule.MyReporterClass.
|
186 |
+
output-format=text
|
187 |
+
|
188 |
+
# Tells whether to display a full report or only the messages.
|
189 |
+
reports=no
|
190 |
+
|
191 |
+
# Activate the evaluation score.
|
192 |
+
score=yes
|
193 |
+
|
194 |
+
|
195 |
+
[REFACTORING]
|
196 |
+
|
197 |
+
# Maximum number of nested blocks for function / method body
|
198 |
+
max-nested-blocks=5
|
199 |
+
|
200 |
+
# Complete name of functions that never returns. When checking for
|
201 |
+
# inconsistent-return-statements if a never returning function is called then
|
202 |
+
# it will be considered as an explicit return statement and no message will be
|
203 |
+
# printed.
|
204 |
+
never-returning-functions=sys.exit
|
205 |
+
|
206 |
+
|
207 |
+
[LOGGING]
|
208 |
+
|
209 |
+
# Format style used to check logging format string. `old` means using %
|
210 |
+
# formatting, while `new` is for `{}` formatting.
|
211 |
+
logging-format-style=old
|
212 |
+
|
213 |
+
# Logging modules to check that the string format arguments are in logging
|
214 |
+
# function parameter format.
|
215 |
+
logging-modules=logging
|
216 |
+
|
217 |
+
|
218 |
+
[SPELLING]
|
219 |
+
|
220 |
+
# Limits count of emitted suggestions for spelling mistakes.
|
221 |
+
max-spelling-suggestions=4
|
222 |
+
|
223 |
+
# Spelling dictionary name. Available dictionaries: none. To make it working
|
224 |
+
# install python-enchant package..
|
225 |
+
spelling-dict=
|
226 |
+
|
227 |
+
# List of comma separated words that should not be checked.
|
228 |
+
spelling-ignore-words=
|
229 |
+
|
230 |
+
# A path to a file that contains private dictionary; one word per line.
|
231 |
+
spelling-private-dict-file=
|
232 |
+
|
233 |
+
# Tells whether to store unknown words to indicated private dictionary in
|
234 |
+
# --spelling-private-dict-file option instead of raising a message.
|
235 |
+
spelling-store-unknown-words=no
|
236 |
+
|
237 |
+
|
238 |
+
[MISCELLANEOUS]
|
239 |
+
|
240 |
+
# List of note tags to take in consideration, separated by a comma.
|
241 |
+
notes=FIXME,
|
242 |
+
XXX,
|
243 |
+
TODO
|
244 |
+
|
245 |
+
|
246 |
+
[TYPECHECK]
|
247 |
+
|
248 |
+
# List of decorators that produce context managers, such as
|
249 |
+
# contextlib.contextmanager. Add to this list to register other decorators that
|
250 |
+
# produce valid context managers.
|
251 |
+
contextmanager-decorators=contextlib.contextmanager
|
252 |
+
|
253 |
+
# List of members which are set dynamically and missed by pylint inference
|
254 |
+
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
255 |
+
# expressions are accepted.
|
256 |
+
generated-members=
|
257 |
+
|
258 |
+
# Tells whether missing members accessed in mixin class should be ignored. A
|
259 |
+
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
260 |
+
ignore-mixin-members=yes
|
261 |
+
|
262 |
+
# Tells whether to warn about missing members when the owner of the attribute
|
263 |
+
# is inferred to be None.
|
264 |
+
ignore-none=yes
|
265 |
+
|
266 |
+
# This flag controls whether pylint should warn about no-member and similar
|
267 |
+
# checks whenever an opaque object is returned when inferring. The inference
|
268 |
+
# can return multiple potential results while evaluating a Python object, but
|
269 |
+
# some branches might not be evaluated, which results in partial inference. In
|
270 |
+
# that case, it might be useful to still emit no-member and other checks for
|
271 |
+
# the rest of the inferred objects.
|
272 |
+
ignore-on-opaque-inference=yes
|
273 |
+
|
274 |
+
# List of class names for which member attributes should not be checked (useful
|
275 |
+
# for classes with dynamically set attributes). This supports the use of
|
276 |
+
# qualified names.
|
277 |
+
ignored-classes=optparse.Values,thread._local,_thread._local
|
278 |
+
|
279 |
+
# List of module names for which member attributes should not be checked
|
280 |
+
# (useful for modules/projects where namespaces are manipulated during runtime
|
281 |
+
# and thus existing member attributes cannot be deduced by static analysis. It
|
282 |
+
# supports qualified module names, as well as Unix pattern matching.
|
283 |
+
ignored-modules=
|
284 |
+
|
285 |
+
# Show a hint with possible names when a member name was not found. The aspect
|
286 |
+
# of finding the hint is based on edit distance.
|
287 |
+
missing-member-hint=yes
|
288 |
+
|
289 |
+
# The minimum edit distance a name should have in order to be considered a
|
290 |
+
# similar match for a missing member name.
|
291 |
+
missing-member-hint-distance=1
|
292 |
+
|
293 |
+
# The total number of similar names that should be taken in consideration when
|
294 |
+
# showing a hint for a missing member.
|
295 |
+
missing-member-max-choices=1
|
296 |
+
|
297 |
+
|
298 |
+
[VARIABLES]
|
299 |
+
|
300 |
+
# List of additional names supposed to be defined in builtins. Remember that
|
301 |
+
# you should avoid defining new builtins when possible.
|
302 |
+
additional-builtins=
|
303 |
+
|
304 |
+
# Tells whether unused global variables should be treated as a violation.
|
305 |
+
allow-global-unused-variables=yes
|
306 |
+
|
307 |
+
# List of strings which can identify a callback function by name. A callback
|
308 |
+
# name must start or end with one of those strings.
|
309 |
+
callbacks=cb_,
|
310 |
+
_cb
|
311 |
+
|
312 |
+
# A regular expression matching the name of dummy variables (i.e. expected to
|
313 |
+
# not be used).
|
314 |
+
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
315 |
+
|
316 |
+
# Argument names that match this expression will be ignored. Default to name
|
317 |
+
# with leading underscore.
|
318 |
+
ignored-argument-names=_.*|^ignored_|^unused_
|
319 |
+
|
320 |
+
# Tells whether we should check for unused import in __init__ files.
|
321 |
+
init-import=no
|
322 |
+
|
323 |
+
# List of qualified module names which can have objects that can redefine
|
324 |
+
# builtins.
|
325 |
+
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
326 |
+
|
327 |
+
|
328 |
+
[FORMAT]
|
329 |
+
|
330 |
+
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
331 |
+
expected-line-ending-format=
|
332 |
+
|
333 |
+
# Regexp for a line that is allowed to be longer than the limit.
|
334 |
+
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
335 |
+
|
336 |
+
# Number of spaces of indent required inside a hanging or continued line.
|
337 |
+
indent-after-paren=4
|
338 |
+
|
339 |
+
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
340 |
+
# tab).
|
341 |
+
indent-string=' '
|
342 |
+
|
343 |
+
# Maximum number of characters on a single line.
|
344 |
+
max-line-length=100
|
345 |
+
|
346 |
+
# Maximum number of lines in a module.
|
347 |
+
max-module-lines=1000
|
348 |
+
|
349 |
+
# List of optional constructs for which whitespace checking is disabled. `dict-
|
350 |
+
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
351 |
+
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
352 |
+
# `empty-line` allows space-only lines.
|
353 |
+
no-space-check=trailing-comma,
|
354 |
+
dict-separator
|
355 |
+
|
356 |
+
# Allow the body of a class to be on the same line as the declaration if body
|
357 |
+
# contains single statement.
|
358 |
+
single-line-class-stmt=no
|
359 |
+
|
360 |
+
# Allow the body of an if to be on the same line as the test if there is no
|
361 |
+
# else.
|
362 |
+
single-line-if-stmt=no
|
363 |
+
|
364 |
+
|
365 |
+
[SIMILARITIES]
|
366 |
+
|
367 |
+
# Ignore comments when computing similarities.
|
368 |
+
ignore-comments=yes
|
369 |
+
|
370 |
+
# Ignore docstrings when computing similarities.
|
371 |
+
ignore-docstrings=yes
|
372 |
+
|
373 |
+
# Ignore imports when computing similarities.
|
374 |
+
ignore-imports=no
|
375 |
+
|
376 |
+
# Minimum lines number of a similarity.
|
377 |
+
min-similarity-lines=4
|
378 |
+
|
379 |
+
|
380 |
+
[BASIC]
|
381 |
+
|
382 |
+
# Naming style matching correct argument names.
|
383 |
+
argument-naming-style=snake_case
|
384 |
+
|
385 |
+
# Regular expression matching correct argument names. Overrides argument-
|
386 |
+
# naming-style.
|
387 |
+
argument-rgx=[a-z_][a-z0-9_]{0,30}$
|
388 |
+
|
389 |
+
# Naming style matching correct attribute names.
|
390 |
+
attr-naming-style=snake_case
|
391 |
+
|
392 |
+
# Regular expression matching correct attribute names. Overrides attr-naming-
|
393 |
+
# style.
|
394 |
+
#attr-rgx=
|
395 |
+
|
396 |
+
# Bad variable names which should always be refused, separated by a comma.
|
397 |
+
bad-names=
|
398 |
+
|
399 |
+
# Naming style matching correct class attribute names.
|
400 |
+
class-attribute-naming-style=any
|
401 |
+
|
402 |
+
# Regular expression matching correct class attribute names. Overrides class-
|
403 |
+
# attribute-naming-style.
|
404 |
+
#class-attribute-rgx=
|
405 |
+
|
406 |
+
# Naming style matching correct class names.
|
407 |
+
class-naming-style=PascalCase
|
408 |
+
|
409 |
+
# Regular expression matching correct class names. Overrides class-naming-
|
410 |
+
# style.
|
411 |
+
#class-rgx=
|
412 |
+
|
413 |
+
# Naming style matching correct constant names.
|
414 |
+
const-naming-style=UPPER_CASE
|
415 |
+
|
416 |
+
# Regular expression matching correct constant names. Overrides const-naming-
|
417 |
+
# style.
|
418 |
+
#const-rgx=
|
419 |
+
|
420 |
+
# Minimum line length for functions/classes that require docstrings, shorter
|
421 |
+
# ones are exempt.
|
422 |
+
docstring-min-length=-1
|
423 |
+
|
424 |
+
# Naming style matching correct function names.
|
425 |
+
function-naming-style=snake_case
|
426 |
+
|
427 |
+
# Regular expression matching correct function names. Overrides function-
|
428 |
+
# naming-style.
|
429 |
+
#function-rgx=
|
430 |
+
|
431 |
+
# Good variable names which should always be accepted, separated by a comma.
|
432 |
+
good-names=i,
|
433 |
+
j,
|
434 |
+
k,
|
435 |
+
x,
|
436 |
+
ex,
|
437 |
+
Run,
|
438 |
+
_
|
439 |
+
|
440 |
+
# Include a hint for the correct naming format with invalid-name.
|
441 |
+
include-naming-hint=no
|
442 |
+
|
443 |
+
# Naming style matching correct inline iteration names.
|
444 |
+
inlinevar-naming-style=any
|
445 |
+
|
446 |
+
# Regular expression matching correct inline iteration names. Overrides
|
447 |
+
# inlinevar-naming-style.
|
448 |
+
#inlinevar-rgx=
|
449 |
+
|
450 |
+
# Naming style matching correct method names.
|
451 |
+
method-naming-style=snake_case
|
452 |
+
|
453 |
+
# Regular expression matching correct method names. Overrides method-naming-
|
454 |
+
# style.
|
455 |
+
#method-rgx=
|
456 |
+
|
457 |
+
# Naming style matching correct module names.
|
458 |
+
module-naming-style=snake_case
|
459 |
+
|
460 |
+
# Regular expression matching correct module names. Overrides module-naming-
|
461 |
+
# style.
|
462 |
+
#module-rgx=
|
463 |
+
|
464 |
+
# Colon-delimited sets of names that determine each other's naming style when
|
465 |
+
# the name regexes allow several styles.
|
466 |
+
name-group=
|
467 |
+
|
468 |
+
# Regular expression which should only match function or class names that do
|
469 |
+
# not require a docstring.
|
470 |
+
no-docstring-rgx=^_
|
471 |
+
|
472 |
+
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
473 |
+
# to this list to register other decorators that produce valid properties.
|
474 |
+
# These decorators are taken in consideration only for invalid-name.
|
475 |
+
property-classes=abc.abstractproperty
|
476 |
+
|
477 |
+
# Naming style matching correct variable names.
|
478 |
+
variable-naming-style=snake_case
|
479 |
+
|
480 |
+
# Regular expression matching correct variable names. Overrides variable-
|
481 |
+
# naming-style.
|
482 |
+
variable-rgx=[a-z_][a-z0-9_]{0,30}$
|
483 |
+
|
484 |
+
|
485 |
+
[STRING]
|
486 |
+
|
487 |
+
# This flag controls whether the implicit-str-concat-in-sequence should
|
488 |
+
# generate a warning on implicit string concatenation in sequences defined over
|
489 |
+
# several lines.
|
490 |
+
check-str-concat-over-line-jumps=no
|
491 |
+
|
492 |
+
|
493 |
+
[IMPORTS]
|
494 |
+
|
495 |
+
# Allow wildcard imports from modules that define __all__.
|
496 |
+
allow-wildcard-with-all=no
|
497 |
+
|
498 |
+
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
499 |
+
# 3 compatible code, which means that the block might have code that exists
|
500 |
+
# only in one or another interpreter, leading to false positives when analysed.
|
501 |
+
analyse-fallback-blocks=no
|
502 |
+
|
503 |
+
# Deprecated modules which should not be used, separated by a comma.
|
504 |
+
deprecated-modules=optparse,tkinter.tix
|
505 |
+
|
506 |
+
# Create a graph of external dependencies in the given file (report RP0402 must
|
507 |
+
# not be disabled).
|
508 |
+
ext-import-graph=
|
509 |
+
|
510 |
+
# Create a graph of every (i.e. internal and external) dependencies in the
|
511 |
+
# given file (report RP0402 must not be disabled).
|
512 |
+
import-graph=
|
513 |
+
|
514 |
+
# Create a graph of internal dependencies in the given file (report RP0402 must
|
515 |
+
# not be disabled).
|
516 |
+
int-import-graph=
|
517 |
+
|
518 |
+
# Force import order to recognize a module as part of the standard
|
519 |
+
# compatibility libraries.
|
520 |
+
known-standard-library=
|
521 |
+
|
522 |
+
# Force import order to recognize a module as part of a third party library.
|
523 |
+
known-third-party=enchant
|
524 |
+
|
525 |
+
|
526 |
+
[CLASSES]
|
527 |
+
|
528 |
+
# List of method names used to declare (i.e. assign) instance attributes.
|
529 |
+
defining-attr-methods=__init__,
|
530 |
+
__new__,
|
531 |
+
setUp
|
532 |
+
|
533 |
+
# List of member names, which should be excluded from the protected access
|
534 |
+
# warning.
|
535 |
+
exclude-protected=_asdict,
|
536 |
+
_fields,
|
537 |
+
_replace,
|
538 |
+
_source,
|
539 |
+
_make
|
540 |
+
|
541 |
+
# List of valid names for the first argument in a class method.
|
542 |
+
valid-classmethod-first-arg=cls
|
543 |
+
|
544 |
+
# List of valid names for the first argument in a metaclass class method.
|
545 |
+
valid-metaclass-classmethod-first-arg=cls
|
546 |
+
|
547 |
+
|
548 |
+
[DESIGN]
|
549 |
+
|
550 |
+
# Maximum number of arguments for function / method.
|
551 |
+
max-args=5
|
552 |
+
|
553 |
+
# Maximum number of attributes for a class (see R0902).
|
554 |
+
max-attributes=7
|
555 |
+
|
556 |
+
# Maximum number of boolean expressions in an if statement.
|
557 |
+
max-bool-expr=5
|
558 |
+
|
559 |
+
# Maximum number of branch for function / method body.
|
560 |
+
max-branches=12
|
561 |
+
|
562 |
+
# Maximum number of locals for function / method body.
|
563 |
+
max-locals=15
|
564 |
+
|
565 |
+
# Maximum number of parents for a class (see R0901).
|
566 |
+
max-parents=7
|
567 |
+
|
568 |
+
# Maximum number of public methods for a class (see R0904).
|
569 |
+
max-public-methods=20
|
570 |
+
|
571 |
+
# Maximum number of return / yield for function / method body.
|
572 |
+
max-returns=6
|
573 |
+
|
574 |
+
# Maximum number of statements in function / method body.
|
575 |
+
max-statements=50
|
576 |
+
|
577 |
+
# Minimum number of public methods for a class (see R0903).
|
578 |
+
min-public-methods=2
|
579 |
+
|
580 |
+
|
581 |
+
[EXCEPTIONS]
|
582 |
+
|
583 |
+
# Exceptions that will emit a warning when being caught. Defaults to
|
584 |
+
# "BaseException, Exception".
|
585 |
+
overgeneral-exceptions=BaseException,
|
586 |
+
Exception
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ethical Notice
|
2 |
+
|
3 |
+
Please consider possible consequences and be mindful of any adversarial use cases of this project. In this regard, please contact us if you have any concerns.
|
4 |
+
|
5 |
+
# Community Participation Guidelines
|
6 |
+
|
7 |
+
This repository is governed by Mozilla's code of conduct and etiquette guidelines.
|
8 |
+
For more details, please read the
|
9 |
+
[Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
|
10 |
+
|
11 |
+
## How to Report
|
12 |
+
For more information on how to report violations of the Community Participation Guidelines, please read our '[How to Report](https://www.mozilla.org/about/governance/policies/participation/reporting/)' page.
|
13 |
+
|
14 |
+
<!--
|
15 |
+
## Project Specific Etiquette
|
16 |
+
|
17 |
+
In some cases, there will be additional project etiquette i.e.: (https://bugzilla.mozilla.org/page.cgi?id=etiquette.html).
|
18 |
+
Please update for your project.
|
19 |
+
-->
|
CODE_OWNERS.rst
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TTS code owners / governance system
|
2 |
+
==========================================
|
3 |
+
|
4 |
+
TTS is run under a governance system inspired (and partially copied from) by the `Mozilla module ownership system <https://www.mozilla.org/about/governance/policies/module-ownership/>`_. The project is roughly divided into modules, and each module has its owners, which are responsible for reviewing pull requests and deciding on technical direction for their modules. Module ownership authority is given to people who have worked extensively on areas of the project.
|
5 |
+
|
6 |
+
Module owners also have the authority of naming other module owners or appointing module peers, which are people with authority to review pull requests in that module. They can also sub-divide their module into sub-modules with their owners.
|
7 |
+
|
8 |
+
Module owners are not tyrants. They are chartered to make decisions with input from the community and in the best interest of the community. Module owners are not required to make code changes or additions solely because the community wants them to do so. (Like anyone else, the module owners may write code because they want to, because their employers want them to, because the community wants them to, or for some other reason.) Module owners do need to pay attention to patches submitted to that module. However “pay attention” does not mean agreeing to every patch. Some patches may not make sense for the WebThings project; some may be poorly implemented. Module owners have the authority to decline a patch; this is a necessary part of the role. We ask the module owners to describe in the relevant issue their reasons for wanting changes to a patch, for declining it altogether, or for postponing review for some period. We don’t ask or expect them to rewrite patches to make them acceptable. Similarly, module owners may need to delay review of a promising patch due to an upcoming deadline. For example, a patch may be of interest, but not for the next milestone. In such a case it may make sense for the module owner to postpone review of a patch until after matters needed for a milestone have been finalized. Again, we expect this to be described in the relevant issue. And of course, it shouldn’t go on very often or for very long or escalation and review is likely.
|
9 |
+
|
10 |
+
The work of the various module owners and peers is overseen by the global owners, which are responsible for making final decisions in case there's conflict between owners as well as set the direction for the project as a whole.
|
11 |
+
|
12 |
+
This file describes module owners who are active on the project and which parts of the code they have expertise on (and interest in). If you're making changes to the code and are wondering who's an appropriate person to talk to, this list will tell you who to ping.
|
13 |
+
|
14 |
+
There's overlap in the areas of expertise of each owner, and in particular when looking at which files are covered by each area, there is a lot of overlap. Don't worry about getting it exactly right when requesting review, any code owner will be happy to redirect the request to a more appropriate person.
|
15 |
+
|
16 |
+
Global owners
|
17 |
+
----------------
|
18 |
+
|
19 |
+
These are people who have worked on the project extensively and are familiar with all or most parts of it. Their expertise and review guidance is trusted by other code owners to cover their own areas of expertise. In case of conflicting opinions from other owners, global owners will make a final decision.
|
20 |
+
|
21 |
+
- Eren Gölge (@erogol)
|
22 |
+
- Reuben Morais (@reuben)
|
23 |
+
|
24 |
+
Training, feeding
|
25 |
+
-----------------
|
26 |
+
|
27 |
+
- Eren Gölge (@erogol)
|
28 |
+
|
29 |
+
Model exporting
|
30 |
+
---------------
|
31 |
+
|
32 |
+
- Eren Gölge (@erogol)
|
33 |
+
|
34 |
+
Multi-Speaker TTS
|
35 |
+
-----------------
|
36 |
+
|
37 |
+
- Eren Gölge (@erogol)
|
38 |
+
- Edresson Casanova (@edresson)
|
39 |
+
|
40 |
+
TTS
|
41 |
+
---
|
42 |
+
|
43 |
+
- Eren Gölge (@erogol)
|
44 |
+
|
45 |
+
Vocoders
|
46 |
+
--------
|
47 |
+
|
48 |
+
- Eren Gölge (@erogol)
|
49 |
+
|
50 |
+
Speaker Encoder
|
51 |
+
---------------
|
52 |
+
|
53 |
+
- Eren Gölge (@erogol)
|
54 |
+
|
55 |
+
Testing & CI
|
56 |
+
------------
|
57 |
+
|
58 |
+
- Eren Gölge (@erogol)
|
59 |
+
- Reuben Morais (@reuben)
|
60 |
+
|
61 |
+
Python bindings
|
62 |
+
---------------
|
63 |
+
|
64 |
+
- Eren Gölge (@erogol)
|
65 |
+
- Reuben Morais (@reuben)
|
66 |
+
|
67 |
+
Documentation
|
68 |
+
-------------
|
69 |
+
|
70 |
+
- Eren Gölge (@erogol)
|
71 |
+
|
72 |
+
Third party bindings
|
73 |
+
--------------------
|
74 |
+
|
75 |
+
Owned by the author.
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contribution guidelines
|
2 |
+
|
3 |
+
This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
|
4 |
+
|
5 |
+
Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter. We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the differences between your work and master, you can use the follow command:
|
6 |
+
|
7 |
+
```bash
|
8 |
+
pip install pylint cardboardlint
|
9 |
+
cardboardlinter --refspec master
|
10 |
+
```
|
11 |
+
|
12 |
+
This will compare the code against master and run the linter on all the changes. To run it automatically as a git pre-commit hook, you can do do the following:
|
13 |
+
|
14 |
+
```bash
|
15 |
+
cat <<\EOF > .git/hooks/pre-commit
|
16 |
+
#!/bin/bash
|
17 |
+
if [ ! -x "$(command -v cardboardlinter)" ]; then
|
18 |
+
exit 0
|
19 |
+
fi
|
20 |
+
|
21 |
+
# First, stash index and work dir, keeping only the
|
22 |
+
# to-be-committed changes in the working directory.
|
23 |
+
echo "Stashing working tree changes..." 1>&2
|
24 |
+
old_stash=$(git rev-parse -q --verify refs/stash)
|
25 |
+
git stash save -q --keep-index
|
26 |
+
new_stash=$(git rev-parse -q --verify refs/stash)
|
27 |
+
|
28 |
+
# If there were no changes (e.g., `--amend` or `--allow-empty`)
|
29 |
+
# then nothing was stashed, and we should skip everything,
|
30 |
+
# including the tests themselves. (Presumably the tests passed
|
31 |
+
# on the previous commit, so there is no need to re-run them.)
|
32 |
+
if [ "$old_stash" = "$new_stash" ]; then
|
33 |
+
echo "No changes, skipping lint." 1>&2
|
34 |
+
exit 0
|
35 |
+
fi
|
36 |
+
|
37 |
+
# Run tests
|
38 |
+
cardboardlinter --refspec HEAD -n auto
|
39 |
+
status=$?
|
40 |
+
|
41 |
+
# Restore changes
|
42 |
+
echo "Restoring working tree changes..." 1>&2
|
43 |
+
git reset --hard -q && git stash apply --index -q && git stash drop -q
|
44 |
+
|
45 |
+
# Exit with status from test-run: nonzero prevents commit
|
46 |
+
exit $status
|
47 |
+
EOF
|
48 |
+
chmod +x .git/hooks/pre-commit
|
49 |
+
```
|
50 |
+
|
51 |
+
This will run the linters on just the changes made in your commit.
|
LICENSE.txt
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Mozilla Public License Version 2.0
|
2 |
+
==================================
|
3 |
+
|
4 |
+
1. Definitions
|
5 |
+
--------------
|
6 |
+
|
7 |
+
1.1. "Contributor"
|
8 |
+
means each individual or legal entity that creates, contributes to
|
9 |
+
the creation of, or owns Covered Software.
|
10 |
+
|
11 |
+
1.2. "Contributor Version"
|
12 |
+
means the combination of the Contributions of others (if any) used
|
13 |
+
by a Contributor and that particular Contributor's Contribution.
|
14 |
+
|
15 |
+
1.3. "Contribution"
|
16 |
+
means Covered Software of a particular Contributor.
|
17 |
+
|
18 |
+
1.4. "Covered Software"
|
19 |
+
means Source Code Form to which the initial Contributor has attached
|
20 |
+
the notice in Exhibit A, the Executable Form of such Source Code
|
21 |
+
Form, and Modifications of such Source Code Form, in each case
|
22 |
+
including portions thereof.
|
23 |
+
|
24 |
+
1.5. "Incompatible With Secondary Licenses"
|
25 |
+
means
|
26 |
+
|
27 |
+
(a) that the initial Contributor has attached the notice described
|
28 |
+
in Exhibit B to the Covered Software; or
|
29 |
+
|
30 |
+
(b) that the Covered Software was made available under the terms of
|
31 |
+
version 1.1 or earlier of the License, but not also under the
|
32 |
+
terms of a Secondary License.
|
33 |
+
|
34 |
+
1.6. "Executable Form"
|
35 |
+
means any form of the work other than Source Code Form.
|
36 |
+
|
37 |
+
1.7. "Larger Work"
|
38 |
+
means a work that combines Covered Software with other material, in
|
39 |
+
a separate file or files, that is not Covered Software.
|
40 |
+
|
41 |
+
1.8. "License"
|
42 |
+
means this document.
|
43 |
+
|
44 |
+
1.9. "Licensable"
|
45 |
+
means having the right to grant, to the maximum extent possible,
|
46 |
+
whether at the time of the initial grant or subsequently, any and
|
47 |
+
all of the rights conveyed by this License.
|
48 |
+
|
49 |
+
1.10. "Modifications"
|
50 |
+
means any of the following:
|
51 |
+
|
52 |
+
(a) any file in Source Code Form that results from an addition to,
|
53 |
+
deletion from, or modification of the contents of Covered
|
54 |
+
Software; or
|
55 |
+
|
56 |
+
(b) any new file in Source Code Form that contains any Covered
|
57 |
+
Software.
|
58 |
+
|
59 |
+
1.11. "Patent Claims" of a Contributor
|
60 |
+
means any patent claim(s), including without limitation, method,
|
61 |
+
process, and apparatus claims, in any patent Licensable by such
|
62 |
+
Contributor that would be infringed, but for the grant of the
|
63 |
+
License, by the making, using, selling, offering for sale, having
|
64 |
+
made, import, or transfer of either its Contributions or its
|
65 |
+
Contributor Version.
|
66 |
+
|
67 |
+
1.12. "Secondary License"
|
68 |
+
means either the GNU General Public License, Version 2.0, the GNU
|
69 |
+
Lesser General Public License, Version 2.1, the GNU Affero General
|
70 |
+
Public License, Version 3.0, or any later versions of those
|
71 |
+
licenses.
|
72 |
+
|
73 |
+
1.13. "Source Code Form"
|
74 |
+
means the form of the work preferred for making modifications.
|
75 |
+
|
76 |
+
1.14. "You" (or "Your")
|
77 |
+
means an individual or a legal entity exercising rights under this
|
78 |
+
License. For legal entities, "You" includes any entity that
|
79 |
+
controls, is controlled by, or is under common control with You. For
|
80 |
+
purposes of this definition, "control" means (a) the power, direct
|
81 |
+
or indirect, to cause the direction or management of such entity,
|
82 |
+
whether by contract or otherwise, or (b) ownership of more than
|
83 |
+
fifty percent (50%) of the outstanding shares or beneficial
|
84 |
+
ownership of such entity.
|
85 |
+
|
86 |
+
2. License Grants and Conditions
|
87 |
+
--------------------------------
|
88 |
+
|
89 |
+
2.1. Grants
|
90 |
+
|
91 |
+
Each Contributor hereby grants You a world-wide, royalty-free,
|
92 |
+
non-exclusive license:
|
93 |
+
|
94 |
+
(a) under intellectual property rights (other than patent or trademark)
|
95 |
+
Licensable by such Contributor to use, reproduce, make available,
|
96 |
+
modify, display, perform, distribute, and otherwise exploit its
|
97 |
+
Contributions, either on an unmodified basis, with Modifications, or
|
98 |
+
as part of a Larger Work; and
|
99 |
+
|
100 |
+
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
101 |
+
for sale, have made, import, and otherwise transfer either its
|
102 |
+
Contributions or its Contributor Version.
|
103 |
+
|
104 |
+
2.2. Effective Date
|
105 |
+
|
106 |
+
The licenses granted in Section 2.1 with respect to any Contribution
|
107 |
+
become effective for each Contribution on the date the Contributor first
|
108 |
+
distributes such Contribution.
|
109 |
+
|
110 |
+
2.3. Limitations on Grant Scope
|
111 |
+
|
112 |
+
The licenses granted in this Section 2 are the only rights granted under
|
113 |
+
this License. No additional rights or licenses will be implied from the
|
114 |
+
distribution or licensing of Covered Software under this License.
|
115 |
+
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
116 |
+
Contributor:
|
117 |
+
|
118 |
+
(a) for any code that a Contributor has removed from Covered Software;
|
119 |
+
or
|
120 |
+
|
121 |
+
(b) for infringements caused by: (i) Your and any other third party's
|
122 |
+
modifications of Covered Software, or (ii) the combination of its
|
123 |
+
Contributions with other software (except as part of its Contributor
|
124 |
+
Version); or
|
125 |
+
|
126 |
+
(c) under Patent Claims infringed by Covered Software in the absence of
|
127 |
+
its Contributions.
|
128 |
+
|
129 |
+
This License does not grant any rights in the trademarks, service marks,
|
130 |
+
or logos of any Contributor (except as may be necessary to comply with
|
131 |
+
the notice requirements in Section 3.4).
|
132 |
+
|
133 |
+
2.4. Subsequent Licenses
|
134 |
+
|
135 |
+
No Contributor makes additional grants as a result of Your choice to
|
136 |
+
distribute the Covered Software under a subsequent version of this
|
137 |
+
License (see Section 10.2) or under the terms of a Secondary License (if
|
138 |
+
permitted under the terms of Section 3.3).
|
139 |
+
|
140 |
+
2.5. Representation
|
141 |
+
|
142 |
+
Each Contributor represents that the Contributor believes its
|
143 |
+
Contributions are its original creation(s) or it has sufficient rights
|
144 |
+
to grant the rights to its Contributions conveyed by this License.
|
145 |
+
|
146 |
+
2.6. Fair Use
|
147 |
+
|
148 |
+
This License is not intended to limit any rights You have under
|
149 |
+
applicable copyright doctrines of fair use, fair dealing, or other
|
150 |
+
equivalents.
|
151 |
+
|
152 |
+
2.7. Conditions
|
153 |
+
|
154 |
+
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
155 |
+
in Section 2.1.
|
156 |
+
|
157 |
+
3. Responsibilities
|
158 |
+
-------------------
|
159 |
+
|
160 |
+
3.1. Distribution of Source Form
|
161 |
+
|
162 |
+
All distribution of Covered Software in Source Code Form, including any
|
163 |
+
Modifications that You create or to which You contribute, must be under
|
164 |
+
the terms of this License. You must inform recipients that the Source
|
165 |
+
Code Form of the Covered Software is governed by the terms of this
|
166 |
+
License, and how they can obtain a copy of this License. You may not
|
167 |
+
attempt to alter or restrict the recipients' rights in the Source Code
|
168 |
+
Form.
|
169 |
+
|
170 |
+
3.2. Distribution of Executable Form
|
171 |
+
|
172 |
+
If You distribute Covered Software in Executable Form then:
|
173 |
+
|
174 |
+
(a) such Covered Software must also be made available in Source Code
|
175 |
+
Form, as described in Section 3.1, and You must inform recipients of
|
176 |
+
the Executable Form how they can obtain a copy of such Source Code
|
177 |
+
Form by reasonable means in a timely manner, at a charge no more
|
178 |
+
than the cost of distribution to the recipient; and
|
179 |
+
|
180 |
+
(b) You may distribute such Executable Form under the terms of this
|
181 |
+
License, or sublicense it under different terms, provided that the
|
182 |
+
license for the Executable Form does not attempt to limit or alter
|
183 |
+
the recipients' rights in the Source Code Form under this License.
|
184 |
+
|
185 |
+
3.3. Distribution of a Larger Work
|
186 |
+
|
187 |
+
You may create and distribute a Larger Work under terms of Your choice,
|
188 |
+
provided that You also comply with the requirements of this License for
|
189 |
+
the Covered Software. If the Larger Work is a combination of Covered
|
190 |
+
Software with a work governed by one or more Secondary Licenses, and the
|
191 |
+
Covered Software is not Incompatible With Secondary Licenses, this
|
192 |
+
License permits You to additionally distribute such Covered Software
|
193 |
+
under the terms of such Secondary License(s), so that the recipient of
|
194 |
+
the Larger Work may, at their option, further distribute the Covered
|
195 |
+
Software under the terms of either this License or such Secondary
|
196 |
+
License(s).
|
197 |
+
|
198 |
+
3.4. Notices
|
199 |
+
|
200 |
+
You may not remove or alter the substance of any license notices
|
201 |
+
(including copyright notices, patent notices, disclaimers of warranty,
|
202 |
+
or limitations of liability) contained within the Source Code Form of
|
203 |
+
the Covered Software, except that You may alter any license notices to
|
204 |
+
the extent required to remedy known factual inaccuracies.
|
205 |
+
|
206 |
+
3.5. Application of Additional Terms
|
207 |
+
|
208 |
+
You may choose to offer, and to charge a fee for, warranty, support,
|
209 |
+
indemnity or liability obligations to one or more recipients of Covered
|
210 |
+
Software. However, You may do so only on Your own behalf, and not on
|
211 |
+
behalf of any Contributor. You must make it absolutely clear that any
|
212 |
+
such warranty, support, indemnity, or liability obligation is offered by
|
213 |
+
You alone, and You hereby agree to indemnify every Contributor for any
|
214 |
+
liability incurred by such Contributor as a result of warranty, support,
|
215 |
+
indemnity or liability terms You offer. You may include additional
|
216 |
+
disclaimers of warranty and limitations of liability specific to any
|
217 |
+
jurisdiction.
|
218 |
+
|
219 |
+
4. Inability to Comply Due to Statute or Regulation
|
220 |
+
---------------------------------------------------
|
221 |
+
|
222 |
+
If it is impossible for You to comply with any of the terms of this
|
223 |
+
License with respect to some or all of the Covered Software due to
|
224 |
+
statute, judicial order, or regulation then You must: (a) comply with
|
225 |
+
the terms of this License to the maximum extent possible; and (b)
|
226 |
+
describe the limitations and the code they affect. Such description must
|
227 |
+
be placed in a text file included with all distributions of the Covered
|
228 |
+
Software under this License. Except to the extent prohibited by statute
|
229 |
+
or regulation, such description must be sufficiently detailed for a
|
230 |
+
recipient of ordinary skill to be able to understand it.
|
231 |
+
|
232 |
+
5. Termination
|
233 |
+
--------------
|
234 |
+
|
235 |
+
5.1. The rights granted under this License will terminate automatically
|
236 |
+
if You fail to comply with any of its terms. However, if You become
|
237 |
+
compliant, then the rights granted under this License from a particular
|
238 |
+
Contributor are reinstated (a) provisionally, unless and until such
|
239 |
+
Contributor explicitly and finally terminates Your grants, and (b) on an
|
240 |
+
ongoing basis, if such Contributor fails to notify You of the
|
241 |
+
non-compliance by some reasonable means prior to 60 days after You have
|
242 |
+
come back into compliance. Moreover, Your grants from a particular
|
243 |
+
Contributor are reinstated on an ongoing basis if such Contributor
|
244 |
+
notifies You of the non-compliance by some reasonable means, this is the
|
245 |
+
first time You have received notice of non-compliance with this License
|
246 |
+
from such Contributor, and You become compliant prior to 30 days after
|
247 |
+
Your receipt of the notice.
|
248 |
+
|
249 |
+
5.2. If You initiate litigation against any entity by asserting a patent
|
250 |
+
infringement claim (excluding declaratory judgment actions,
|
251 |
+
counter-claims, and cross-claims) alleging that a Contributor Version
|
252 |
+
directly or indirectly infringes any patent, then the rights granted to
|
253 |
+
You by any and all Contributors for the Covered Software under Section
|
254 |
+
2.1 of this License shall terminate.
|
255 |
+
|
256 |
+
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
257 |
+
end user license agreements (excluding distributors and resellers) which
|
258 |
+
have been validly granted by You or Your distributors under this License
|
259 |
+
prior to termination shall survive termination.
|
260 |
+
|
261 |
+
************************************************************************
|
262 |
+
* *
|
263 |
+
* 6. Disclaimer of Warranty *
|
264 |
+
* ------------------------- *
|
265 |
+
* *
|
266 |
+
* Covered Software is provided under this License on an "as is" *
|
267 |
+
* basis, without warranty of any kind, either expressed, implied, or *
|
268 |
+
* statutory, including, without limitation, warranties that the *
|
269 |
+
* Covered Software is free of defects, merchantable, fit for a *
|
270 |
+
* particular purpose or non-infringing. The entire risk as to the *
|
271 |
+
* quality and performance of the Covered Software is with You. *
|
272 |
+
* Should any Covered Software prove defective in any respect, You *
|
273 |
+
* (not any Contributor) assume the cost of any necessary servicing, *
|
274 |
+
* repair, or correction. This disclaimer of warranty constitutes an *
|
275 |
+
* essential part of this License. No use of any Covered Software is *
|
276 |
+
* authorized under this License except under this disclaimer. *
|
277 |
+
* *
|
278 |
+
************************************************************************
|
279 |
+
|
280 |
+
************************************************************************
|
281 |
+
* *
|
282 |
+
* 7. Limitation of Liability *
|
283 |
+
* -------------------------- *
|
284 |
+
* *
|
285 |
+
* Under no circumstances and under no legal theory, whether tort *
|
286 |
+
* (including negligence), contract, or otherwise, shall any *
|
287 |
+
* Contributor, or anyone who distributes Covered Software as *
|
288 |
+
* permitted above, be liable to You for any direct, indirect, *
|
289 |
+
* special, incidental, or consequential damages of any character *
|
290 |
+
* including, without limitation, damages for lost profits, loss of *
|
291 |
+
* goodwill, work stoppage, computer failure or malfunction, or any *
|
292 |
+
* and all other commercial damages or losses, even if such party *
|
293 |
+
* shall have been informed of the possibility of such damages. This *
|
294 |
+
* limitation of liability shall not apply to liability for death or *
|
295 |
+
* personal injury resulting from such party's negligence to the *
|
296 |
+
* extent applicable law prohibits such limitation. Some *
|
297 |
+
* jurisdictions do not allow the exclusion or limitation of *
|
298 |
+
* incidental or consequential damages, so this exclusion and *
|
299 |
+
* limitation may not apply to You. *
|
300 |
+
* *
|
301 |
+
************************************************************************
|
302 |
+
|
303 |
+
8. Litigation
|
304 |
+
-------------
|
305 |
+
|
306 |
+
Any litigation relating to this License may be brought only in the
|
307 |
+
courts of a jurisdiction where the defendant maintains its principal
|
308 |
+
place of business and such litigation shall be governed by laws of that
|
309 |
+
jurisdiction, without reference to its conflict-of-law provisions.
|
310 |
+
Nothing in this Section shall prevent a party's ability to bring
|
311 |
+
cross-claims or counter-claims.
|
312 |
+
|
313 |
+
9. Miscellaneous
|
314 |
+
----------------
|
315 |
+
|
316 |
+
This License represents the complete agreement concerning the subject
|
317 |
+
matter hereof. If any provision of this License is held to be
|
318 |
+
unenforceable, such provision shall be reformed only to the extent
|
319 |
+
necessary to make it enforceable. Any law or regulation which provides
|
320 |
+
that the language of a contract shall be construed against the drafter
|
321 |
+
shall not be used to construe this License against a Contributor.
|
322 |
+
|
323 |
+
10. Versions of the License
|
324 |
+
---------------------------
|
325 |
+
|
326 |
+
10.1. New Versions
|
327 |
+
|
328 |
+
Mozilla Foundation is the license steward. Except as provided in Section
|
329 |
+
10.3, no one other than the license steward has the right to modify or
|
330 |
+
publish new versions of this License. Each version will be given a
|
331 |
+
distinguishing version number.
|
332 |
+
|
333 |
+
10.2. Effect of New Versions
|
334 |
+
|
335 |
+
You may distribute the Covered Software under the terms of the version
|
336 |
+
of the License under which You originally received the Covered Software,
|
337 |
+
or under the terms of any subsequent version published by the license
|
338 |
+
steward.
|
339 |
+
|
340 |
+
10.3. Modified Versions
|
341 |
+
|
342 |
+
If you create software not governed by this License, and you want to
|
343 |
+
create a new license for such software, you may create and use a
|
344 |
+
modified version of this License if you rename the license and remove
|
345 |
+
any references to the name of the license steward (except to note that
|
346 |
+
such modified license differs from this License).
|
347 |
+
|
348 |
+
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
349 |
+
Licenses
|
350 |
+
|
351 |
+
If You choose to distribute Source Code Form that is Incompatible With
|
352 |
+
Secondary Licenses under the terms of this version of the License, the
|
353 |
+
notice described in Exhibit B of this License must be attached.
|
354 |
+
|
355 |
+
Exhibit A - Source Code Form License Notice
|
356 |
+
-------------------------------------------
|
357 |
+
|
358 |
+
This Source Code Form is subject to the terms of the Mozilla Public
|
359 |
+
License, v. 2.0. If a copy of the MPL was not distributed with this
|
360 |
+
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
361 |
+
|
362 |
+
If it is not possible or desirable to put the notice in a particular
|
363 |
+
file, then You may include the notice in a location (such as a LICENSE
|
364 |
+
file in a relevant directory) where a recipient would be likely to look
|
365 |
+
for such a notice.
|
366 |
+
|
367 |
+
You may add additional accurate notices of copyright ownership.
|
368 |
+
|
369 |
+
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
370 |
+
---------------------------------------------------------
|
371 |
+
|
372 |
+
This Source Code Form is "Incompatible With Secondary Licenses", as
|
373 |
+
defined by the Mozilla Public License, v. 2.0.
|
MANIFEST.in
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include README.md
|
2 |
+
include LICENSE.txt
|
3 |
+
include requirements.txt
|
4 |
+
recursive-include TTS *.json
|
5 |
+
recursive-include TTS *.html
|
6 |
+
recursive-include TTS *.png
|
7 |
+
recursive-include TTS *.md
|
8 |
+
recursive-include TTS *.py
|
9 |
+
recursive-include TTS *.pyx
|
10 |
+
recursive-include images *.png
|
11 |
+
|
README.md
CHANGED
@@ -1,3 +1,281 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<img src="https://user-images.githubusercontent.com/1402048/104139991-3fd15e00-53af-11eb-8640-3a78a64641dd.png" data-canonical-src="![TTS banner](https://user-images.githubusercontent.com/1402048/104139991-3fd15e00-53af-11eb-8640-3a78a64641dd.png =250x250)
|
2 |
+
" width="256" height="256" align="right" />
|
3 |
+
|
4 |
+
# TTS: Text-to-Speech for all.
|
5 |
+
|
6 |
+
TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
|
7 |
+
TTS comes with [pretrained models](https://github.com/mozilla/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
8 |
+
|
9 |
+
[![CircleCI](<https://circleci.com/gh/mozilla/TTS/tree/dev.svg?style=svg>)]()
|
10 |
+
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
|
11 |
+
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)
|
12 |
+
|
13 |
+
:loudspeaker: [English Voice Samples](https://erogol.github.io/ddc-samples/) and [SoundCloud playlist](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
|
14 |
+
|
15 |
+
:man_cook: [TTS training recipes](https://github.com/erogol/TTS_recipes)
|
16 |
+
|
17 |
+
:page_facing_up: [Text-to-Speech paper collection](https://github.com/erogol/TTS-papers)
|
18 |
+
|
19 |
+
## 💬 Where to ask questions
|
20 |
+
Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly, so that more people can benefit from it.
|
21 |
+
|
22 |
+
| Type | Platforms |
|
23 |
+
| ------------------------------- | --------------------------------------- |
|
24 |
+
| 🚨 **Bug Reports** | [GitHub Issue Tracker] |
|
25 |
+
| ❔ **FAQ** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/FAQ) |
|
26 |
+
| 🎁 **Feature Requests & Ideas** | [GitHub Issue Tracker] |
|
27 |
+
| 👩💻 **Usage Questions** | [Discourse Forum] |
|
28 |
+
| 🗯 **General Discussion** | [Discourse Forum] and [Matrix Channel] |
|
29 |
+
|
30 |
+
[github issue tracker]: https://github.com/mozilla/tts/issues
|
31 |
+
[discourse forum]: https://discourse.mozilla.org/c/tts/
|
32 |
+
[matrix channel]: https://matrix.to/#/!KTePhNahjgiVumkqca:matrix.org?via=matrix.org
|
33 |
+
[Tutorials and Examples]: https://github.com/mozilla/TTS/wiki/TTS-Notebooks-and-Tutorials
|
34 |
+
|
35 |
+
|
36 |
+
## 🔗 Links and Resources
|
37 |
+
| Type | Links |
|
38 |
+
| ------------------------------- | --------------------------------------- |
|
39 |
+
| 💾 **Installation** | [TTS/README.md](https://github.com/mozilla/TTS/tree/dev#install-tts)|
|
40 |
+
| 👩🏾🏫 **Tutorials and Examples** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/TTS-Notebooks-and-Tutorials) |
|
41 |
+
| 🚀 **Released Models** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/Released-Models)|
|
42 |
+
| 💻 **Docker Image** | [Repository by @synesthesiam](https://github.com/synesthesiam/docker-mozillatts)|
|
43 |
+
| 🖥️ **Demo Server** | [TTS/server](https://github.com/mozilla/TTS/tree/master/TTS/server)|
|
44 |
+
| 🤖 **Running TTS on Terminal** | [TTS/README.md](https://github.com/mozilla/TTS#example-synthesizing-speech-on-terminal-using-the-released-models)|
|
45 |
+
| ✨ **How to contribute** |[TTS/README.md](#contribution-guidelines)|
|
46 |
+
|
47 |
+
## 🥇 TTS Performance
|
48 |
+
<p align="center"><img src="https://discourse-prod-uploads-81679984178418.s3.dualstack.us-west-2.amazonaws.com/optimized/3X/6/4/6428f980e9ec751c248e591460895f7881aec0c6_2_1035x591.png" width="800" /></p>
|
49 |
+
|
50 |
+
"Mozilla*" and "Judy*" are our models.
|
51 |
+
[Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
|
52 |
+
|
53 |
+
## Features
|
54 |
+
- High performance Deep Learning models for Text2Speech tasks.
|
55 |
+
- Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
|
56 |
+
- Speaker Encoder to compute speaker embeddings efficiently.
|
57 |
+
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN, WaveGrad, WaveRNN)
|
58 |
+
- Fast and efficient model training.
|
59 |
+
- Detailed training logs on console and Tensorboard.
|
60 |
+
- Support for multi-speaker TTS.
|
61 |
+
- Efficient Multi-GPUs training.
|
62 |
+
- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
|
63 |
+
- Released models in PyTorch, Tensorflow and TFLite.
|
64 |
+
- Tools to curate Text2Speech datasets under```dataset_analysis```.
|
65 |
+
- Demo server for model testing.
|
66 |
+
- Notebooks for extensive model benchmarking.
|
67 |
+
- Modular (but not too much) code base enabling easy testing for new ideas.
|
68 |
+
|
69 |
+
## Implemented Models
|
70 |
+
### Text-to-Spectrogram
|
71 |
+
- Tacotron: [paper](https://arxiv.org/abs/1703.10135)
|
72 |
+
- Tacotron2: [paper](https://arxiv.org/abs/1712.05884)
|
73 |
+
- Glow-TTS: [paper](https://arxiv.org/abs/2005.11129)
|
74 |
+
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
|
75 |
+
|
76 |
+
### Attention Methods
|
77 |
+
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
|
78 |
+
- Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
|
79 |
+
- Graves Attention: [paper](https://arxiv.org/abs/1907.09006)
|
80 |
+
- Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
|
81 |
+
|
82 |
+
### Speaker Encoder
|
83 |
+
- GE2E: [paper](https://arxiv.org/abs/1710.10467)
|
84 |
+
- Angular Loss: [paper](https://arxiv.org/pdf/2003.11982.pdf)
|
85 |
+
|
86 |
+
### Vocoders
|
87 |
+
- MelGAN: [paper](https://arxiv.org/abs/1910.06711)
|
88 |
+
- MultiBandMelGAN: [paper](https://arxiv.org/abs/2005.05106)
|
89 |
+
- ParallelWaveGAN: [paper](https://arxiv.org/abs/1910.11480)
|
90 |
+
- GAN-TTS discriminators: [paper](https://arxiv.org/abs/1909.11646)
|
91 |
+
- WaveRNN: [origin](https://github.com/fatchord/WaveRNN/)
|
92 |
+
- WaveGrad: [paper](https://arxiv.org/abs/2009.00713)
|
93 |
+
|
94 |
+
You can also help us implement more models. Some TTS related work can be found [here](https://github.com/erogol/TTS-papers).
|
95 |
+
|
96 |
+
## Install TTS
|
97 |
+
TTS supports **python >= 3.6, <3.9**.
|
98 |
+
|
99 |
+
If you are only interested in [synthesizing speech](https://github.com/mozilla/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released TTS models, installing from PyPI is the easiest option.
|
100 |
+
|
101 |
+
```bash
|
102 |
+
pip install TTS
|
103 |
+
```
|
104 |
+
|
105 |
+
If you plan to code or train models, clone TTS and install it locally.
|
106 |
+
|
107 |
+
```bash
|
108 |
+
git clone https://github.com/mozilla/TTS
|
109 |
+
pip install -e .
|
110 |
+
```
|
111 |
+
|
112 |
+
## Directory Structure
|
113 |
+
```
|
114 |
+
|- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
|
115 |
+
|- utils/ (common utilities.)
|
116 |
+
|- TTS
|
117 |
+
|- bin/ (folder for all the executables.)
|
118 |
+
|- train*.py (train your target model.)
|
119 |
+
|- distribute.py (train your TTS model using Multiple GPUs.)
|
120 |
+
|- compute_statistics.py (compute dataset statistics for normalization.)
|
121 |
+
|- convert*.py (convert target torch model to TF.)
|
122 |
+
|- tts/ (text to speech models)
|
123 |
+
|- layers/ (model layer definitions)
|
124 |
+
|- models/ (model definitions)
|
125 |
+
|- tf/ (Tensorflow 2 utilities and model implementations)
|
126 |
+
|- utils/ (model specific utilities.)
|
127 |
+
|- speaker_encoder/ (Speaker Encoder models.)
|
128 |
+
|- (same)
|
129 |
+
|- vocoder/ (Vocoder models.)
|
130 |
+
|- (same)
|
131 |
+
```
|
132 |
+
|
133 |
+
## Sample Model Output
|
134 |
+
Below you see Tacotron model state after 16K iterations with batch-size 32 with LJSpeech dataset.
|
135 |
+
|
136 |
+
> "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning."
|
137 |
+
|
138 |
+
Audio examples: [soundcloud](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
|
139 |
+
|
140 |
+
<img src="images/example_model_output.png?raw=true" alt="example_output" width="400"/>
|
141 |
+
|
142 |
+
## Datasets and Data-Loading
|
143 |
+
TTS provides a generic dataloader easy to use for your custom dataset.
|
144 |
+
You just need to write a simple function to format the dataset. Check ```datasets/preprocess.py``` to see some examples.
|
145 |
+
After that, you need to set ```dataset``` fields in ```config.json```.
|
146 |
+
|
147 |
+
Some of the public datasets that we successfully applied TTS:
|
148 |
+
|
149 |
+
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
|
150 |
+
- [Nancy](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/)
|
151 |
+
- [TWEB](https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)
|
152 |
+
- [M-AI-Labs](http://www.caito.de/2019/01/the-m-ailabs-speech-dataset/)
|
153 |
+
- [LibriTTS](https://openslr.org/60/)
|
154 |
+
- [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01
|
155 |
+
|
156 |
+
## Example: Synthesizing Speech on Terminal Using the Released Models.
|
157 |
+
|
158 |
+
After the installation, TTS provides a CLI interface for synthesizing speech using pre-trained models. You can either use your own model or the release models under the TTS project.
|
159 |
+
|
160 |
+
Listing released TTS models.
|
161 |
+
```bash
|
162 |
+
tts --list_models
|
163 |
+
```
|
164 |
+
|
165 |
+
Run a tts and a vocoder model from the released model list. (Simply copy and paste the full model names from the list as arguments for the command below.)
|
166 |
+
```bash
|
167 |
+
tts --text "Text for TTS" \
|
168 |
+
--model_name "<type>/<language>/<dataset>/<model_name>" \
|
169 |
+
--vocoder_name "<type>/<language>/<dataset>/<model_name>" \
|
170 |
+
--out_path folder/to/save/output/
|
171 |
+
```
|
172 |
+
|
173 |
+
Run your own TTS model (Using Griffin-Lim Vocoder)
|
174 |
+
```bash
|
175 |
+
tts --text "Text for TTS" \
|
176 |
+
--model_path path/to/model.pth.tar \
|
177 |
+
--config_path path/to/config.json \
|
178 |
+
--out_path output/path/speech.wav
|
179 |
+
```
|
180 |
+
|
181 |
+
Run your own TTS and Vocoder models
|
182 |
+
```bash
|
183 |
+
tts --text "Text for TTS" \
|
184 |
+
--model_path path/to/config.json \
|
185 |
+
--config_path path/to/model.pth.tar \
|
186 |
+
--out_path output/path/speech.wav \
|
187 |
+
--vocoder_path path/to/vocoder.pth.tar \
|
188 |
+
--vocoder_config_path path/to/vocoder_config.json
|
189 |
+
```
|
190 |
+
|
191 |
+
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
|
192 |
+
|
193 |
+
## Example: Training and Fine-tuning LJ-Speech Dataset
|
194 |
+
Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
|
195 |
+
|
196 |
+
To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
|
197 |
+
|
198 |
+
```
|
199 |
+
shuf metadata.csv > metadata_shuf.csv
|
200 |
+
head -n 12000 metadata_shuf.csv > metadata_train.csv
|
201 |
+
tail -n 1100 metadata_shuf.csv > metadata_val.csv
|
202 |
+
```
|
203 |
+
|
204 |
+
To train a new model, you need to define your own ```config.json``` to define model details, trainin configuration and more (check the examples). Then call the corressponding train script.
|
205 |
+
|
206 |
+
For instance, in order to train a tacotron or tacotron2 model on LJSpeech dataset, follow these steps.
|
207 |
+
|
208 |
+
```bash
|
209 |
+
python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json
|
210 |
+
```
|
211 |
+
|
212 |
+
To fine-tune a model, use ```--restore_path```.
|
213 |
+
|
214 |
+
```bash
|
215 |
+
python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json --restore_path /path/to/your/model.pth.tar
|
216 |
+
```
|
217 |
+
|
218 |
+
To continue an old training run, use ```--continue_path```.
|
219 |
+
|
220 |
+
```bash
|
221 |
+
python TTS/bin/train_tacotron.py --continue_path /path/to/your/run_folder/
|
222 |
+
```
|
223 |
+
|
224 |
+
For multi-GPU training, call ```distribute.py```. It runs any provided train script in multi-GPU setting.
|
225 |
+
|
226 |
+
```bash
|
227 |
+
CUDA_VISIBLE_DEVICES="0,1,4" python TTS/bin/distribute.py --script train_tacotron.py --config_path TTS/tts/configs/config.json
|
228 |
+
```
|
229 |
+
|
230 |
+
Each run creates a new output folder accomodating used ```config.json```, model checkpoints and tensorboard logs.
|
231 |
+
|
232 |
+
In case of any error or intercepted execution, if there is no checkpoint yet under the output folder, the whole folder is going to be removed.
|
233 |
+
|
234 |
+
You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder.
|
235 |
+
|
236 |
+
## Contribution Guidelines
|
237 |
+
This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/)
|
238 |
+
|
239 |
+
1. Create a new branch.
|
240 |
+
2. Implement your changes.
|
241 |
+
3. (if applicable) Add [Google Style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
|
242 |
+
4. (if applicable) Implement a test case under ```tests``` folder.
|
243 |
+
5. (Optional but Prefered) Run tests.
|
244 |
+
```bash
|
245 |
+
./run_tests.sh
|
246 |
+
```
|
247 |
+
6. Run the linter.
|
248 |
+
```bash
|
249 |
+
pip install pylint cardboardlint
|
250 |
+
cardboardlinter --refspec master
|
251 |
+
```
|
252 |
+
7. Send a PR to ```dev``` branch, explain what the change is about.
|
253 |
+
8. Let us discuss until we make it perfect :).
|
254 |
+
9. We merge it to the ```dev``` branch once things look good.
|
255 |
+
|
256 |
+
Feel free to ping us at any step you need help using our communication channels.
|
257 |
+
|
258 |
+
## Collaborative Experimentation Guide
|
259 |
+
If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration.
|
260 |
+
(If you have an idea for better collaboration, let us know)
|
261 |
+
- Create a new branch.
|
262 |
+
- Open an issue pointing your branch.
|
263 |
+
- Explain your idea and experiment.
|
264 |
+
- Share your results regularly. (Tensorboard log files, audio results, visuals etc.)
|
265 |
+
|
266 |
+
## Major TODOs
|
267 |
+
- [x] Implement the model.
|
268 |
+
- [x] Generate human-like speech on LJSpeech dataset.
|
269 |
+
- [x] Generate human-like speech on a different dataset (Nancy) (TWEB).
|
270 |
+
- [x] Train TTS with r=1 successfully.
|
271 |
+
- [x] Enable process based distributed training. Similar to (https://github.com/fastai/imagenet-fast/).
|
272 |
+
- [x] Adapting Neural Vocoder. TTS works with WaveRNN and ParallelWaveGAN (https://github.com/erogol/WaveRNN and https://github.com/erogol/ParallelWaveGAN)
|
273 |
+
- [x] Multi-speaker embedding.
|
274 |
+
- [x] Model optimization (model export, model pruning etc.)
|
275 |
+
|
276 |
+
### Acknowledgement
|
277 |
+
- https://github.com/keithito/tacotron (Dataset pre-processing)
|
278 |
+
- https://github.com/r9y9/tacotron_pytorch (Initial Tacotron architecture)
|
279 |
+
- https://github.com/kan-bayashi/ParallelWaveGAN (vocoder library)
|
280 |
+
- https://github.com/jaywalnut310/glow-tts (Original Glow-TTS implementation)
|
281 |
+
- https://github.com/fatchord/WaveRNN/ (Original WaveRNN implementation)
|
TTS/.models.json
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"tts_models":{
|
3 |
+
"en":{
|
4 |
+
"ljspeech":{
|
5 |
+
"glow-tts":{
|
6 |
+
"description": "",
|
7 |
+
"model_file": "1NFsfhH8W8AgcfJ-BsL8CYAwQfZ5k4T-n",
|
8 |
+
"config_file": "1IAROF3yy9qTK43vG_-R67y3Py9yYbD6t",
|
9 |
+
"stats_file": null,
|
10 |
+
"commit": ""
|
11 |
+
},
|
12 |
+
"tacotron2-DCA": {
|
13 |
+
"description": "",
|
14 |
+
"model_file": "1CFoPDQBnhfBFu2Gc0TBSJn8o-TuNKQn7",
|
15 |
+
"config_file": "1lWSscNfKet1zZSJCNirOn7v9bigUZ8C1",
|
16 |
+
"stats_file": "1qevpGRVHPmzfiRBNuugLMX62x1k7B5vK",
|
17 |
+
"commit": ""
|
18 |
+
},
|
19 |
+
"speedy-speech-wn":{
|
20 |
+
"description": "Speedy Speech model with wavenet decoder.",
|
21 |
+
"model_file": "1VXAwiq6N-Viq3rsSXlf43bdoi0jSvMAJ",
|
22 |
+
"config_file": "1KvZilhsNP3EumVggDcD46yd834eO5hR3",
|
23 |
+
"stats_file": "1Ju7apZ5JlgsVECcETL-GEx3DRoNzWfkR",
|
24 |
+
"commit": "77b6145"
|
25 |
+
}
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"es":{
|
29 |
+
"mai":{
|
30 |
+
"tacotron2-DDC":{
|
31 |
+
"model_file": "1jZ4HvYcAXI5ZClke2iGA7qFQQJBXIovw",
|
32 |
+
"config_file": "1s7g4n-B73ChCB48AQ88_DV_8oyLth8r0",
|
33 |
+
"stats_file": "13st0CZ743v6Br5R5Qw_lH1OPQOr3M-Jv",
|
34 |
+
"commit": ""
|
35 |
+
}
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"fr":{
|
39 |
+
"mai":{
|
40 |
+
"tacotron2-DDC":{
|
41 |
+
"model_file": "1qyxrrCyoXUvBG2lqVd0KqAlHj-2nZCgS",
|
42 |
+
"config_file": "1yECKeP2LI7tNv4E8yVNx1yLmCfTCpkqG",
|
43 |
+
"stats_file": "13st0CZ743v6Br5R5Qw_lH1OPQOr3M-Jv",
|
44 |
+
"commit": ""
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"vocoder_models":{
|
50 |
+
"universal":{
|
51 |
+
"libri-tts":{
|
52 |
+
"wavegrad":{
|
53 |
+
"model_file": "1r2g90JaZsfCj9dJkI9ioIU6JCFMPRqi6",
|
54 |
+
"config_file": "1POrrLf5YEpZyjvWyMccj1nGCVc94mR6s",
|
55 |
+
"stats_file": "1Vwbv4t-N1i3jXqI0bgKAhShAEO097sK0",
|
56 |
+
"commit": "ea976b0"
|
57 |
+
},
|
58 |
+
"fullband-melgan":{
|
59 |
+
"model_file": "1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K",
|
60 |
+
"config_file": "1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu",
|
61 |
+
"stats_file": "11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU",
|
62 |
+
"commit": "4132240"
|
63 |
+
}
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"en": {
|
67 |
+
"ljspeech":{
|
68 |
+
"mulitband-melgan":{
|
69 |
+
"model_file": "1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K",
|
70 |
+
"config_file": "1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu",
|
71 |
+
"stats_file": "11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU",
|
72 |
+
"commit": "ea976b0"
|
73 |
+
}
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
}
|
TTS/__init__.py
ADDED
File without changes
|
TTS/bin/__init__.py
ADDED
File without changes
|
TTS/bin/compute_attention_masks.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from tqdm import tqdm
|
9 |
+
from argparse import RawTextHelpFormatter
|
10 |
+
from TTS.tts.datasets.TTSDataset import MyDataset
|
11 |
+
from TTS.tts.utils.generic_utils import setup_model
|
12 |
+
from TTS.tts.utils.io import load_checkpoint
|
13 |
+
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
14 |
+
from TTS.utils.audio import AudioProcessor
|
15 |
+
from TTS.utils.io import load_config
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
|
21 |
+
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
|
22 |
+
|
23 |
+
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
24 |
+
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
|
25 |
+
|
26 |
+
'''
|
27 |
+
Example run:
|
28 |
+
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
29 |
+
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
30 |
+
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
31 |
+
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
32 |
+
--data_path /root/LJSpeech-1.1/
|
33 |
+
--batch_size 32
|
34 |
+
--dataset ljspeech
|
35 |
+
--use_cuda True
|
36 |
+
''',
|
37 |
+
formatter_class=RawTextHelpFormatter
|
38 |
+
)
|
39 |
+
parser.add_argument('--model_path',
|
40 |
+
type=str,
|
41 |
+
required=True,
|
42 |
+
help='Path to Tacotron/Tacotron2 model file ')
|
43 |
+
parser.add_argument(
|
44 |
+
'--config_path',
|
45 |
+
type=str,
|
46 |
+
required=True,
|
47 |
+
help='Path to Tacotron/Tacotron2 config file.',
|
48 |
+
)
|
49 |
+
parser.add_argument('--dataset',
|
50 |
+
type=str,
|
51 |
+
default='',
|
52 |
+
required=True,
|
53 |
+
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
'--dataset_metafile',
|
57 |
+
type=str,
|
58 |
+
default='',
|
59 |
+
required=True,
|
60 |
+
help='Dataset metafile inclusing file paths with transcripts.')
|
61 |
+
parser.add_argument(
|
62 |
+
'--data_path',
|
63 |
+
type=str,
|
64 |
+
default='',
|
65 |
+
help='Defines the data path. It overwrites config.json.')
|
66 |
+
parser.add_argument('--use_cuda',
|
67 |
+
type=bool,
|
68 |
+
default=False,
|
69 |
+
help="enable/disable cuda.")
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
'--batch_size',
|
73 |
+
default=16,
|
74 |
+
type=int,
|
75 |
+
help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
C = load_config(args.config_path)
|
79 |
+
ap = AudioProcessor(**C.audio)
|
80 |
+
|
81 |
+
# if the vocabulary was passed, replace the default
|
82 |
+
if 'characters' in C.keys():
|
83 |
+
symbols, phonemes = make_symbols(**C.characters)
|
84 |
+
|
85 |
+
# load the model
|
86 |
+
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
87 |
+
# TODO: handle multi-speaker
|
88 |
+
model = setup_model(num_chars, num_speakers=0, c=C)
|
89 |
+
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
90 |
+
model.eval()
|
91 |
+
|
92 |
+
# data loader
|
93 |
+
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
|
94 |
+
preprocessor = getattr(preprocessor, args.dataset)
|
95 |
+
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
96 |
+
dataset = MyDataset(model.decoder.r,
|
97 |
+
C.text_cleaner,
|
98 |
+
compute_linear_spec=False,
|
99 |
+
ap=ap,
|
100 |
+
meta_data=meta_data,
|
101 |
+
tp=C.characters if 'characters' in C.keys() else None,
|
102 |
+
add_blank=C['add_blank'] if 'add_blank' in C.keys() else False,
|
103 |
+
use_phonemes=C.use_phonemes,
|
104 |
+
phoneme_cache_path=C.phoneme_cache_path,
|
105 |
+
phoneme_language=C.phoneme_language,
|
106 |
+
enable_eos_bos=C.enable_eos_bos_chars)
|
107 |
+
|
108 |
+
dataset.sort_items()
|
109 |
+
loader = DataLoader(dataset,
|
110 |
+
batch_size=args.batch_size,
|
111 |
+
num_workers=4,
|
112 |
+
collate_fn=dataset.collate_fn,
|
113 |
+
shuffle=False,
|
114 |
+
drop_last=False)
|
115 |
+
|
116 |
+
# compute attentions
|
117 |
+
file_paths = []
|
118 |
+
with torch.no_grad():
|
119 |
+
for data in tqdm(loader):
|
120 |
+
# setup input data
|
121 |
+
text_input = data[0]
|
122 |
+
text_lengths = data[1]
|
123 |
+
linear_input = data[3]
|
124 |
+
mel_input = data[4]
|
125 |
+
mel_lengths = data[5]
|
126 |
+
stop_targets = data[6]
|
127 |
+
item_idxs = data[7]
|
128 |
+
|
129 |
+
# dispatch data to GPU
|
130 |
+
if args.use_cuda:
|
131 |
+
text_input = text_input.cuda()
|
132 |
+
text_lengths = text_lengths.cuda()
|
133 |
+
mel_input = mel_input.cuda()
|
134 |
+
mel_lengths = mel_lengths.cuda()
|
135 |
+
|
136 |
+
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
|
137 |
+
text_input, text_lengths, mel_input)
|
138 |
+
|
139 |
+
alignments = alignments.detach()
|
140 |
+
for idx, alignment in enumerate(alignments):
|
141 |
+
item_idx = item_idxs[idx]
|
142 |
+
# interpolate if r > 1
|
143 |
+
alignment = torch.nn.functional.interpolate(
|
144 |
+
alignment.transpose(0, 1).unsqueeze(0),
|
145 |
+
size=None,
|
146 |
+
scale_factor=model.decoder.r,
|
147 |
+
mode='nearest',
|
148 |
+
align_corners=None,
|
149 |
+
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
150 |
+
# remove paddings
|
151 |
+
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
152 |
+
# set file paths
|
153 |
+
wav_file_name = os.path.basename(item_idx)
|
154 |
+
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
155 |
+
file_path = item_idx.replace(wav_file_name, align_file_name)
|
156 |
+
# save output
|
157 |
+
file_paths.append([item_idx, file_path])
|
158 |
+
np.save(file_path, alignment)
|
159 |
+
|
160 |
+
# ourput metafile
|
161 |
+
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
162 |
+
|
163 |
+
with open(metafile, "w") as f:
|
164 |
+
for p in file_paths:
|
165 |
+
f.write(f"{p[0]}|{p[1]}\n")
|
166 |
+
print(f" >> Metafile created: {metafile}")
|
TTS/bin/compute_embeddings.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from TTS.speaker_encoder.model import SpeakerEncoder
|
10 |
+
from TTS.utils.audio import AudioProcessor
|
11 |
+
from TTS.utils.io import load_config
|
12 |
+
from TTS.tts.utils.speakers import save_speaker_mapping
|
13 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser(
|
16 |
+
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.')
|
17 |
+
parser.add_argument(
|
18 |
+
'model_path',
|
19 |
+
type=str,
|
20 |
+
help='Path to model outputs (checkpoint, tensorboard etc.).')
|
21 |
+
parser.add_argument(
|
22 |
+
'config_path',
|
23 |
+
type=str,
|
24 |
+
help='Path to config file for training.',
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
'data_path',
|
28 |
+
type=str,
|
29 |
+
help='Data path for wav files - directory or CSV file')
|
30 |
+
parser.add_argument(
|
31 |
+
'output_path',
|
32 |
+
type=str,
|
33 |
+
help='path for training outputs.')
|
34 |
+
parser.add_argument(
|
35 |
+
'--target_dataset',
|
36 |
+
type=str,
|
37 |
+
default='',
|
38 |
+
help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.'
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--use_cuda', type=bool, help='flag to set cuda.', default=False
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
|
45 |
+
)
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
|
49 |
+
c = load_config(args.config_path)
|
50 |
+
ap = AudioProcessor(**c['audio'])
|
51 |
+
|
52 |
+
data_path = args.data_path
|
53 |
+
split_ext = os.path.splitext(data_path)
|
54 |
+
sep = args.separator
|
55 |
+
|
56 |
+
if args.target_dataset != '':
|
57 |
+
# if target dataset is defined
|
58 |
+
dataset_config = [
|
59 |
+
{
|
60 |
+
"name": args.target_dataset,
|
61 |
+
"path": args.data_path,
|
62 |
+
"meta_file_train": None,
|
63 |
+
"meta_file_val": None
|
64 |
+
},
|
65 |
+
]
|
66 |
+
wav_files, _ = load_meta_data(dataset_config, eval_split=False)
|
67 |
+
output_files = [wav_file[1].replace(data_path, args.output_path).replace(
|
68 |
+
'.wav', '.npy') for wav_file in wav_files]
|
69 |
+
else:
|
70 |
+
# if target dataset is not defined
|
71 |
+
if len(split_ext) > 0 and split_ext[1].lower() == '.csv':
|
72 |
+
# Parse CSV
|
73 |
+
print(f'CSV file: {data_path}')
|
74 |
+
with open(data_path) as f:
|
75 |
+
wav_path = os.path.join(os.path.dirname(data_path), 'wavs')
|
76 |
+
wav_files = []
|
77 |
+
print(f'Separator is: {sep}')
|
78 |
+
for line in f:
|
79 |
+
components = line.split(sep)
|
80 |
+
if len(components) != 2:
|
81 |
+
print("Invalid line")
|
82 |
+
continue
|
83 |
+
wav_file = os.path.join(wav_path, components[0] + '.wav')
|
84 |
+
#print(f'wav_file: {wav_file}')
|
85 |
+
if os.path.exists(wav_file):
|
86 |
+
wav_files.append(wav_file)
|
87 |
+
print(f'Count of wavs imported: {len(wav_files)}')
|
88 |
+
else:
|
89 |
+
# Parse all wav files in data_path
|
90 |
+
wav_files = glob.glob(data_path + '/**/*.wav', recursive=True)
|
91 |
+
|
92 |
+
output_files = [wav_file.replace(data_path, args.output_path).replace(
|
93 |
+
'.wav', '.npy') for wav_file in wav_files]
|
94 |
+
|
95 |
+
for output_file in output_files:
|
96 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
97 |
+
|
98 |
+
# define Encoder model
|
99 |
+
model = SpeakerEncoder(**c.model)
|
100 |
+
model.load_state_dict(torch.load(args.model_path)['model'])
|
101 |
+
model.eval()
|
102 |
+
if args.use_cuda:
|
103 |
+
model.cuda()
|
104 |
+
|
105 |
+
# compute speaker embeddings
|
106 |
+
speaker_mapping = {}
|
107 |
+
for idx, wav_file in enumerate(tqdm(wav_files)):
|
108 |
+
if isinstance(wav_file, list):
|
109 |
+
speaker_name = wav_file[2]
|
110 |
+
wav_file = wav_file[1]
|
111 |
+
|
112 |
+
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
|
113 |
+
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
|
114 |
+
if args.use_cuda:
|
115 |
+
mel_spec = mel_spec.cuda()
|
116 |
+
embedd = model.compute_embedding(mel_spec)
|
117 |
+
embedd = embedd.detach().cpu().numpy()
|
118 |
+
np.save(output_files[idx], embedd)
|
119 |
+
|
120 |
+
if args.target_dataset != '':
|
121 |
+
# create speaker_mapping if target dataset is defined
|
122 |
+
wav_file_name = os.path.basename(wav_file)
|
123 |
+
speaker_mapping[wav_file_name] = {}
|
124 |
+
speaker_mapping[wav_file_name]['name'] = speaker_name
|
125 |
+
speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist()
|
126 |
+
|
127 |
+
if args.target_dataset != '':
|
128 |
+
# save speaker_mapping if target dataset is defined
|
129 |
+
mapping_file_path = os.path.join(args.output_path, 'speakers.json')
|
130 |
+
save_speaker_mapping(args.output_path, speaker_mapping)
|
TTS/bin/compute_statistics.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
12 |
+
from TTS.utils.io import load_config
|
13 |
+
from TTS.utils.audio import AudioProcessor
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
"""Run preprocessing process."""
|
18 |
+
parser = argparse.ArgumentParser(
|
19 |
+
description="Compute mean and variance of spectrogtram features.")
|
20 |
+
parser.add_argument("--config_path", type=str, required=True,
|
21 |
+
help="TTS config file path to define audio processin parameters.")
|
22 |
+
parser.add_argument("--out_path", default=None, type=str,
|
23 |
+
help="directory to save the output file.")
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
# load config
|
27 |
+
CONFIG = load_config(args.config_path)
|
28 |
+
CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
|
29 |
+
CONFIG.audio['stats_path'] = None # discard pre-defined stats
|
30 |
+
|
31 |
+
# load audio processor
|
32 |
+
ap = AudioProcessor(**CONFIG.audio)
|
33 |
+
|
34 |
+
# load the meta data of target dataset
|
35 |
+
if 'data_path' in CONFIG.keys():
|
36 |
+
dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True)
|
37 |
+
else:
|
38 |
+
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
|
39 |
+
print(f" > There are {len(dataset_items)} files.")
|
40 |
+
|
41 |
+
mel_sum = 0
|
42 |
+
mel_square_sum = 0
|
43 |
+
linear_sum = 0
|
44 |
+
linear_square_sum = 0
|
45 |
+
N = 0
|
46 |
+
for item in tqdm(dataset_items):
|
47 |
+
# compute features
|
48 |
+
wav = ap.load_wav(item if isinstance(item, str) else item[1])
|
49 |
+
linear = ap.spectrogram(wav)
|
50 |
+
mel = ap.melspectrogram(wav)
|
51 |
+
|
52 |
+
# compute stats
|
53 |
+
N += mel.shape[1]
|
54 |
+
mel_sum += mel.sum(1)
|
55 |
+
linear_sum += linear.sum(1)
|
56 |
+
mel_square_sum += (mel ** 2).sum(axis=1)
|
57 |
+
linear_square_sum += (linear ** 2).sum(axis=1)
|
58 |
+
|
59 |
+
mel_mean = mel_sum / N
|
60 |
+
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
|
61 |
+
linear_mean = linear_sum / N
|
62 |
+
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
|
63 |
+
|
64 |
+
output_file_path = args.out_path
|
65 |
+
stats = {}
|
66 |
+
stats['mel_mean'] = mel_mean
|
67 |
+
stats['mel_std'] = mel_scale
|
68 |
+
stats['linear_mean'] = linear_mean
|
69 |
+
stats['linear_std'] = linear_scale
|
70 |
+
|
71 |
+
print(f' > Avg mel spec mean: {mel_mean.mean()}')
|
72 |
+
print(f' > Avg mel spec scale: {mel_scale.mean()}')
|
73 |
+
print(f' > Avg linear spec mean: {linear_mean.mean()}')
|
74 |
+
print(f' > Avg lienar spec scale: {linear_scale.mean()}')
|
75 |
+
|
76 |
+
# set default config values for mean-var scaling
|
77 |
+
CONFIG.audio['stats_path'] = output_file_path
|
78 |
+
CONFIG.audio['signal_norm'] = True
|
79 |
+
# remove redundant values
|
80 |
+
del CONFIG.audio['max_norm']
|
81 |
+
del CONFIG.audio['min_level_db']
|
82 |
+
del CONFIG.audio['symmetric_norm']
|
83 |
+
del CONFIG.audio['clip_norm']
|
84 |
+
stats['audio_config'] = CONFIG.audio
|
85 |
+
np.save(output_file_path, stats, allow_pickle=True)
|
86 |
+
print(f' > stats saved to {output_file_path}')
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
main()
|
TTS/bin/convert_melgan_tflite.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from TTS.utils.io import load_config
|
6 |
+
from TTS.vocoder.tf.utils.generic_utils import setup_generator
|
7 |
+
from TTS.vocoder.tf.utils.io import load_checkpoint
|
8 |
+
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
|
9 |
+
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('--tf_model',
|
13 |
+
type=str,
|
14 |
+
help='Path to target torch model to be converted to TF.')
|
15 |
+
parser.add_argument('--config_path',
|
16 |
+
type=str,
|
17 |
+
help='Path to config file of torch model.')
|
18 |
+
parser.add_argument('--output_path',
|
19 |
+
type=str,
|
20 |
+
help='path to tflite output binary.')
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
# Set constants
|
24 |
+
CONFIG = load_config(args.config_path)
|
25 |
+
|
26 |
+
# load the model
|
27 |
+
model = setup_generator(CONFIG)
|
28 |
+
model.build_inference()
|
29 |
+
model = load_checkpoint(model, args.tf_model)
|
30 |
+
|
31 |
+
# create tflite model
|
32 |
+
tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)
|
TTS/bin/convert_melgan_torch_to_tf.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from difflib import SequenceMatcher
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tensorflow as tf
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from TTS.utils.io import load_config
|
10 |
+
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
11 |
+
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
|
12 |
+
from TTS.vocoder.tf.utils.generic_utils import \
|
13 |
+
setup_generator as setup_tf_generator
|
14 |
+
from TTS.vocoder.tf.utils.io import save_checkpoint
|
15 |
+
from TTS.vocoder.utils.generic_utils import setup_generator
|
16 |
+
|
17 |
+
# prevent GPU use
|
18 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
19 |
+
|
20 |
+
# define args
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--torch_model_path',
|
23 |
+
type=str,
|
24 |
+
help='Path to target torch model to be converted to TF.')
|
25 |
+
parser.add_argument('--config_path',
|
26 |
+
type=str,
|
27 |
+
help='Path to config file of torch model.')
|
28 |
+
parser.add_argument(
|
29 |
+
'--output_path',
|
30 |
+
type=str,
|
31 |
+
help='path to output file including file name to save TF model.')
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
# load model config
|
35 |
+
config_path = args.config_path
|
36 |
+
c = load_config(config_path)
|
37 |
+
num_speakers = 0
|
38 |
+
|
39 |
+
# init torch model
|
40 |
+
model = setup_generator(c)
|
41 |
+
checkpoint = torch.load(args.torch_model_path,
|
42 |
+
map_location=torch.device('cpu'))
|
43 |
+
state_dict = checkpoint['model']
|
44 |
+
model.load_state_dict(state_dict)
|
45 |
+
model.remove_weight_norm()
|
46 |
+
state_dict = model.state_dict()
|
47 |
+
|
48 |
+
# init tf model
|
49 |
+
model_tf = setup_tf_generator(c)
|
50 |
+
|
51 |
+
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
52 |
+
# get tf_model graph by passing an input
|
53 |
+
# B x D x T
|
54 |
+
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
|
55 |
+
mel_pred = model_tf(dummy_input, training=False)
|
56 |
+
|
57 |
+
# get tf variables
|
58 |
+
tf_vars = model_tf.weights
|
59 |
+
|
60 |
+
# match variable names with fuzzy logic
|
61 |
+
torch_var_names = list(state_dict.keys())
|
62 |
+
tf_var_names = [we.name for we in model_tf.weights]
|
63 |
+
var_map = []
|
64 |
+
for tf_name in tf_var_names:
|
65 |
+
# skip re-mapped layer names
|
66 |
+
if tf_name in [name[0] for name in var_map]:
|
67 |
+
continue
|
68 |
+
tf_name_edited = convert_tf_name(tf_name)
|
69 |
+
ratios = [
|
70 |
+
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
|
71 |
+
for torch_name in torch_var_names
|
72 |
+
]
|
73 |
+
max_idx = np.argmax(ratios)
|
74 |
+
matching_name = torch_var_names[max_idx]
|
75 |
+
del torch_var_names[max_idx]
|
76 |
+
var_map.append((tf_name, matching_name))
|
77 |
+
|
78 |
+
# pass weights
|
79 |
+
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
80 |
+
|
81 |
+
# Compare TF and TORCH models
|
82 |
+
# check embedding outputs
|
83 |
+
model.eval()
|
84 |
+
dummy_input_torch = torch.ones((1, 80, 10))
|
85 |
+
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
86 |
+
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
|
87 |
+
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
|
88 |
+
|
89 |
+
out_torch = model.layers[0](dummy_input_torch)
|
90 |
+
out_tf = model_tf.model_layers[0](dummy_input_tf)
|
91 |
+
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
92 |
+
|
93 |
+
assert compare_torch_tf(out_torch, out_tf_) < 1e-5
|
94 |
+
|
95 |
+
for i in range(1, len(model.layers)):
|
96 |
+
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
|
97 |
+
out_torch = model.layers[i](out_torch)
|
98 |
+
out_tf = model_tf.model_layers[i](out_tf)
|
99 |
+
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
100 |
+
diff = compare_torch_tf(out_torch, out_tf_)
|
101 |
+
assert diff < 1e-5, diff
|
102 |
+
|
103 |
+
torch.manual_seed(0)
|
104 |
+
dummy_input_torch = torch.rand((1, 80, 100))
|
105 |
+
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
106 |
+
model.inference_padding = 0
|
107 |
+
model_tf.inference_padding = 0
|
108 |
+
output_torch = model.inference(dummy_input_torch)
|
109 |
+
output_tf = model_tf(dummy_input_tf, training=False)
|
110 |
+
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
111 |
+
output_torch, output_tf)
|
112 |
+
|
113 |
+
# save tf model
|
114 |
+
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
115 |
+
args.output_path)
|
116 |
+
print(' > Model conversion is successfully completed :).')
|
TTS/bin/convert_tacotron2_tflite.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from TTS.utils.io import load_config
|
6 |
+
from TTS.tts.utils.text.symbols import symbols, phonemes
|
7 |
+
from TTS.tts.tf.utils.generic_utils import setup_model
|
8 |
+
from TTS.tts.tf.utils.io import load_checkpoint
|
9 |
+
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
|
10 |
+
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--tf_model',
|
14 |
+
type=str,
|
15 |
+
help='Path to target torch model to be converted to TF.')
|
16 |
+
parser.add_argument('--config_path',
|
17 |
+
type=str,
|
18 |
+
help='Path to config file of torch model.')
|
19 |
+
parser.add_argument('--output_path',
|
20 |
+
type=str,
|
21 |
+
help='path to tflite output binary.')
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# Set constants
|
25 |
+
CONFIG = load_config(args.config_path)
|
26 |
+
|
27 |
+
# load the model
|
28 |
+
c = CONFIG
|
29 |
+
num_speakers = 0
|
30 |
+
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
31 |
+
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
|
32 |
+
model.build_inference()
|
33 |
+
model = load_checkpoint(model, args.tf_model)
|
34 |
+
model.decoder.set_max_decoder_steps(1000)
|
35 |
+
|
36 |
+
# create tflite model
|
37 |
+
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
|
TTS/bin/convert_tacotron2_torch_to_tf.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
# %%
|
3 |
+
import argparse
|
4 |
+
from difflib import SequenceMatcher
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
# %%
|
8 |
+
# print variable match
|
9 |
+
from pprint import pprint
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import tensorflow as tf
|
13 |
+
import torch
|
14 |
+
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
15 |
+
from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
|
16 |
+
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
|
17 |
+
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
18 |
+
from TTS.tts.utils.generic_utils import setup_model
|
19 |
+
from TTS.tts.utils.text.symbols import phonemes, symbols
|
20 |
+
from TTS.utils.io import load_config
|
21 |
+
|
22 |
+
sys.path.append('/home/erogol/Projects')
|
23 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
24 |
+
|
25 |
+
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument('--torch_model_path',
|
28 |
+
type=str,
|
29 |
+
help='Path to target torch model to be converted to TF.')
|
30 |
+
parser.add_argument('--config_path',
|
31 |
+
type=str,
|
32 |
+
help='Path to config file of torch model.')
|
33 |
+
parser.add_argument('--output_path',
|
34 |
+
type=str,
|
35 |
+
help='path to output file including file name to save TF model.')
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
# load model config
|
39 |
+
config_path = args.config_path
|
40 |
+
c = load_config(config_path)
|
41 |
+
num_speakers = 0
|
42 |
+
|
43 |
+
# init torch model
|
44 |
+
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
45 |
+
model = setup_model(num_chars, num_speakers, c)
|
46 |
+
checkpoint = torch.load(args.torch_model_path,
|
47 |
+
map_location=torch.device('cpu'))
|
48 |
+
state_dict = checkpoint['model']
|
49 |
+
model.load_state_dict(state_dict)
|
50 |
+
|
51 |
+
# init tf model
|
52 |
+
model_tf = Tacotron2(num_chars=num_chars,
|
53 |
+
num_speakers=num_speakers,
|
54 |
+
r=model.decoder.r,
|
55 |
+
postnet_output_dim=c.audio['num_mels'],
|
56 |
+
decoder_output_dim=c.audio['num_mels'],
|
57 |
+
attn_type=c.attention_type,
|
58 |
+
attn_win=c.windowing,
|
59 |
+
attn_norm=c.attention_norm,
|
60 |
+
prenet_type=c.prenet_type,
|
61 |
+
prenet_dropout=c.prenet_dropout,
|
62 |
+
forward_attn=c.use_forward_attn,
|
63 |
+
trans_agent=c.transition_agent,
|
64 |
+
forward_attn_mask=c.forward_attn_mask,
|
65 |
+
location_attn=c.location_attn,
|
66 |
+
attn_K=c.attention_heads,
|
67 |
+
separate_stopnet=c.separate_stopnet,
|
68 |
+
bidirectional_decoder=c.bidirectional_decoder)
|
69 |
+
|
70 |
+
# set initial layer mapping - these are not captured by the below heuristic approach
|
71 |
+
# TODO: set layer names so that we can remove these manual matching
|
72 |
+
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
73 |
+
var_map = [
|
74 |
+
('embedding/embeddings:0', 'embedding.weight'),
|
75 |
+
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
76 |
+
'encoder.lstm.weight_ih_l0'),
|
77 |
+
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
78 |
+
'encoder.lstm.weight_hh_l0'),
|
79 |
+
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
80 |
+
'encoder.lstm.weight_ih_l0_reverse'),
|
81 |
+
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
82 |
+
'encoder.lstm.weight_hh_l0_reverse'),
|
83 |
+
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
84 |
+
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
85 |
+
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
86 |
+
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
87 |
+
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
88 |
+
('decoder/linear_projection/kernel:0',
|
89 |
+
'decoder.linear_projection.linear_layer.weight'),
|
90 |
+
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
|
91 |
+
]
|
92 |
+
|
93 |
+
# %%
|
94 |
+
# get tf_model graph
|
95 |
+
model_tf.build_inference()
|
96 |
+
|
97 |
+
# get tf variables
|
98 |
+
tf_vars = model_tf.weights
|
99 |
+
|
100 |
+
# match variable names with fuzzy logic
|
101 |
+
torch_var_names = list(state_dict.keys())
|
102 |
+
tf_var_names = [we.name for we in model_tf.weights]
|
103 |
+
for tf_name in tf_var_names:
|
104 |
+
# skip re-mapped layer names
|
105 |
+
if tf_name in [name[0] for name in var_map]:
|
106 |
+
continue
|
107 |
+
tf_name_edited = convert_tf_name(tf_name)
|
108 |
+
ratios = [
|
109 |
+
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
|
110 |
+
for torch_name in torch_var_names
|
111 |
+
]
|
112 |
+
max_idx = np.argmax(ratios)
|
113 |
+
matching_name = torch_var_names[max_idx]
|
114 |
+
del torch_var_names[max_idx]
|
115 |
+
var_map.append((tf_name, matching_name))
|
116 |
+
|
117 |
+
pprint(var_map)
|
118 |
+
pprint(torch_var_names)
|
119 |
+
|
120 |
+
# pass weights
|
121 |
+
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
122 |
+
|
123 |
+
# Compare TF and TORCH models
|
124 |
+
# %%
|
125 |
+
# check embedding outputs
|
126 |
+
model.eval()
|
127 |
+
input_ids = torch.randint(0, 24, (1, 128)).long()
|
128 |
+
|
129 |
+
o_t = model.embedding(input_ids)
|
130 |
+
o_tf = model_tf.embedding(input_ids.detach().numpy())
|
131 |
+
assert abs(o_t.detach().numpy() -
|
132 |
+
o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
|
133 |
+
o_tf.numpy()).sum()
|
134 |
+
|
135 |
+
# compare encoder outputs
|
136 |
+
oo_en = model.encoder.inference(o_t.transpose(1, 2))
|
137 |
+
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
|
138 |
+
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
|
139 |
+
|
140 |
+
#pylint: disable=redefined-builtin
|
141 |
+
# compare decoder.attention_rnn
|
142 |
+
inp = torch.rand([1, 768])
|
143 |
+
inp_tf = inp.numpy()
|
144 |
+
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
|
145 |
+
output, cell_state = model.decoder.attention_rnn(inp)
|
146 |
+
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
147 |
+
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
|
148 |
+
states[2],
|
149 |
+
training=False)
|
150 |
+
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
151 |
+
|
152 |
+
query = output
|
153 |
+
inputs = torch.rand([1, 128, 512])
|
154 |
+
query_tf = query.detach().numpy()
|
155 |
+
inputs_tf = inputs.numpy()
|
156 |
+
|
157 |
+
# compare decoder.attention
|
158 |
+
model.decoder.attention.init_states(inputs)
|
159 |
+
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
|
160 |
+
loc_attn, proc_query = model.decoder.attention.get_location_attention(
|
161 |
+
query, processes_inputs)
|
162 |
+
context = model.decoder.attention(query, inputs, processes_inputs, None)
|
163 |
+
|
164 |
+
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
|
165 |
+
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
|
166 |
+
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
|
167 |
+
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
|
168 |
+
|
169 |
+
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
|
170 |
+
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
|
171 |
+
assert compare_torch_tf(context, context_tf) < 1e-5
|
172 |
+
|
173 |
+
# compare decoder.decoder_rnn
|
174 |
+
input = torch.rand([1, 1536])
|
175 |
+
input_tf = input.numpy()
|
176 |
+
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
|
177 |
+
output, cell_state = model.decoder.decoder_rnn(
|
178 |
+
input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
|
179 |
+
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
180 |
+
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
|
181 |
+
states[3],
|
182 |
+
training=False)
|
183 |
+
assert abs(input - input_tf).mean() < 1e-5
|
184 |
+
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
185 |
+
|
186 |
+
# compare decoder.linear_projection
|
187 |
+
input = torch.rand([1, 1536])
|
188 |
+
input_tf = input.numpy()
|
189 |
+
output = model.decoder.linear_projection(input)
|
190 |
+
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
|
191 |
+
assert compare_torch_tf(output, output_tf) < 1e-5
|
192 |
+
|
193 |
+
# compare decoder outputs
|
194 |
+
model.decoder.max_decoder_steps = 100
|
195 |
+
model_tf.decoder.set_max_decoder_steps(100)
|
196 |
+
output, align, stop = model.decoder.inference(oo_en)
|
197 |
+
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
198 |
+
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
|
199 |
+
assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
|
200 |
+
|
201 |
+
# compare the whole model output
|
202 |
+
outputs_torch = model.inference(input_ids)
|
203 |
+
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
|
204 |
+
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
|
205 |
+
assert compare_torch_tf(outputs_torch[2][:, 50, :],
|
206 |
+
outputs_tf[2][:, 50, :]) < 1e-5
|
207 |
+
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
|
208 |
+
|
209 |
+
# %%
|
210 |
+
# save tf model
|
211 |
+
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
|
212 |
+
checkpoint['r'], args.output_path)
|
213 |
+
print(' > Model conversion is successfully completed :).')
|
TTS/bin/distribute.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import pathlib
|
7 |
+
import time
|
8 |
+
import subprocess
|
9 |
+
import argparse
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
"""
|
15 |
+
Call train.py as a new process and pass command arguments
|
16 |
+
"""
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument(
|
19 |
+
'--script',
|
20 |
+
type=str,
|
21 |
+
help='Target training script to distibute.')
|
22 |
+
parser.add_argument(
|
23 |
+
'--continue_path',
|
24 |
+
type=str,
|
25 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
26 |
+
default='',
|
27 |
+
required='--config_path' not in sys.argv)
|
28 |
+
parser.add_argument(
|
29 |
+
'--restore_path',
|
30 |
+
type=str,
|
31 |
+
help='Model file to be restored. Use to finetune a model.',
|
32 |
+
default='')
|
33 |
+
parser.add_argument(
|
34 |
+
'--config_path',
|
35 |
+
type=str,
|
36 |
+
help='Path to config file for training.',
|
37 |
+
required='--continue_path' not in sys.argv
|
38 |
+
)
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
num_gpus = torch.cuda.device_count()
|
42 |
+
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
43 |
+
|
44 |
+
# set arguments for train.py
|
45 |
+
folder_path = pathlib.Path(__file__).parent.absolute()
|
46 |
+
command = [os.path.join(folder_path, args.script)]
|
47 |
+
command.append('--continue_path={}'.format(args.continue_path))
|
48 |
+
command.append('--restore_path={}'.format(args.restore_path))
|
49 |
+
command.append('--config_path={}'.format(args.config_path))
|
50 |
+
command.append('--group_id=group_{}'.format(group_id))
|
51 |
+
command.append('')
|
52 |
+
|
53 |
+
# run processes
|
54 |
+
processes = []
|
55 |
+
for i in range(num_gpus):
|
56 |
+
my_env = os.environ.copy()
|
57 |
+
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
58 |
+
command[-1] = '--rank={}'.format(i)
|
59 |
+
stdout = None if i == 0 else open(os.devnull, 'w')
|
60 |
+
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
61 |
+
processes.append(p)
|
62 |
+
print(command)
|
63 |
+
|
64 |
+
for p in processes:
|
65 |
+
p.wait()
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
main()
|
TTS/bin/synthesize.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import string
|
8 |
+
from argparse import RawTextHelpFormatter
|
9 |
+
# pylint: disable=redefined-outer-name, unused-argument
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from TTS.utils.manage import ModelManager
|
13 |
+
from TTS.utils.synthesizer import Synthesizer
|
14 |
+
|
15 |
+
|
16 |
+
def str2bool(v):
|
17 |
+
if isinstance(v, bool):
|
18 |
+
return v
|
19 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
20 |
+
return True
|
21 |
+
if v.lower() in ('no', 'false', 'f', 'n', '0'):
|
22 |
+
return False
|
23 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
# pylint: disable=bad-continuation
|
28 |
+
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
|
29 |
+
|
30 |
+
'''You can either use your trained model or choose a model from the provided list.\n'''\
|
31 |
+
|
32 |
+
'''
|
33 |
+
Example runs:
|
34 |
+
|
35 |
+
# list provided models
|
36 |
+
./TTS/bin/synthesize.py --list_models
|
37 |
+
|
38 |
+
# run a model from the list
|
39 |
+
./TTS/bin/synthesize.py --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
40 |
+
|
41 |
+
# run your own TTS model (Using Griffin-Lim Vocoder)
|
42 |
+
./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
|
43 |
+
|
44 |
+
# run your own TTS and Vocoder models
|
45 |
+
./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
|
46 |
+
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
|
47 |
+
|
48 |
+
''',
|
49 |
+
formatter_class=RawTextHelpFormatter)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
'--list_models',
|
53 |
+
type=str2bool,
|
54 |
+
nargs='?',
|
55 |
+
const=True,
|
56 |
+
default=False,
|
57 |
+
help='list available pre-trained tts and vocoder models.'
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
'--text',
|
61 |
+
type=str,
|
62 |
+
default=None,
|
63 |
+
help='Text to generate speech.'
|
64 |
+
)
|
65 |
+
|
66 |
+
# Args for running pre-trained TTS models.
|
67 |
+
parser.add_argument(
|
68 |
+
'--model_name',
|
69 |
+
type=str,
|
70 |
+
default=None,
|
71 |
+
help=
|
72 |
+
'Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>'
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
'--vocoder_name',
|
76 |
+
type=str,
|
77 |
+
default=None,
|
78 |
+
help=
|
79 |
+
'Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>'
|
80 |
+
)
|
81 |
+
|
82 |
+
# Args for running custom models
|
83 |
+
parser.add_argument(
|
84 |
+
'--config_path',
|
85 |
+
default=None,
|
86 |
+
type=str,
|
87 |
+
help='Path to model config file.'
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
'--model_path',
|
91 |
+
type=str,
|
92 |
+
default=None,
|
93 |
+
help='Path to model file.',
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
'--out_path',
|
97 |
+
type=str,
|
98 |
+
default=Path(__file__).resolve().parent,
|
99 |
+
help='Path to save final wav file. Wav file will be named as the given text.',
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
'--use_cuda',
|
103 |
+
type=bool,
|
104 |
+
help='Run model on CUDA.',
|
105 |
+
default=False
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
'--vocoder_path',
|
109 |
+
type=str,
|
110 |
+
help=
|
111 |
+
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
|
112 |
+
default=None,
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
'--vocoder_config_path',
|
116 |
+
type=str,
|
117 |
+
help='Path to vocoder model config file.',
|
118 |
+
default=None)
|
119 |
+
|
120 |
+
# args for multi-speaker synthesis
|
121 |
+
parser.add_argument(
|
122 |
+
'--speakers_json',
|
123 |
+
type=str,
|
124 |
+
help="JSON file for multi-speaker model.",
|
125 |
+
default=None)
|
126 |
+
parser.add_argument(
|
127 |
+
'--speaker_idx',
|
128 |
+
type=str,
|
129 |
+
help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.",
|
130 |
+
default=None)
|
131 |
+
parser.add_argument(
|
132 |
+
'--gst_style',
|
133 |
+
help="Wav path file for GST stylereference.",
|
134 |
+
default=None)
|
135 |
+
|
136 |
+
# aux args
|
137 |
+
parser.add_argument(
|
138 |
+
'--save_spectogram',
|
139 |
+
type=bool,
|
140 |
+
help="If true save raw spectogram for further (vocoder) processing in out_path.",
|
141 |
+
default=False)
|
142 |
+
|
143 |
+
args = parser.parse_args()
|
144 |
+
|
145 |
+
# load model manager
|
146 |
+
path = Path(__file__).parent / "../.models.json"
|
147 |
+
manager = ModelManager(path)
|
148 |
+
|
149 |
+
model_path = None
|
150 |
+
config_path = None
|
151 |
+
vocoder_path = None
|
152 |
+
vocoder_config_path = None
|
153 |
+
|
154 |
+
# CASE1: list pre-trained TTS models
|
155 |
+
if args.list_models:
|
156 |
+
manager.list_models()
|
157 |
+
sys.exit()
|
158 |
+
|
159 |
+
# CASE2: load pre-trained models
|
160 |
+
if args.model_name is not None:
|
161 |
+
model_path, config_path = manager.download_model(args.model_name)
|
162 |
+
|
163 |
+
if args.vocoder_name is not None:
|
164 |
+
vocoder_path, vocoder_config_path = manager.download_model(args.vocoder_name)
|
165 |
+
|
166 |
+
# CASE3: load custome models
|
167 |
+
if args.model_path is not None:
|
168 |
+
model_path = args.model_path
|
169 |
+
config_path = args.config_path
|
170 |
+
|
171 |
+
if args.vocoder_path is not None:
|
172 |
+
vocoder_path = args.vocoder_path
|
173 |
+
vocoder_config_path = args.vocoder_config_path
|
174 |
+
|
175 |
+
# RUN THE SYNTHESIS
|
176 |
+
# load models
|
177 |
+
synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, args.use_cuda)
|
178 |
+
|
179 |
+
use_griffin_lim = vocoder_path is None
|
180 |
+
print(" > Text: {}".format(args.text))
|
181 |
+
|
182 |
+
# # handle multi-speaker setting
|
183 |
+
# if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None:
|
184 |
+
# if args.speaker_idx.isdigit():
|
185 |
+
# args.speaker_idx = int(args.speaker_idx)
|
186 |
+
# else:
|
187 |
+
# args.speaker_idx = None
|
188 |
+
# else:
|
189 |
+
# args.speaker_idx = None
|
190 |
+
|
191 |
+
# if args.gst_style is None:
|
192 |
+
# if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None:
|
193 |
+
# gst_style = model_config.gst['gst_style_input']
|
194 |
+
# else:
|
195 |
+
# gst_style = None
|
196 |
+
# else:
|
197 |
+
# # check if gst_style string is a dict, if is dict convert else use string
|
198 |
+
# try:
|
199 |
+
# gst_style = json.loads(args.gst_style)
|
200 |
+
# if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']:
|
201 |
+
# raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens']))
|
202 |
+
# except ValueError:
|
203 |
+
# gst_style = args.gst_style
|
204 |
+
|
205 |
+
# kick it
|
206 |
+
wav = synthesizer.tts(args.text)
|
207 |
+
|
208 |
+
# save the results
|
209 |
+
file_name = args.text.replace(" ", "_")[0:20]
|
210 |
+
file_name = file_name.translate(
|
211 |
+
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
|
212 |
+
out_path = os.path.join(args.out_path, file_name)
|
213 |
+
print(" > Saving output to {}".format(out_path))
|
214 |
+
synthesizer.save_wav(wav, out_path)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
main()
|
TTS/bin/train_encoder.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from TTS.speaker_encoder.dataset import MyDataset
|
13 |
+
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
14 |
+
from TTS.speaker_encoder.model import SpeakerEncoder
|
15 |
+
from TTS.speaker_encoder.utils.generic_utils import \
|
16 |
+
check_config_speaker_encoder, save_best_model
|
17 |
+
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
18 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
19 |
+
from TTS.utils.audio import AudioProcessor
|
20 |
+
from TTS.utils.generic_utils import (count_parameters,
|
21 |
+
create_experiment_folder, get_git_branch,
|
22 |
+
remove_experiment_folder, set_init_dict)
|
23 |
+
from TTS.utils.io import copy_model_files, load_config
|
24 |
+
from TTS.utils.radam import RAdam
|
25 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
26 |
+
from TTS.utils.training import NoamLR, check_update
|
27 |
+
|
28 |
+
torch.backends.cudnn.enabled = True
|
29 |
+
torch.backends.cudnn.benchmark = True
|
30 |
+
torch.manual_seed(54321)
|
31 |
+
use_cuda = torch.cuda.is_available()
|
32 |
+
num_gpus = torch.cuda.device_count()
|
33 |
+
print(" > Using CUDA: ", use_cuda)
|
34 |
+
print(" > Number of GPUs: ", num_gpus)
|
35 |
+
|
36 |
+
|
37 |
+
def setup_loader(ap: AudioProcessor, is_val: bool=False, verbose: bool=False):
|
38 |
+
if is_val:
|
39 |
+
loader = None
|
40 |
+
else:
|
41 |
+
dataset = MyDataset(ap,
|
42 |
+
meta_data_eval if is_val else meta_data_train,
|
43 |
+
voice_len=1.6,
|
44 |
+
num_utter_per_speaker=c.num_utters_per_speaker,
|
45 |
+
num_speakers_in_batch=c.num_speakers_in_batch,
|
46 |
+
skip_speakers=False,
|
47 |
+
storage_size=c.storage["storage_size"],
|
48 |
+
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
49 |
+
additive_noise=c.storage["additive_noise"],
|
50 |
+
verbose=verbose)
|
51 |
+
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
52 |
+
loader = DataLoader(dataset,
|
53 |
+
batch_size=c.num_speakers_in_batch,
|
54 |
+
shuffle=False,
|
55 |
+
num_workers=c.num_loader_workers,
|
56 |
+
collate_fn=dataset.collate_fn)
|
57 |
+
return loader
|
58 |
+
|
59 |
+
|
60 |
+
def train(model, criterion, optimizer, scheduler, ap, global_step):
|
61 |
+
data_loader = setup_loader(ap, is_val=False, verbose=True)
|
62 |
+
model.train()
|
63 |
+
epoch_time = 0
|
64 |
+
best_loss = float('inf')
|
65 |
+
avg_loss = 0
|
66 |
+
avg_loader_time = 0
|
67 |
+
end_time = time.time()
|
68 |
+
for _, data in enumerate(data_loader):
|
69 |
+
start_time = time.time()
|
70 |
+
|
71 |
+
# setup input data
|
72 |
+
inputs = data[0]
|
73 |
+
loader_time = time.time() - end_time
|
74 |
+
global_step += 1
|
75 |
+
|
76 |
+
# setup lr
|
77 |
+
if c.lr_decay:
|
78 |
+
scheduler.step()
|
79 |
+
optimizer.zero_grad()
|
80 |
+
|
81 |
+
# dispatch data to GPU
|
82 |
+
if use_cuda:
|
83 |
+
inputs = inputs.cuda(non_blocking=True)
|
84 |
+
# labels = labels.cuda(non_blocking=True)
|
85 |
+
|
86 |
+
# forward pass model
|
87 |
+
outputs = model(inputs)
|
88 |
+
|
89 |
+
# loss computation
|
90 |
+
loss = criterion(
|
91 |
+
outputs.view(c.num_speakers_in_batch,
|
92 |
+
outputs.shape[0] // c.num_speakers_in_batch, -1))
|
93 |
+
loss.backward()
|
94 |
+
grad_norm, _ = check_update(model, c.grad_clip)
|
95 |
+
optimizer.step()
|
96 |
+
|
97 |
+
step_time = time.time() - start_time
|
98 |
+
epoch_time += step_time
|
99 |
+
|
100 |
+
# Averaged Loss and Averaged Loader Time
|
101 |
+
avg_loss = 0.01 * loss.item() \
|
102 |
+
+ 0.99 * avg_loss if avg_loss != 0 else loss.item()
|
103 |
+
avg_loader_time = 1/c.num_loader_workers * loader_time + \
|
104 |
+
(c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time
|
105 |
+
current_lr = optimizer.param_groups[0]['lr']
|
106 |
+
|
107 |
+
if global_step % c.steps_plot_stats == 0:
|
108 |
+
# Plot Training Epoch Stats
|
109 |
+
train_stats = {
|
110 |
+
"loss": avg_loss,
|
111 |
+
"lr": current_lr,
|
112 |
+
"grad_norm": grad_norm,
|
113 |
+
"step_time": step_time,
|
114 |
+
"avg_loader_time": avg_loader_time
|
115 |
+
}
|
116 |
+
tb_logger.tb_train_epoch_stats(global_step, train_stats)
|
117 |
+
figures = {
|
118 |
+
# FIXME: not constant
|
119 |
+
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
|
120 |
+
10),
|
121 |
+
}
|
122 |
+
tb_logger.tb_train_figures(global_step, figures)
|
123 |
+
|
124 |
+
if global_step % c.print_step == 0:
|
125 |
+
print(
|
126 |
+
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
|
127 |
+
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
|
128 |
+
global_step, loss.item(), avg_loss, grad_norm, step_time,
|
129 |
+
loader_time, avg_loader_time, current_lr),
|
130 |
+
flush=True)
|
131 |
+
|
132 |
+
# save best model
|
133 |
+
best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
|
134 |
+
OUT_PATH, global_step)
|
135 |
+
|
136 |
+
end_time = time.time()
|
137 |
+
return avg_loss, global_step
|
138 |
+
|
139 |
+
|
140 |
+
def main(args): # pylint: disable=redefined-outer-name
|
141 |
+
# pylint: disable=global-variable-undefined
|
142 |
+
global meta_data_train
|
143 |
+
global meta_data_eval
|
144 |
+
|
145 |
+
ap = AudioProcessor(**c.audio)
|
146 |
+
model = SpeakerEncoder(input_dim=c.model['input_dim'],
|
147 |
+
proj_dim=c.model['proj_dim'],
|
148 |
+
lstm_dim=c.model['lstm_dim'],
|
149 |
+
num_lstm_layers=c.model['num_lstm_layers'])
|
150 |
+
optimizer = RAdam(model.parameters(), lr=c.lr)
|
151 |
+
|
152 |
+
if c.loss == "ge2e":
|
153 |
+
criterion = GE2ELoss(loss_method='softmax')
|
154 |
+
elif c.loss == "angleproto":
|
155 |
+
criterion = AngleProtoLoss()
|
156 |
+
else:
|
157 |
+
raise Exception("The %s not is a loss supported" % c.loss)
|
158 |
+
|
159 |
+
if args.restore_path:
|
160 |
+
checkpoint = torch.load(args.restore_path)
|
161 |
+
try:
|
162 |
+
# TODO: fix optimizer init, model.cuda() needs to be called before
|
163 |
+
# optimizer restore
|
164 |
+
# optimizer.load_state_dict(checkpoint['optimizer'])
|
165 |
+
if c.reinit_layers:
|
166 |
+
raise RuntimeError
|
167 |
+
model.load_state_dict(checkpoint['model'])
|
168 |
+
except KeyError:
|
169 |
+
print(" > Partial model initialization.")
|
170 |
+
model_dict = model.state_dict()
|
171 |
+
model_dict = set_init_dict(model_dict, checkpoint, c)
|
172 |
+
model.load_state_dict(model_dict)
|
173 |
+
del model_dict
|
174 |
+
for group in optimizer.param_groups:
|
175 |
+
group['lr'] = c.lr
|
176 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
177 |
+
flush=True)
|
178 |
+
args.restore_step = checkpoint['step']
|
179 |
+
else:
|
180 |
+
args.restore_step = 0
|
181 |
+
|
182 |
+
if use_cuda:
|
183 |
+
model = model.cuda()
|
184 |
+
criterion.cuda()
|
185 |
+
|
186 |
+
if c.lr_decay:
|
187 |
+
scheduler = NoamLR(optimizer,
|
188 |
+
warmup_steps=c.warmup_steps,
|
189 |
+
last_epoch=args.restore_step - 1)
|
190 |
+
else:
|
191 |
+
scheduler = None
|
192 |
+
|
193 |
+
num_params = count_parameters(model)
|
194 |
+
print("\n > Model has {} parameters".format(num_params), flush=True)
|
195 |
+
|
196 |
+
# pylint: disable=redefined-outer-name
|
197 |
+
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
198 |
+
|
199 |
+
global_step = args.restore_step
|
200 |
+
_, global_step = train(model, criterion, optimizer, scheduler, ap,
|
201 |
+
global_step)
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
parser = argparse.ArgumentParser()
|
206 |
+
parser.add_argument(
|
207 |
+
'--restore_path',
|
208 |
+
type=str,
|
209 |
+
help='Path to model outputs (checkpoint, tensorboard etc.).',
|
210 |
+
default=0)
|
211 |
+
parser.add_argument(
|
212 |
+
'--config_path',
|
213 |
+
type=str,
|
214 |
+
required=True,
|
215 |
+
help='Path to config file for training.',
|
216 |
+
)
|
217 |
+
parser.add_argument('--debug',
|
218 |
+
type=bool,
|
219 |
+
default=True,
|
220 |
+
help='Do not verify commit integrity to run training.')
|
221 |
+
parser.add_argument(
|
222 |
+
'--data_path',
|
223 |
+
type=str,
|
224 |
+
default='',
|
225 |
+
help='Defines the data path. It overwrites config.json.')
|
226 |
+
parser.add_argument('--output_path',
|
227 |
+
type=str,
|
228 |
+
help='path for training outputs.',
|
229 |
+
default='')
|
230 |
+
parser.add_argument('--output_folder',
|
231 |
+
type=str,
|
232 |
+
default='',
|
233 |
+
help='folder name for training outputs.')
|
234 |
+
args = parser.parse_args()
|
235 |
+
|
236 |
+
# setup output paths and read configs
|
237 |
+
c = load_config(args.config_path)
|
238 |
+
check_config_speaker_encoder(c)
|
239 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
240 |
+
if args.data_path != '':
|
241 |
+
c.data_path = args.data_path
|
242 |
+
|
243 |
+
if args.output_path == '':
|
244 |
+
OUT_PATH = os.path.join(_, c.output_path)
|
245 |
+
else:
|
246 |
+
OUT_PATH = args.output_path
|
247 |
+
|
248 |
+
if args.output_folder == '':
|
249 |
+
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
|
250 |
+
else:
|
251 |
+
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
|
252 |
+
|
253 |
+
new_fields = {}
|
254 |
+
if args.restore_path:
|
255 |
+
new_fields["restore_path"] = args.restore_path
|
256 |
+
new_fields["github_branch"] = get_git_branch()
|
257 |
+
copy_model_files(c, args.config_path, OUT_PATH,
|
258 |
+
new_fields)
|
259 |
+
|
260 |
+
LOG_DIR = OUT_PATH
|
261 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder')
|
262 |
+
|
263 |
+
try:
|
264 |
+
main(args)
|
265 |
+
except KeyboardInterrupt:
|
266 |
+
remove_experiment_folder(OUT_PATH)
|
267 |
+
try:
|
268 |
+
sys.exit(0)
|
269 |
+
except SystemExit:
|
270 |
+
os._exit(0) # pylint: disable=protected-access
|
271 |
+
except Exception: # pylint: disable=broad-except
|
272 |
+
remove_experiment_folder(OUT_PATH)
|
273 |
+
traceback.print_exc()
|
274 |
+
sys.exit(1)
|
TTS/bin/train_glow_tts.py
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
import traceback
|
10 |
+
from random import randrange
|
11 |
+
|
12 |
+
import torch
|
13 |
+
# DISTRIBUTED
|
14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
18 |
+
from TTS.tts.datasets.TTSDataset import MyDataset
|
19 |
+
from TTS.tts.layers.losses import GlowTTSLoss
|
20 |
+
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
21 |
+
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
22 |
+
from TTS.tts.utils.measures import alignment_diagonal_score
|
23 |
+
from TTS.tts.utils.speakers import parse_speakers
|
24 |
+
from TTS.tts.utils.synthesis import synthesis
|
25 |
+
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
26 |
+
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
27 |
+
from TTS.utils.audio import AudioProcessor
|
28 |
+
from TTS.utils.console_logger import ConsoleLogger
|
29 |
+
from TTS.utils.distribute import init_distributed, reduce_tensor
|
30 |
+
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
31 |
+
create_experiment_folder, get_git_branch,
|
32 |
+
remove_experiment_folder, set_init_dict)
|
33 |
+
from TTS.utils.io import copy_model_files, load_config
|
34 |
+
from TTS.utils.radam import RAdam
|
35 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
36 |
+
from TTS.utils.training import NoamLR, setup_torch_training_env
|
37 |
+
|
38 |
+
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
39 |
+
|
40 |
+
def setup_loader(ap, r, is_val=False, verbose=False):
|
41 |
+
if is_val and not c.run_eval:
|
42 |
+
loader = None
|
43 |
+
else:
|
44 |
+
dataset = MyDataset(
|
45 |
+
r,
|
46 |
+
c.text_cleaner,
|
47 |
+
compute_linear_spec=False,
|
48 |
+
meta_data=meta_data_eval if is_val else meta_data_train,
|
49 |
+
ap=ap,
|
50 |
+
tp=c.characters if 'characters' in c.keys() else None,
|
51 |
+
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
52 |
+
batch_group_size=0 if is_val else c.batch_group_size *
|
53 |
+
c.batch_size,
|
54 |
+
min_seq_len=c.min_seq_len,
|
55 |
+
max_seq_len=c.max_seq_len,
|
56 |
+
phoneme_cache_path=c.phoneme_cache_path,
|
57 |
+
use_phonemes=c.use_phonemes,
|
58 |
+
phoneme_language=c.phoneme_language,
|
59 |
+
enable_eos_bos=c.enable_eos_bos_chars,
|
60 |
+
use_noise_augment=c['use_noise_augment'] and not is_val,
|
61 |
+
verbose=verbose,
|
62 |
+
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
63 |
+
|
64 |
+
if c.use_phonemes and c.compute_input_seq_cache:
|
65 |
+
# precompute phonemes to have a better estimate of sequence lengths.
|
66 |
+
dataset.compute_input_seq(c.num_loader_workers)
|
67 |
+
dataset.sort_items()
|
68 |
+
|
69 |
+
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
70 |
+
loader = DataLoader(
|
71 |
+
dataset,
|
72 |
+
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
73 |
+
shuffle=False,
|
74 |
+
collate_fn=dataset.collate_fn,
|
75 |
+
drop_last=False,
|
76 |
+
sampler=sampler,
|
77 |
+
num_workers=c.num_val_loader_workers
|
78 |
+
if is_val else c.num_loader_workers,
|
79 |
+
pin_memory=False)
|
80 |
+
return loader
|
81 |
+
|
82 |
+
|
83 |
+
def format_data(data):
|
84 |
+
# setup input data
|
85 |
+
text_input = data[0]
|
86 |
+
text_lengths = data[1]
|
87 |
+
speaker_names = data[2]
|
88 |
+
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
89 |
+
mel_lengths = data[5]
|
90 |
+
item_idx = data[7]
|
91 |
+
attn_mask = data[9]
|
92 |
+
avg_text_length = torch.mean(text_lengths.float())
|
93 |
+
avg_spec_length = torch.mean(mel_lengths.float())
|
94 |
+
|
95 |
+
if c.use_speaker_embedding:
|
96 |
+
if c.use_external_speaker_embedding_file:
|
97 |
+
# return precomputed embedding vector
|
98 |
+
speaker_c = data[8]
|
99 |
+
else:
|
100 |
+
# return speaker_id to be used by an embedding layer
|
101 |
+
speaker_c = [
|
102 |
+
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
103 |
+
]
|
104 |
+
speaker_c = torch.LongTensor(speaker_c)
|
105 |
+
else:
|
106 |
+
speaker_c = None
|
107 |
+
|
108 |
+
# dispatch data to GPU
|
109 |
+
if use_cuda:
|
110 |
+
text_input = text_input.cuda(non_blocking=True)
|
111 |
+
text_lengths = text_lengths.cuda(non_blocking=True)
|
112 |
+
mel_input = mel_input.cuda(non_blocking=True)
|
113 |
+
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
114 |
+
if speaker_c is not None:
|
115 |
+
speaker_c = speaker_c.cuda(non_blocking=True)
|
116 |
+
if attn_mask is not None:
|
117 |
+
attn_mask = attn_mask.cuda(non_blocking=True)
|
118 |
+
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
119 |
+
avg_text_length, avg_spec_length, attn_mask, item_idx
|
120 |
+
|
121 |
+
|
122 |
+
def data_depended_init(data_loader, model, ap):
|
123 |
+
"""Data depended initialization for activation normalization."""
|
124 |
+
if hasattr(model, 'module'):
|
125 |
+
for f in model.module.decoder.flows:
|
126 |
+
if getattr(f, "set_ddi", False):
|
127 |
+
f.set_ddi(True)
|
128 |
+
else:
|
129 |
+
for f in model.decoder.flows:
|
130 |
+
if getattr(f, "set_ddi", False):
|
131 |
+
f.set_ddi(True)
|
132 |
+
|
133 |
+
model.train()
|
134 |
+
print(" > Data depended initialization ... ")
|
135 |
+
num_iter = 0
|
136 |
+
with torch.no_grad():
|
137 |
+
for _, data in enumerate(data_loader):
|
138 |
+
|
139 |
+
# format data
|
140 |
+
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
|
141 |
+
_, _, attn_mask, item_idx = format_data(data)
|
142 |
+
|
143 |
+
# forward pass model
|
144 |
+
_ = model.forward(
|
145 |
+
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
|
146 |
+
if num_iter == c.data_dep_init_iter:
|
147 |
+
break
|
148 |
+
num_iter += 1
|
149 |
+
|
150 |
+
if hasattr(model, 'module'):
|
151 |
+
for f in model.module.decoder.flows:
|
152 |
+
if getattr(f, "set_ddi", False):
|
153 |
+
f.set_ddi(False)
|
154 |
+
else:
|
155 |
+
for f in model.decoder.flows:
|
156 |
+
if getattr(f, "set_ddi", False):
|
157 |
+
f.set_ddi(False)
|
158 |
+
return model
|
159 |
+
|
160 |
+
|
161 |
+
def train(data_loader, model, criterion, optimizer, scheduler,
|
162 |
+
ap, global_step, epoch):
|
163 |
+
|
164 |
+
model.train()
|
165 |
+
epoch_time = 0
|
166 |
+
keep_avg = KeepAverage()
|
167 |
+
if use_cuda:
|
168 |
+
batch_n_iter = int(
|
169 |
+
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
170 |
+
else:
|
171 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
172 |
+
end_time = time.time()
|
173 |
+
c_logger.print_train_start()
|
174 |
+
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
175 |
+
for num_iter, data in enumerate(data_loader):
|
176 |
+
start_time = time.time()
|
177 |
+
|
178 |
+
# format data
|
179 |
+
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
180 |
+
avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data)
|
181 |
+
|
182 |
+
loader_time = time.time() - end_time
|
183 |
+
|
184 |
+
global_step += 1
|
185 |
+
optimizer.zero_grad()
|
186 |
+
|
187 |
+
# forward pass model
|
188 |
+
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
189 |
+
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
190 |
+
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
191 |
+
|
192 |
+
# compute loss
|
193 |
+
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
194 |
+
o_dur_log, o_total_dur, text_lengths)
|
195 |
+
|
196 |
+
# backward pass with loss scaling
|
197 |
+
if c.mixed_precision:
|
198 |
+
scaler.scale(loss_dict['loss']).backward()
|
199 |
+
scaler.unscale_(optimizer)
|
200 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
201 |
+
c.grad_clip)
|
202 |
+
scaler.step(optimizer)
|
203 |
+
scaler.update()
|
204 |
+
else:
|
205 |
+
loss_dict['loss'].backward()
|
206 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
207 |
+
c.grad_clip)
|
208 |
+
optimizer.step()
|
209 |
+
|
210 |
+
# setup lr
|
211 |
+
if c.noam_schedule:
|
212 |
+
scheduler.step()
|
213 |
+
|
214 |
+
# current_lr
|
215 |
+
current_lr = optimizer.param_groups[0]['lr']
|
216 |
+
|
217 |
+
# compute alignment error (the lower the better )
|
218 |
+
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
219 |
+
loss_dict['align_error'] = align_error
|
220 |
+
|
221 |
+
step_time = time.time() - start_time
|
222 |
+
epoch_time += step_time
|
223 |
+
|
224 |
+
# aggregate losses from processes
|
225 |
+
if num_gpus > 1:
|
226 |
+
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
227 |
+
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
228 |
+
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
229 |
+
|
230 |
+
# detach loss values
|
231 |
+
loss_dict_new = dict()
|
232 |
+
for key, value in loss_dict.items():
|
233 |
+
if isinstance(value, (int, float)):
|
234 |
+
loss_dict_new[key] = value
|
235 |
+
else:
|
236 |
+
loss_dict_new[key] = value.item()
|
237 |
+
loss_dict = loss_dict_new
|
238 |
+
|
239 |
+
# update avg stats
|
240 |
+
update_train_values = dict()
|
241 |
+
for key, value in loss_dict.items():
|
242 |
+
update_train_values['avg_' + key] = value
|
243 |
+
update_train_values['avg_loader_time'] = loader_time
|
244 |
+
update_train_values['avg_step_time'] = step_time
|
245 |
+
keep_avg.update_values(update_train_values)
|
246 |
+
|
247 |
+
# print training progress
|
248 |
+
if global_step % c.print_step == 0:
|
249 |
+
log_dict = {
|
250 |
+
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
251 |
+
"avg_text_length": [avg_text_length, 1],
|
252 |
+
"step_time": [step_time, 4],
|
253 |
+
"loader_time": [loader_time, 2],
|
254 |
+
"current_lr": current_lr,
|
255 |
+
}
|
256 |
+
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
257 |
+
log_dict, loss_dict, keep_avg.avg_values)
|
258 |
+
|
259 |
+
if args.rank == 0:
|
260 |
+
# Plot Training Iter Stats
|
261 |
+
# reduce TB load
|
262 |
+
if global_step % c.tb_plot_step == 0:
|
263 |
+
iter_stats = {
|
264 |
+
"lr": current_lr,
|
265 |
+
"grad_norm": grad_norm,
|
266 |
+
"step_time": step_time
|
267 |
+
}
|
268 |
+
iter_stats.update(loss_dict)
|
269 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
270 |
+
|
271 |
+
if global_step % c.save_step == 0:
|
272 |
+
if c.checkpoint:
|
273 |
+
# save model
|
274 |
+
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
275 |
+
model_loss=loss_dict['loss'])
|
276 |
+
|
277 |
+
# wait all kernels to be completed
|
278 |
+
torch.cuda.synchronize()
|
279 |
+
|
280 |
+
# Diagnostic visualizations
|
281 |
+
# direct pass on model for spec predictions
|
282 |
+
target_speaker = None if speaker_c is None else speaker_c[:1]
|
283 |
+
|
284 |
+
if hasattr(model, 'module'):
|
285 |
+
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
286 |
+
else:
|
287 |
+
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
288 |
+
|
289 |
+
spec_pred = spec_pred.permute(0, 2, 1)
|
290 |
+
gt_spec = mel_input.permute(0, 2, 1)
|
291 |
+
const_spec = spec_pred[0].data.cpu().numpy()
|
292 |
+
gt_spec = gt_spec[0].data.cpu().numpy()
|
293 |
+
align_img = alignments[0].data.cpu().numpy()
|
294 |
+
|
295 |
+
figures = {
|
296 |
+
"prediction": plot_spectrogram(const_spec, ap),
|
297 |
+
"ground_truth": plot_spectrogram(gt_spec, ap),
|
298 |
+
"alignment": plot_alignment(align_img),
|
299 |
+
}
|
300 |
+
|
301 |
+
tb_logger.tb_train_figures(global_step, figures)
|
302 |
+
|
303 |
+
# Sample audio
|
304 |
+
train_audio = ap.inv_melspectrogram(const_spec.T)
|
305 |
+
tb_logger.tb_train_audios(global_step,
|
306 |
+
{'TrainAudio': train_audio},
|
307 |
+
c.audio["sample_rate"])
|
308 |
+
end_time = time.time()
|
309 |
+
|
310 |
+
# print epoch stats
|
311 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
312 |
+
|
313 |
+
# Plot Epoch Stats
|
314 |
+
if args.rank == 0:
|
315 |
+
epoch_stats = {"epoch_time": epoch_time}
|
316 |
+
epoch_stats.update(keep_avg.avg_values)
|
317 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
318 |
+
if c.tb_model_param_stats:
|
319 |
+
tb_logger.tb_model_weights(model, global_step)
|
320 |
+
return keep_avg.avg_values, global_step
|
321 |
+
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
325 |
+
model.eval()
|
326 |
+
epoch_time = 0
|
327 |
+
keep_avg = KeepAverage()
|
328 |
+
c_logger.print_eval_start()
|
329 |
+
if data_loader is not None:
|
330 |
+
for num_iter, data in enumerate(data_loader):
|
331 |
+
start_time = time.time()
|
332 |
+
|
333 |
+
# format data
|
334 |
+
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
335 |
+
_, _, attn_mask, item_idx = format_data(data)
|
336 |
+
|
337 |
+
# forward pass model
|
338 |
+
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
339 |
+
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
340 |
+
|
341 |
+
# compute loss
|
342 |
+
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
343 |
+
o_dur_log, o_total_dur, text_lengths)
|
344 |
+
|
345 |
+
# step time
|
346 |
+
step_time = time.time() - start_time
|
347 |
+
epoch_time += step_time
|
348 |
+
|
349 |
+
# compute alignment score
|
350 |
+
align_error = 1 - alignment_diagonal_score(alignments)
|
351 |
+
loss_dict['align_error'] = align_error
|
352 |
+
|
353 |
+
# aggregate losses from processes
|
354 |
+
if num_gpus > 1:
|
355 |
+
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
356 |
+
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
357 |
+
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
358 |
+
|
359 |
+
# detach loss values
|
360 |
+
loss_dict_new = dict()
|
361 |
+
for key, value in loss_dict.items():
|
362 |
+
if isinstance(value, (int, float)):
|
363 |
+
loss_dict_new[key] = value
|
364 |
+
else:
|
365 |
+
loss_dict_new[key] = value.item()
|
366 |
+
loss_dict = loss_dict_new
|
367 |
+
|
368 |
+
# update avg stats
|
369 |
+
update_train_values = dict()
|
370 |
+
for key, value in loss_dict.items():
|
371 |
+
update_train_values['avg_' + key] = value
|
372 |
+
keep_avg.update_values(update_train_values)
|
373 |
+
|
374 |
+
if c.print_eval:
|
375 |
+
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
376 |
+
|
377 |
+
if args.rank == 0:
|
378 |
+
# Diagnostic visualizations
|
379 |
+
# direct pass on model for spec predictions
|
380 |
+
target_speaker = None if speaker_c is None else speaker_c[:1]
|
381 |
+
if hasattr(model, 'module'):
|
382 |
+
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
383 |
+
else:
|
384 |
+
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
385 |
+
spec_pred = spec_pred.permute(0, 2, 1)
|
386 |
+
gt_spec = mel_input.permute(0, 2, 1)
|
387 |
+
|
388 |
+
const_spec = spec_pred[0].data.cpu().numpy()
|
389 |
+
gt_spec = gt_spec[0].data.cpu().numpy()
|
390 |
+
align_img = alignments[0].data.cpu().numpy()
|
391 |
+
|
392 |
+
eval_figures = {
|
393 |
+
"prediction": plot_spectrogram(const_spec, ap),
|
394 |
+
"ground_truth": plot_spectrogram(gt_spec, ap),
|
395 |
+
"alignment": plot_alignment(align_img)
|
396 |
+
}
|
397 |
+
|
398 |
+
# Sample audio
|
399 |
+
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
400 |
+
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
401 |
+
c.audio["sample_rate"])
|
402 |
+
|
403 |
+
# Plot Validation Stats
|
404 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
405 |
+
tb_logger.tb_eval_figures(global_step, eval_figures)
|
406 |
+
|
407 |
+
if args.rank == 0 and epoch >= c.test_delay_epochs:
|
408 |
+
if c.test_sentences_file is None:
|
409 |
+
test_sentences = [
|
410 |
+
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
411 |
+
"Be a voice, not an echo.",
|
412 |
+
"I'm sorry Dave. I'm afraid I can't do that.",
|
413 |
+
"This cake is great. It's so delicious and moist.",
|
414 |
+
"Prior to November 22, 1963."
|
415 |
+
]
|
416 |
+
else:
|
417 |
+
with open(c.test_sentences_file, "r") as f:
|
418 |
+
test_sentences = [s.strip() for s in f.readlines()]
|
419 |
+
|
420 |
+
# test sentences
|
421 |
+
test_audios = {}
|
422 |
+
test_figures = {}
|
423 |
+
print(" | > Synthesizing test sentences")
|
424 |
+
if c.use_speaker_embedding:
|
425 |
+
if c.use_external_speaker_embedding_file:
|
426 |
+
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
427 |
+
speaker_id = None
|
428 |
+
else:
|
429 |
+
speaker_id = 0
|
430 |
+
speaker_embedding = None
|
431 |
+
else:
|
432 |
+
speaker_id = None
|
433 |
+
speaker_embedding = None
|
434 |
+
|
435 |
+
style_wav = c.get("style_wav_for_test")
|
436 |
+
for idx, test_sentence in enumerate(test_sentences):
|
437 |
+
try:
|
438 |
+
wav, alignment, _, postnet_output, _, _ = synthesis(
|
439 |
+
model,
|
440 |
+
test_sentence,
|
441 |
+
c,
|
442 |
+
use_cuda,
|
443 |
+
ap,
|
444 |
+
speaker_id=speaker_id,
|
445 |
+
speaker_embedding=speaker_embedding,
|
446 |
+
style_wav=style_wav,
|
447 |
+
truncated=False,
|
448 |
+
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
449 |
+
use_griffin_lim=True,
|
450 |
+
do_trim_silence=False)
|
451 |
+
|
452 |
+
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
453 |
+
os.makedirs(file_path, exist_ok=True)
|
454 |
+
file_path = os.path.join(file_path,
|
455 |
+
"TestSentence_{}.wav".format(idx))
|
456 |
+
ap.save_wav(wav, file_path)
|
457 |
+
test_audios['{}-audio'.format(idx)] = wav
|
458 |
+
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
459 |
+
postnet_output, ap)
|
460 |
+
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
461 |
+
alignment)
|
462 |
+
except: #pylint: disable=bare-except
|
463 |
+
print(" !! Error creating Test Sentence -", idx)
|
464 |
+
traceback.print_exc()
|
465 |
+
tb_logger.tb_test_audios(global_step, test_audios,
|
466 |
+
c.audio['sample_rate'])
|
467 |
+
tb_logger.tb_test_figures(global_step, test_figures)
|
468 |
+
return keep_avg.avg_values
|
469 |
+
|
470 |
+
|
471 |
+
# FIXME: move args definition/parsing inside of main?
|
472 |
+
def main(args): # pylint: disable=redefined-outer-name
|
473 |
+
# pylint: disable=global-variable-undefined
|
474 |
+
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
475 |
+
# Audio processor
|
476 |
+
ap = AudioProcessor(**c.audio)
|
477 |
+
if 'characters' in c.keys():
|
478 |
+
symbols, phonemes = make_symbols(**c.characters)
|
479 |
+
|
480 |
+
# DISTRUBUTED
|
481 |
+
if num_gpus > 1:
|
482 |
+
init_distributed(args.rank, num_gpus, args.group_id,
|
483 |
+
c.distributed["backend"], c.distributed["url"])
|
484 |
+
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
485 |
+
|
486 |
+
# load data instances
|
487 |
+
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
488 |
+
|
489 |
+
# set the portion of the data used for training
|
490 |
+
if 'train_portion' in c.keys():
|
491 |
+
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
492 |
+
if 'eval_portion' in c.keys():
|
493 |
+
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
494 |
+
|
495 |
+
# parse speakers
|
496 |
+
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
497 |
+
|
498 |
+
# setup model
|
499 |
+
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
500 |
+
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
501 |
+
criterion = GlowTTSLoss()
|
502 |
+
|
503 |
+
if args.restore_path:
|
504 |
+
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
505 |
+
try:
|
506 |
+
# TODO: fix optimizer init, model.cuda() needs to be called before
|
507 |
+
# optimizer restore
|
508 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
509 |
+
if c.reinit_layers:
|
510 |
+
raise RuntimeError
|
511 |
+
model.load_state_dict(checkpoint['model'])
|
512 |
+
except: #pylint: disable=bare-except
|
513 |
+
print(" > Partial model initialization.")
|
514 |
+
model_dict = model.state_dict()
|
515 |
+
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
516 |
+
model.load_state_dict(model_dict)
|
517 |
+
del model_dict
|
518 |
+
|
519 |
+
for group in optimizer.param_groups:
|
520 |
+
group['initial_lr'] = c.lr
|
521 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
522 |
+
flush=True)
|
523 |
+
args.restore_step = checkpoint['step']
|
524 |
+
else:
|
525 |
+
args.restore_step = 0
|
526 |
+
|
527 |
+
if use_cuda:
|
528 |
+
model.cuda()
|
529 |
+
criterion.cuda()
|
530 |
+
|
531 |
+
# DISTRUBUTED
|
532 |
+
if num_gpus > 1:
|
533 |
+
model = DDP_th(model, device_ids=[args.rank])
|
534 |
+
|
535 |
+
if c.noam_schedule:
|
536 |
+
scheduler = NoamLR(optimizer,
|
537 |
+
warmup_steps=c.warmup_steps,
|
538 |
+
last_epoch=args.restore_step - 1)
|
539 |
+
else:
|
540 |
+
scheduler = None
|
541 |
+
|
542 |
+
num_params = count_parameters(model)
|
543 |
+
print("\n > Model has {} parameters".format(num_params), flush=True)
|
544 |
+
|
545 |
+
if 'best_loss' not in locals():
|
546 |
+
best_loss = float('inf')
|
547 |
+
|
548 |
+
# define dataloaders
|
549 |
+
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
550 |
+
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
551 |
+
|
552 |
+
global_step = args.restore_step
|
553 |
+
model = data_depended_init(train_loader, model, ap)
|
554 |
+
for epoch in range(0, c.epochs):
|
555 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
556 |
+
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
557 |
+
scheduler, ap, global_step,
|
558 |
+
epoch)
|
559 |
+
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
560 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
561 |
+
target_loss = train_avg_loss_dict['avg_loss']
|
562 |
+
if c.run_eval:
|
563 |
+
target_loss = eval_avg_loss_dict['avg_loss']
|
564 |
+
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
565 |
+
OUT_PATH)
|
566 |
+
|
567 |
+
|
568 |
+
if __name__ == '__main__':
|
569 |
+
parser = argparse.ArgumentParser()
|
570 |
+
parser.add_argument(
|
571 |
+
'--continue_path',
|
572 |
+
type=str,
|
573 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
574 |
+
default='',
|
575 |
+
required='--config_path' not in sys.argv)
|
576 |
+
parser.add_argument(
|
577 |
+
'--restore_path',
|
578 |
+
type=str,
|
579 |
+
help='Model file to be restored. Use to finetune a model.',
|
580 |
+
default='')
|
581 |
+
parser.add_argument(
|
582 |
+
'--config_path',
|
583 |
+
type=str,
|
584 |
+
help='Path to config file for training.',
|
585 |
+
required='--continue_path' not in sys.argv
|
586 |
+
)
|
587 |
+
parser.add_argument('--debug',
|
588 |
+
type=bool,
|
589 |
+
default=False,
|
590 |
+
help='Do not verify commit integrity to run training.')
|
591 |
+
|
592 |
+
# DISTRUBUTED
|
593 |
+
parser.add_argument(
|
594 |
+
'--rank',
|
595 |
+
type=int,
|
596 |
+
default=0,
|
597 |
+
help='DISTRIBUTED: process rank for distributed training.')
|
598 |
+
parser.add_argument('--group_id',
|
599 |
+
type=str,
|
600 |
+
default="",
|
601 |
+
help='DISTRIBUTED: process group id.')
|
602 |
+
args = parser.parse_args()
|
603 |
+
|
604 |
+
if args.continue_path != '':
|
605 |
+
args.output_path = args.continue_path
|
606 |
+
args.config_path = os.path.join(args.continue_path, 'config.json')
|
607 |
+
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
608 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
609 |
+
args.restore_path = latest_model_file
|
610 |
+
print(f" > Training continues for {args.restore_path}")
|
611 |
+
|
612 |
+
# setup output paths and read configs
|
613 |
+
c = load_config(args.config_path)
|
614 |
+
# check_config(c)
|
615 |
+
check_config_tts(c)
|
616 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
617 |
+
|
618 |
+
if c.mixed_precision:
|
619 |
+
print(" > Mixed precision enabled.")
|
620 |
+
|
621 |
+
OUT_PATH = args.continue_path
|
622 |
+
if args.continue_path == '':
|
623 |
+
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
624 |
+
|
625 |
+
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
626 |
+
|
627 |
+
c_logger = ConsoleLogger()
|
628 |
+
|
629 |
+
if args.rank == 0:
|
630 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
631 |
+
new_fields = {}
|
632 |
+
if args.restore_path:
|
633 |
+
new_fields["restore_path"] = args.restore_path
|
634 |
+
new_fields["github_branch"] = get_git_branch()
|
635 |
+
copy_model_files(c, args.config_path,
|
636 |
+
OUT_PATH, new_fields)
|
637 |
+
os.chmod(AUDIO_PATH, 0o775)
|
638 |
+
os.chmod(OUT_PATH, 0o775)
|
639 |
+
|
640 |
+
LOG_DIR = OUT_PATH
|
641 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
642 |
+
|
643 |
+
# write model desc to tensorboard
|
644 |
+
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
645 |
+
|
646 |
+
try:
|
647 |
+
main(args)
|
648 |
+
except KeyboardInterrupt:
|
649 |
+
remove_experiment_folder(OUT_PATH)
|
650 |
+
try:
|
651 |
+
sys.exit(0)
|
652 |
+
except SystemExit:
|
653 |
+
os._exit(0) # pylint: disable=protected-access
|
654 |
+
except Exception: # pylint: disable=broad-except
|
655 |
+
remove_experiment_folder(OUT_PATH)
|
656 |
+
traceback.print_exc()
|
657 |
+
sys.exit(1)
|
TTS/bin/train_speedy_speech.py
ADDED
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
import traceback
|
10 |
+
import numpy as np
|
11 |
+
from random import randrange
|
12 |
+
|
13 |
+
import torch
|
14 |
+
# DISTRIBUTED
|
15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
19 |
+
from TTS.tts.datasets.TTSDataset import MyDataset
|
20 |
+
from TTS.tts.layers.losses import SpeedySpeechLoss
|
21 |
+
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
22 |
+
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
23 |
+
from TTS.tts.utils.measures import alignment_diagonal_score
|
24 |
+
from TTS.tts.utils.speakers import parse_speakers
|
25 |
+
from TTS.tts.utils.synthesis import synthesis
|
26 |
+
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
27 |
+
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
28 |
+
from TTS.utils.audio import AudioProcessor
|
29 |
+
from TTS.utils.console_logger import ConsoleLogger
|
30 |
+
from TTS.utils.distribute import init_distributed, reduce_tensor
|
31 |
+
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
32 |
+
create_experiment_folder, get_git_branch,
|
33 |
+
remove_experiment_folder, set_init_dict)
|
34 |
+
from TTS.utils.io import copy_model_files, load_config
|
35 |
+
from TTS.utils.radam import RAdam
|
36 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
37 |
+
from TTS.utils.training import NoamLR, setup_torch_training_env
|
38 |
+
|
39 |
+
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
40 |
+
|
41 |
+
|
42 |
+
def setup_loader(ap, r, is_val=False, verbose=False):
|
43 |
+
if is_val and not c.run_eval:
|
44 |
+
loader = None
|
45 |
+
else:
|
46 |
+
dataset = MyDataset(
|
47 |
+
r,
|
48 |
+
c.text_cleaner,
|
49 |
+
compute_linear_spec=False,
|
50 |
+
meta_data=meta_data_eval if is_val else meta_data_train,
|
51 |
+
ap=ap,
|
52 |
+
tp=c.characters if 'characters' in c.keys() else None,
|
53 |
+
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
54 |
+
batch_group_size=0 if is_val else c.batch_group_size *
|
55 |
+
c.batch_size,
|
56 |
+
min_seq_len=c.min_seq_len,
|
57 |
+
max_seq_len=c.max_seq_len,
|
58 |
+
phoneme_cache_path=c.phoneme_cache_path,
|
59 |
+
use_phonemes=c.use_phonemes,
|
60 |
+
phoneme_language=c.phoneme_language,
|
61 |
+
enable_eos_bos=c.enable_eos_bos_chars,
|
62 |
+
use_noise_augment=not is_val,
|
63 |
+
verbose=verbose,
|
64 |
+
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
65 |
+
|
66 |
+
if c.use_phonemes and c.compute_input_seq_cache:
|
67 |
+
# precompute phonemes to have a better estimate of sequence lengths.
|
68 |
+
dataset.compute_input_seq(c.num_loader_workers)
|
69 |
+
dataset.sort_items()
|
70 |
+
|
71 |
+
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
72 |
+
loader = DataLoader(
|
73 |
+
dataset,
|
74 |
+
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
75 |
+
shuffle=False,
|
76 |
+
collate_fn=dataset.collate_fn,
|
77 |
+
drop_last=False,
|
78 |
+
sampler=sampler,
|
79 |
+
num_workers=c.num_val_loader_workers
|
80 |
+
if is_val else c.num_loader_workers,
|
81 |
+
pin_memory=False)
|
82 |
+
return loader
|
83 |
+
|
84 |
+
|
85 |
+
def format_data(data):
|
86 |
+
# setup input data
|
87 |
+
text_input = data[0]
|
88 |
+
text_lengths = data[1]
|
89 |
+
speaker_names = data[2]
|
90 |
+
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
91 |
+
mel_lengths = data[5]
|
92 |
+
item_idx = data[7]
|
93 |
+
attn_mask = data[9]
|
94 |
+
avg_text_length = torch.mean(text_lengths.float())
|
95 |
+
avg_spec_length = torch.mean(mel_lengths.float())
|
96 |
+
|
97 |
+
if c.use_speaker_embedding:
|
98 |
+
if c.use_external_speaker_embedding_file:
|
99 |
+
# return precomputed embedding vector
|
100 |
+
speaker_c = data[8]
|
101 |
+
else:
|
102 |
+
# return speaker_id to be used by an embedding layer
|
103 |
+
speaker_c = [
|
104 |
+
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
105 |
+
]
|
106 |
+
speaker_c = torch.LongTensor(speaker_c)
|
107 |
+
else:
|
108 |
+
speaker_c = None
|
109 |
+
# compute durations from attention mask
|
110 |
+
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
111 |
+
for idx, am in enumerate(attn_mask):
|
112 |
+
# compute raw durations
|
113 |
+
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
|
114 |
+
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
115 |
+
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
116 |
+
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
117 |
+
dur[c_idxs] = counts
|
118 |
+
# smooth the durations and set any 0 duration to 1
|
119 |
+
# by cutting off from the largest duration indeces.
|
120 |
+
extra_frames = dur.sum() - mel_lengths[idx]
|
121 |
+
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
122 |
+
dur[largest_idxs] -= 1
|
123 |
+
assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
124 |
+
durations[idx, :text_lengths[idx]] = dur
|
125 |
+
# dispatch data to GPU
|
126 |
+
if use_cuda:
|
127 |
+
text_input = text_input.cuda(non_blocking=True)
|
128 |
+
text_lengths = text_lengths.cuda(non_blocking=True)
|
129 |
+
mel_input = mel_input.cuda(non_blocking=True)
|
130 |
+
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
131 |
+
if speaker_c is not None:
|
132 |
+
speaker_c = speaker_c.cuda(non_blocking=True)
|
133 |
+
attn_mask = attn_mask.cuda(non_blocking=True)
|
134 |
+
durations = durations.cuda(non_blocking=True)
|
135 |
+
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
136 |
+
avg_text_length, avg_spec_length, attn_mask, durations, item_idx
|
137 |
+
|
138 |
+
|
139 |
+
def train(data_loader, model, criterion, optimizer, scheduler,
|
140 |
+
ap, global_step, epoch):
|
141 |
+
|
142 |
+
model.train()
|
143 |
+
epoch_time = 0
|
144 |
+
keep_avg = KeepAverage()
|
145 |
+
if use_cuda:
|
146 |
+
batch_n_iter = int(
|
147 |
+
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
148 |
+
else:
|
149 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
150 |
+
end_time = time.time()
|
151 |
+
c_logger.print_train_start()
|
152 |
+
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
153 |
+
for num_iter, data in enumerate(data_loader):
|
154 |
+
start_time = time.time()
|
155 |
+
|
156 |
+
# format data
|
157 |
+
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
158 |
+
avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data)
|
159 |
+
|
160 |
+
loader_time = time.time() - end_time
|
161 |
+
|
162 |
+
global_step += 1
|
163 |
+
optimizer.zero_grad()
|
164 |
+
|
165 |
+
# forward pass model
|
166 |
+
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
167 |
+
decoder_output, dur_output, alignments = model.forward(
|
168 |
+
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
169 |
+
|
170 |
+
# compute loss
|
171 |
+
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
172 |
+
|
173 |
+
# backward pass with loss scaling
|
174 |
+
if c.mixed_precision:
|
175 |
+
scaler.scale(loss_dict['loss']).backward()
|
176 |
+
scaler.unscale_(optimizer)
|
177 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
178 |
+
c.grad_clip)
|
179 |
+
scaler.step(optimizer)
|
180 |
+
scaler.update()
|
181 |
+
else:
|
182 |
+
loss_dict['loss'].backward()
|
183 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
184 |
+
c.grad_clip)
|
185 |
+
optimizer.step()
|
186 |
+
|
187 |
+
# setup lr
|
188 |
+
if c.noam_schedule:
|
189 |
+
scheduler.step()
|
190 |
+
|
191 |
+
# current_lr
|
192 |
+
current_lr = optimizer.param_groups[0]['lr']
|
193 |
+
|
194 |
+
# compute alignment error (the lower the better )
|
195 |
+
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
196 |
+
loss_dict['align_error'] = align_error
|
197 |
+
|
198 |
+
step_time = time.time() - start_time
|
199 |
+
epoch_time += step_time
|
200 |
+
|
201 |
+
# aggregate losses from processes
|
202 |
+
if num_gpus > 1:
|
203 |
+
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
204 |
+
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
205 |
+
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
206 |
+
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
207 |
+
|
208 |
+
# detach loss values
|
209 |
+
loss_dict_new = dict()
|
210 |
+
for key, value in loss_dict.items():
|
211 |
+
if isinstance(value, (int, float)):
|
212 |
+
loss_dict_new[key] = value
|
213 |
+
else:
|
214 |
+
loss_dict_new[key] = value.item()
|
215 |
+
loss_dict = loss_dict_new
|
216 |
+
|
217 |
+
# update avg stats
|
218 |
+
update_train_values = dict()
|
219 |
+
for key, value in loss_dict.items():
|
220 |
+
update_train_values['avg_' + key] = value
|
221 |
+
update_train_values['avg_loader_time'] = loader_time
|
222 |
+
update_train_values['avg_step_time'] = step_time
|
223 |
+
keep_avg.update_values(update_train_values)
|
224 |
+
|
225 |
+
# print training progress
|
226 |
+
if global_step % c.print_step == 0:
|
227 |
+
log_dict = {
|
228 |
+
|
229 |
+
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
230 |
+
"avg_text_length": [avg_text_length, 1],
|
231 |
+
"step_time": [step_time, 4],
|
232 |
+
"loader_time": [loader_time, 2],
|
233 |
+
"current_lr": current_lr,
|
234 |
+
}
|
235 |
+
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
236 |
+
log_dict, loss_dict, keep_avg.avg_values)
|
237 |
+
|
238 |
+
if args.rank == 0:
|
239 |
+
# Plot Training Iter Stats
|
240 |
+
# reduce TB load
|
241 |
+
if global_step % c.tb_plot_step == 0:
|
242 |
+
iter_stats = {
|
243 |
+
"lr": current_lr,
|
244 |
+
"grad_norm": grad_norm,
|
245 |
+
"step_time": step_time
|
246 |
+
}
|
247 |
+
iter_stats.update(loss_dict)
|
248 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
249 |
+
|
250 |
+
if global_step % c.save_step == 0:
|
251 |
+
if c.checkpoint:
|
252 |
+
# save model
|
253 |
+
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
254 |
+
model_loss=loss_dict['loss'])
|
255 |
+
|
256 |
+
# wait all kernels to be completed
|
257 |
+
torch.cuda.synchronize()
|
258 |
+
|
259 |
+
# Diagnostic visualizations
|
260 |
+
idx = np.random.randint(mel_targets.shape[0])
|
261 |
+
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
262 |
+
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
263 |
+
align_img = alignments[idx].data.cpu()
|
264 |
+
|
265 |
+
figures = {
|
266 |
+
"prediction": plot_spectrogram(pred_spec, ap),
|
267 |
+
"ground_truth": plot_spectrogram(gt_spec, ap),
|
268 |
+
"alignment": plot_alignment(align_img),
|
269 |
+
}
|
270 |
+
|
271 |
+
tb_logger.tb_train_figures(global_step, figures)
|
272 |
+
|
273 |
+
# Sample audio
|
274 |
+
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
275 |
+
tb_logger.tb_train_audios(global_step,
|
276 |
+
{'TrainAudio': train_audio},
|
277 |
+
c.audio["sample_rate"])
|
278 |
+
end_time = time.time()
|
279 |
+
|
280 |
+
# print epoch stats
|
281 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
282 |
+
|
283 |
+
# Plot Epoch Stats
|
284 |
+
if args.rank == 0:
|
285 |
+
epoch_stats = {"epoch_time": epoch_time}
|
286 |
+
epoch_stats.update(keep_avg.avg_values)
|
287 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
288 |
+
if c.tb_model_param_stats:
|
289 |
+
tb_logger.tb_model_weights(model, global_step)
|
290 |
+
return keep_avg.avg_values, global_step
|
291 |
+
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
295 |
+
model.eval()
|
296 |
+
epoch_time = 0
|
297 |
+
keep_avg = KeepAverage()
|
298 |
+
c_logger.print_eval_start()
|
299 |
+
if data_loader is not None:
|
300 |
+
for num_iter, data in enumerate(data_loader):
|
301 |
+
start_time = time.time()
|
302 |
+
|
303 |
+
# format data
|
304 |
+
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
305 |
+
_, _, _, dur_target, _ = format_data(data)
|
306 |
+
|
307 |
+
# forward pass model
|
308 |
+
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
309 |
+
decoder_output, dur_output, alignments = model.forward(
|
310 |
+
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
311 |
+
|
312 |
+
# compute loss
|
313 |
+
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
314 |
+
|
315 |
+
# step time
|
316 |
+
step_time = time.time() - start_time
|
317 |
+
epoch_time += step_time
|
318 |
+
|
319 |
+
# compute alignment score
|
320 |
+
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
321 |
+
loss_dict['align_error'] = align_error
|
322 |
+
|
323 |
+
# aggregate losses from processes
|
324 |
+
if num_gpus > 1:
|
325 |
+
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
326 |
+
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
327 |
+
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
328 |
+
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
329 |
+
|
330 |
+
# detach loss values
|
331 |
+
loss_dict_new = dict()
|
332 |
+
for key, value in loss_dict.items():
|
333 |
+
if isinstance(value, (int, float)):
|
334 |
+
loss_dict_new[key] = value
|
335 |
+
else:
|
336 |
+
loss_dict_new[key] = value.item()
|
337 |
+
loss_dict = loss_dict_new
|
338 |
+
|
339 |
+
# update avg stats
|
340 |
+
update_train_values = dict()
|
341 |
+
for key, value in loss_dict.items():
|
342 |
+
update_train_values['avg_' + key] = value
|
343 |
+
keep_avg.update_values(update_train_values)
|
344 |
+
|
345 |
+
if c.print_eval:
|
346 |
+
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
347 |
+
|
348 |
+
if args.rank == 0:
|
349 |
+
# Diagnostic visualizations
|
350 |
+
idx = np.random.randint(mel_targets.shape[0])
|
351 |
+
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
352 |
+
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
353 |
+
align_img = alignments[idx].data.cpu()
|
354 |
+
|
355 |
+
eval_figures = {
|
356 |
+
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
357 |
+
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
358 |
+
"alignment": plot_alignment(align_img, output_fig=False)
|
359 |
+
}
|
360 |
+
|
361 |
+
# Sample audio
|
362 |
+
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
363 |
+
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
364 |
+
c.audio["sample_rate"])
|
365 |
+
|
366 |
+
# Plot Validation Stats
|
367 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
368 |
+
tb_logger.tb_eval_figures(global_step, eval_figures)
|
369 |
+
|
370 |
+
if args.rank == 0 and epoch >= c.test_delay_epochs:
|
371 |
+
if c.test_sentences_file is None:
|
372 |
+
test_sentences = [
|
373 |
+
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
374 |
+
"Be a voice, not an echo.",
|
375 |
+
"I'm sorry Dave. I'm afraid I can't do that.",
|
376 |
+
"This cake is great. It's so delicious and moist.",
|
377 |
+
"Prior to November 22, 1963."
|
378 |
+
]
|
379 |
+
else:
|
380 |
+
with open(c.test_sentences_file, "r") as f:
|
381 |
+
test_sentences = [s.strip() for s in f.readlines()]
|
382 |
+
|
383 |
+
# test sentences
|
384 |
+
test_audios = {}
|
385 |
+
test_figures = {}
|
386 |
+
print(" | > Synthesizing test sentences")
|
387 |
+
if c.use_speaker_embedding:
|
388 |
+
if c.use_external_speaker_embedding_file:
|
389 |
+
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
390 |
+
speaker_id = None
|
391 |
+
else:
|
392 |
+
speaker_id = 0
|
393 |
+
speaker_embedding = None
|
394 |
+
else:
|
395 |
+
speaker_id = None
|
396 |
+
speaker_embedding = None
|
397 |
+
|
398 |
+
style_wav = c.get("style_wav_for_test")
|
399 |
+
for idx, test_sentence in enumerate(test_sentences):
|
400 |
+
try:
|
401 |
+
wav, alignment, _, postnet_output, _, _ = synthesis(
|
402 |
+
model,
|
403 |
+
test_sentence,
|
404 |
+
c,
|
405 |
+
use_cuda,
|
406 |
+
ap,
|
407 |
+
speaker_id=speaker_id,
|
408 |
+
speaker_embedding=speaker_embedding,
|
409 |
+
style_wav=style_wav,
|
410 |
+
truncated=False,
|
411 |
+
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
412 |
+
use_griffin_lim=True,
|
413 |
+
do_trim_silence=False)
|
414 |
+
|
415 |
+
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
416 |
+
os.makedirs(file_path, exist_ok=True)
|
417 |
+
file_path = os.path.join(file_path,
|
418 |
+
"TestSentence_{}.wav".format(idx))
|
419 |
+
ap.save_wav(wav, file_path)
|
420 |
+
test_audios['{}-audio'.format(idx)] = wav
|
421 |
+
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
422 |
+
postnet_output, ap)
|
423 |
+
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
424 |
+
alignment)
|
425 |
+
except: #pylint: disable=bare-except
|
426 |
+
print(" !! Error creating Test Sentence -", idx)
|
427 |
+
traceback.print_exc()
|
428 |
+
tb_logger.tb_test_audios(global_step, test_audios,
|
429 |
+
c.audio['sample_rate'])
|
430 |
+
tb_logger.tb_test_figures(global_step, test_figures)
|
431 |
+
return keep_avg.avg_values
|
432 |
+
|
433 |
+
|
434 |
+
# FIXME: move args definition/parsing inside of main?
|
435 |
+
def main(args): # pylint: disable=redefined-outer-name
|
436 |
+
# pylint: disable=global-variable-undefined
|
437 |
+
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
438 |
+
# Audio processor
|
439 |
+
ap = AudioProcessor(**c.audio)
|
440 |
+
if 'characters' in c.keys():
|
441 |
+
symbols, phonemes = make_symbols(**c.characters)
|
442 |
+
|
443 |
+
# DISTRUBUTED
|
444 |
+
if num_gpus > 1:
|
445 |
+
init_distributed(args.rank, num_gpus, args.group_id,
|
446 |
+
c.distributed["backend"], c.distributed["url"])
|
447 |
+
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
448 |
+
|
449 |
+
# load data instances
|
450 |
+
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
451 |
+
|
452 |
+
# set the portion of the data used for training if set in config.json
|
453 |
+
if 'train_portion' in c.keys():
|
454 |
+
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
455 |
+
if 'eval_portion' in c.keys():
|
456 |
+
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
457 |
+
|
458 |
+
# parse speakers
|
459 |
+
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
460 |
+
|
461 |
+
# setup model
|
462 |
+
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
463 |
+
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
464 |
+
criterion = SpeedySpeechLoss(c)
|
465 |
+
|
466 |
+
if args.restore_path:
|
467 |
+
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
468 |
+
try:
|
469 |
+
# TODO: fix optimizer init, model.cuda() needs to be called before
|
470 |
+
# optimizer restore
|
471 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
472 |
+
if c.reinit_layers:
|
473 |
+
raise RuntimeError
|
474 |
+
model.load_state_dict(checkpoint['model'])
|
475 |
+
except: #pylint: disable=bare-except
|
476 |
+
print(" > Partial model initialization.")
|
477 |
+
model_dict = model.state_dict()
|
478 |
+
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
479 |
+
model.load_state_dict(model_dict)
|
480 |
+
del model_dict
|
481 |
+
|
482 |
+
for group in optimizer.param_groups:
|
483 |
+
group['initial_lr'] = c.lr
|
484 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
485 |
+
flush=True)
|
486 |
+
args.restore_step = checkpoint['step']
|
487 |
+
else:
|
488 |
+
args.restore_step = 0
|
489 |
+
|
490 |
+
if use_cuda:
|
491 |
+
model.cuda()
|
492 |
+
criterion.cuda()
|
493 |
+
|
494 |
+
# DISTRUBUTED
|
495 |
+
if num_gpus > 1:
|
496 |
+
model = DDP_th(model, device_ids=[args.rank])
|
497 |
+
|
498 |
+
if c.noam_schedule:
|
499 |
+
scheduler = NoamLR(optimizer,
|
500 |
+
warmup_steps=c.warmup_steps,
|
501 |
+
last_epoch=args.restore_step - 1)
|
502 |
+
else:
|
503 |
+
scheduler = None
|
504 |
+
|
505 |
+
num_params = count_parameters(model)
|
506 |
+
print("\n > Model has {} parameters".format(num_params), flush=True)
|
507 |
+
|
508 |
+
if 'best_loss' not in locals():
|
509 |
+
best_loss = float('inf')
|
510 |
+
|
511 |
+
# define dataloaders
|
512 |
+
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
513 |
+
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
514 |
+
|
515 |
+
global_step = args.restore_step
|
516 |
+
for epoch in range(0, c.epochs):
|
517 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
518 |
+
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
519 |
+
scheduler, ap, global_step,
|
520 |
+
epoch)
|
521 |
+
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
522 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
523 |
+
target_loss = train_avg_loss_dict['avg_loss']
|
524 |
+
if c.run_eval:
|
525 |
+
target_loss = eval_avg_loss_dict['avg_loss']
|
526 |
+
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
527 |
+
OUT_PATH)
|
528 |
+
|
529 |
+
|
530 |
+
if __name__ == '__main__':
|
531 |
+
parser = argparse.ArgumentParser()
|
532 |
+
parser.add_argument(
|
533 |
+
'--continue_path',
|
534 |
+
type=str,
|
535 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
536 |
+
default='',
|
537 |
+
required='--config_path' not in sys.argv)
|
538 |
+
parser.add_argument(
|
539 |
+
'--restore_path',
|
540 |
+
type=str,
|
541 |
+
help='Model file to be restored. Use to finetune a model.',
|
542 |
+
default='')
|
543 |
+
parser.add_argument(
|
544 |
+
'--config_path',
|
545 |
+
type=str,
|
546 |
+
help='Path to config file for training.',
|
547 |
+
required='--continue_path' not in sys.argv
|
548 |
+
)
|
549 |
+
parser.add_argument('--debug',
|
550 |
+
type=bool,
|
551 |
+
default=False,
|
552 |
+
help='Do not verify commit integrity to run training.')
|
553 |
+
|
554 |
+
# DISTRUBUTED
|
555 |
+
parser.add_argument(
|
556 |
+
'--rank',
|
557 |
+
type=int,
|
558 |
+
default=0,
|
559 |
+
help='DISTRIBUTED: process rank for distributed training.')
|
560 |
+
parser.add_argument('--group_id',
|
561 |
+
type=str,
|
562 |
+
default="",
|
563 |
+
help='DISTRIBUTED: process group id.')
|
564 |
+
args = parser.parse_args()
|
565 |
+
|
566 |
+
if args.continue_path != '':
|
567 |
+
args.output_path = args.continue_path
|
568 |
+
args.config_path = os.path.join(args.continue_path, 'config.json')
|
569 |
+
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
570 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
571 |
+
args.restore_path = latest_model_file
|
572 |
+
print(f" > Training continues for {args.restore_path}")
|
573 |
+
|
574 |
+
# setup output paths and read configs
|
575 |
+
c = load_config(args.config_path)
|
576 |
+
# check_config(c)
|
577 |
+
check_config_tts(c)
|
578 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
579 |
+
|
580 |
+
if c.mixed_precision:
|
581 |
+
print(" > Mixed precision enabled.")
|
582 |
+
|
583 |
+
OUT_PATH = args.continue_path
|
584 |
+
if args.continue_path == '':
|
585 |
+
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
586 |
+
|
587 |
+
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
588 |
+
|
589 |
+
c_logger = ConsoleLogger()
|
590 |
+
|
591 |
+
if args.rank == 0:
|
592 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
593 |
+
new_fields = {}
|
594 |
+
if args.restore_path:
|
595 |
+
new_fields["restore_path"] = args.restore_path
|
596 |
+
new_fields["github_branch"] = get_git_branch()
|
597 |
+
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
598 |
+
os.chmod(AUDIO_PATH, 0o775)
|
599 |
+
os.chmod(OUT_PATH, 0o775)
|
600 |
+
|
601 |
+
LOG_DIR = OUT_PATH
|
602 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
603 |
+
|
604 |
+
# write model desc to tensorboard
|
605 |
+
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
606 |
+
|
607 |
+
try:
|
608 |
+
main(args)
|
609 |
+
except KeyboardInterrupt:
|
610 |
+
remove_experiment_folder(OUT_PATH)
|
611 |
+
try:
|
612 |
+
sys.exit(0)
|
613 |
+
except SystemExit:
|
614 |
+
os._exit(0) # pylint: disable=protected-access
|
615 |
+
except Exception: # pylint: disable=broad-except
|
616 |
+
remove_experiment_folder(OUT_PATH)
|
617 |
+
traceback.print_exc()
|
618 |
+
sys.exit(1)
|
TTS/bin/train_tacotron.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
import traceback
|
10 |
+
from random import randrange
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from TTS.tts.datasets.preprocess import load_meta_data
|
16 |
+
from TTS.tts.datasets.TTSDataset import MyDataset
|
17 |
+
from TTS.tts.layers.losses import TacotronLoss
|
18 |
+
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
19 |
+
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
20 |
+
from TTS.tts.utils.measures import alignment_diagonal_score
|
21 |
+
from TTS.tts.utils.speakers import parse_speakers
|
22 |
+
from TTS.tts.utils.synthesis import synthesis
|
23 |
+
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
24 |
+
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
25 |
+
from TTS.utils.audio import AudioProcessor
|
26 |
+
from TTS.utils.console_logger import ConsoleLogger
|
27 |
+
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
28 |
+
init_distributed, reduce_tensor)
|
29 |
+
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
30 |
+
create_experiment_folder, get_git_branch,
|
31 |
+
remove_experiment_folder, set_init_dict)
|
32 |
+
from TTS.utils.io import copy_model_files, load_config
|
33 |
+
from TTS.utils.radam import RAdam
|
34 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
35 |
+
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
36 |
+
gradual_training_scheduler, set_weight_decay,
|
37 |
+
setup_torch_training_env)
|
38 |
+
|
39 |
+
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
40 |
+
|
41 |
+
|
42 |
+
def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
43 |
+
if is_val and not c.run_eval:
|
44 |
+
loader = None
|
45 |
+
else:
|
46 |
+
if dataset is None:
|
47 |
+
dataset = MyDataset(
|
48 |
+
r,
|
49 |
+
c.text_cleaner,
|
50 |
+
compute_linear_spec=c.model.lower() == 'tacotron',
|
51 |
+
meta_data=meta_data_eval if is_val else meta_data_train,
|
52 |
+
ap=ap,
|
53 |
+
tp=c.characters if 'characters' in c.keys() else None,
|
54 |
+
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
55 |
+
batch_group_size=0 if is_val else c.batch_group_size *
|
56 |
+
c.batch_size,
|
57 |
+
min_seq_len=c.min_seq_len,
|
58 |
+
max_seq_len=c.max_seq_len,
|
59 |
+
phoneme_cache_path=c.phoneme_cache_path,
|
60 |
+
use_phonemes=c.use_phonemes,
|
61 |
+
phoneme_language=c.phoneme_language,
|
62 |
+
enable_eos_bos=c.enable_eos_bos_chars,
|
63 |
+
verbose=verbose,
|
64 |
+
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
65 |
+
|
66 |
+
if c.use_phonemes and c.compute_input_seq_cache:
|
67 |
+
# precompute phonemes to have a better estimate of sequence lengths.
|
68 |
+
dataset.compute_input_seq(c.num_loader_workers)
|
69 |
+
dataset.sort_items()
|
70 |
+
|
71 |
+
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
72 |
+
loader = DataLoader(
|
73 |
+
dataset,
|
74 |
+
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
75 |
+
shuffle=False,
|
76 |
+
collate_fn=dataset.collate_fn,
|
77 |
+
drop_last=False,
|
78 |
+
sampler=sampler,
|
79 |
+
num_workers=c.num_val_loader_workers
|
80 |
+
if is_val else c.num_loader_workers,
|
81 |
+
pin_memory=False)
|
82 |
+
return loader
|
83 |
+
|
84 |
+
def format_data(data):
|
85 |
+
# setup input data
|
86 |
+
text_input = data[0]
|
87 |
+
text_lengths = data[1]
|
88 |
+
speaker_names = data[2]
|
89 |
+
linear_input = data[3] if c.model in ["Tacotron"] else None
|
90 |
+
mel_input = data[4]
|
91 |
+
mel_lengths = data[5]
|
92 |
+
stop_targets = data[6]
|
93 |
+
max_text_length = torch.max(text_lengths.float())
|
94 |
+
max_spec_length = torch.max(mel_lengths.float())
|
95 |
+
|
96 |
+
if c.use_speaker_embedding:
|
97 |
+
if c.use_external_speaker_embedding_file:
|
98 |
+
speaker_embeddings = data[8]
|
99 |
+
speaker_ids = None
|
100 |
+
else:
|
101 |
+
speaker_ids = [
|
102 |
+
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
103 |
+
]
|
104 |
+
speaker_ids = torch.LongTensor(speaker_ids)
|
105 |
+
speaker_embeddings = None
|
106 |
+
else:
|
107 |
+
speaker_embeddings = None
|
108 |
+
speaker_ids = None
|
109 |
+
|
110 |
+
|
111 |
+
# set stop targets view, we predict a single stop token per iteration.
|
112 |
+
stop_targets = stop_targets.view(text_input.shape[0],
|
113 |
+
stop_targets.size(1) // c.r, -1)
|
114 |
+
stop_targets = (stop_targets.sum(2) >
|
115 |
+
0.0).unsqueeze(2).float().squeeze(2)
|
116 |
+
|
117 |
+
# dispatch data to GPU
|
118 |
+
if use_cuda:
|
119 |
+
text_input = text_input.cuda(non_blocking=True)
|
120 |
+
text_lengths = text_lengths.cuda(non_blocking=True)
|
121 |
+
mel_input = mel_input.cuda(non_blocking=True)
|
122 |
+
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
123 |
+
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
|
124 |
+
stop_targets = stop_targets.cuda(non_blocking=True)
|
125 |
+
if speaker_ids is not None:
|
126 |
+
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
127 |
+
if speaker_embeddings is not None:
|
128 |
+
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
129 |
+
|
130 |
+
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
131 |
+
|
132 |
+
|
133 |
+
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
134 |
+
ap, global_step, epoch, scaler, scaler_st):
|
135 |
+
model.train()
|
136 |
+
epoch_time = 0
|
137 |
+
keep_avg = KeepAverage()
|
138 |
+
if use_cuda:
|
139 |
+
batch_n_iter = int(
|
140 |
+
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
141 |
+
else:
|
142 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
143 |
+
end_time = time.time()
|
144 |
+
c_logger.print_train_start()
|
145 |
+
for num_iter, data in enumerate(data_loader):
|
146 |
+
start_time = time.time()
|
147 |
+
|
148 |
+
# format data
|
149 |
+
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data)
|
150 |
+
loader_time = time.time() - end_time
|
151 |
+
|
152 |
+
global_step += 1
|
153 |
+
|
154 |
+
# setup lr
|
155 |
+
if c.noam_schedule:
|
156 |
+
scheduler.step()
|
157 |
+
|
158 |
+
optimizer.zero_grad()
|
159 |
+
if optimizer_st:
|
160 |
+
optimizer_st.zero_grad()
|
161 |
+
|
162 |
+
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
163 |
+
# forward pass model
|
164 |
+
if c.bidirectional_decoder or c.double_decoder_consistency:
|
165 |
+
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
166 |
+
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
167 |
+
else:
|
168 |
+
decoder_output, postnet_output, alignments, stop_tokens = model(
|
169 |
+
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
170 |
+
decoder_backward_output = None
|
171 |
+
alignments_backward = None
|
172 |
+
|
173 |
+
# set the [alignment] lengths wrt reduction factor for guided attention
|
174 |
+
if mel_lengths.max() % model.decoder.r != 0:
|
175 |
+
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
176 |
+
else:
|
177 |
+
alignment_lengths = mel_lengths // model.decoder.r
|
178 |
+
|
179 |
+
# compute loss
|
180 |
+
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
181 |
+
linear_input, stop_tokens, stop_targets,
|
182 |
+
mel_lengths, decoder_backward_output,
|
183 |
+
alignments, alignment_lengths, alignments_backward,
|
184 |
+
text_lengths)
|
185 |
+
|
186 |
+
# check nan loss
|
187 |
+
if torch.isnan(loss_dict['loss']).any():
|
188 |
+
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
189 |
+
|
190 |
+
# optimizer step
|
191 |
+
if c.mixed_precision:
|
192 |
+
# model optimizer step in mixed precision mode
|
193 |
+
scaler.scale(loss_dict['loss']).backward()
|
194 |
+
scaler.unscale_(optimizer)
|
195 |
+
optimizer, current_lr = adam_weight_decay(optimizer)
|
196 |
+
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
197 |
+
scaler.step(optimizer)
|
198 |
+
scaler.update()
|
199 |
+
|
200 |
+
# stopnet optimizer step
|
201 |
+
if c.separate_stopnet:
|
202 |
+
scaler_st.scale( loss_dict['stopnet_loss']).backward()
|
203 |
+
scaler.unscale_(optimizer_st)
|
204 |
+
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
205 |
+
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
206 |
+
scaler_st.step(optimizer)
|
207 |
+
scaler_st.update()
|
208 |
+
else:
|
209 |
+
grad_norm_st = 0
|
210 |
+
else:
|
211 |
+
# main model optimizer step
|
212 |
+
loss_dict['loss'].backward()
|
213 |
+
optimizer, current_lr = adam_weight_decay(optimizer)
|
214 |
+
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
215 |
+
optimizer.step()
|
216 |
+
|
217 |
+
# stopnet optimizer step
|
218 |
+
if c.separate_stopnet:
|
219 |
+
loss_dict['stopnet_loss'].backward()
|
220 |
+
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
221 |
+
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
222 |
+
optimizer_st.step()
|
223 |
+
else:
|
224 |
+
grad_norm_st = 0
|
225 |
+
|
226 |
+
# compute alignment error (the lower the better )
|
227 |
+
align_error = 1 - alignment_diagonal_score(alignments)
|
228 |
+
loss_dict['align_error'] = align_error
|
229 |
+
|
230 |
+
step_time = time.time() - start_time
|
231 |
+
epoch_time += step_time
|
232 |
+
|
233 |
+
# aggregate losses from processes
|
234 |
+
if num_gpus > 1:
|
235 |
+
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
236 |
+
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
237 |
+
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
238 |
+
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
|
239 |
+
|
240 |
+
# detach loss values
|
241 |
+
loss_dict_new = dict()
|
242 |
+
for key, value in loss_dict.items():
|
243 |
+
if isinstance(value, (int, float)):
|
244 |
+
loss_dict_new[key] = value
|
245 |
+
else:
|
246 |
+
loss_dict_new[key] = value.item()
|
247 |
+
loss_dict = loss_dict_new
|
248 |
+
|
249 |
+
# update avg stats
|
250 |
+
update_train_values = dict()
|
251 |
+
for key, value in loss_dict.items():
|
252 |
+
update_train_values['avg_' + key] = value
|
253 |
+
update_train_values['avg_loader_time'] = loader_time
|
254 |
+
update_train_values['avg_step_time'] = step_time
|
255 |
+
keep_avg.update_values(update_train_values)
|
256 |
+
|
257 |
+
# print training progress
|
258 |
+
if global_step % c.print_step == 0:
|
259 |
+
log_dict = {
|
260 |
+
"max_spec_length": [max_spec_length, 1], # value, precision
|
261 |
+
"max_text_length": [max_text_length, 1],
|
262 |
+
"step_time": [step_time, 4],
|
263 |
+
"loader_time": [loader_time, 2],
|
264 |
+
"current_lr": current_lr,
|
265 |
+
}
|
266 |
+
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
267 |
+
log_dict, loss_dict, keep_avg.avg_values)
|
268 |
+
|
269 |
+
if args.rank == 0:
|
270 |
+
# Plot Training Iter Stats
|
271 |
+
# reduce TB load
|
272 |
+
if global_step % c.tb_plot_step == 0:
|
273 |
+
iter_stats = {
|
274 |
+
"lr": current_lr,
|
275 |
+
"grad_norm": grad_norm,
|
276 |
+
"grad_norm_st": grad_norm_st,
|
277 |
+
"step_time": step_time
|
278 |
+
}
|
279 |
+
iter_stats.update(loss_dict)
|
280 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
281 |
+
|
282 |
+
if global_step % c.save_step == 0:
|
283 |
+
if c.checkpoint:
|
284 |
+
# save model
|
285 |
+
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
286 |
+
optimizer_st=optimizer_st,
|
287 |
+
model_loss=loss_dict['postnet_loss'],
|
288 |
+
scaler=scaler.state_dict() if c.mixed_precision else None)
|
289 |
+
|
290 |
+
# Diagnostic visualizations
|
291 |
+
const_spec = postnet_output[0].data.cpu().numpy()
|
292 |
+
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
293 |
+
"Tacotron", "TacotronGST"
|
294 |
+
] else mel_input[0].data.cpu().numpy()
|
295 |
+
align_img = alignments[0].data.cpu().numpy()
|
296 |
+
|
297 |
+
figures = {
|
298 |
+
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
299 |
+
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
300 |
+
"alignment": plot_alignment(align_img, output_fig=False),
|
301 |
+
}
|
302 |
+
|
303 |
+
if c.bidirectional_decoder or c.double_decoder_consistency:
|
304 |
+
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
305 |
+
|
306 |
+
tb_logger.tb_train_figures(global_step, figures)
|
307 |
+
|
308 |
+
# Sample audio
|
309 |
+
if c.model in ["Tacotron", "TacotronGST"]:
|
310 |
+
train_audio = ap.inv_spectrogram(const_spec.T)
|
311 |
+
else:
|
312 |
+
train_audio = ap.inv_melspectrogram(const_spec.T)
|
313 |
+
tb_logger.tb_train_audios(global_step,
|
314 |
+
{'TrainAudio': train_audio},
|
315 |
+
c.audio["sample_rate"])
|
316 |
+
end_time = time.time()
|
317 |
+
|
318 |
+
# print epoch stats
|
319 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
320 |
+
|
321 |
+
# Plot Epoch Stats
|
322 |
+
if args.rank == 0:
|
323 |
+
epoch_stats = {"epoch_time": epoch_time}
|
324 |
+
epoch_stats.update(keep_avg.avg_values)
|
325 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
326 |
+
if c.tb_model_param_stats:
|
327 |
+
tb_logger.tb_model_weights(model, global_step)
|
328 |
+
return keep_avg.avg_values, global_step
|
329 |
+
|
330 |
+
|
331 |
+
@torch.no_grad()
|
332 |
+
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
333 |
+
model.eval()
|
334 |
+
epoch_time = 0
|
335 |
+
keep_avg = KeepAverage()
|
336 |
+
c_logger.print_eval_start()
|
337 |
+
if data_loader is not None:
|
338 |
+
for num_iter, data in enumerate(data_loader):
|
339 |
+
start_time = time.time()
|
340 |
+
|
341 |
+
# format data
|
342 |
+
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data)
|
343 |
+
assert mel_input.shape[1] % model.decoder.r == 0
|
344 |
+
|
345 |
+
# forward pass model
|
346 |
+
if c.bidirectional_decoder or c.double_decoder_consistency:
|
347 |
+
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
348 |
+
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
349 |
+
else:
|
350 |
+
decoder_output, postnet_output, alignments, stop_tokens = model(
|
351 |
+
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
352 |
+
decoder_backward_output = None
|
353 |
+
alignments_backward = None
|
354 |
+
|
355 |
+
# set the alignment lengths wrt reduction factor for guided attention
|
356 |
+
if mel_lengths.max() % model.decoder.r != 0:
|
357 |
+
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
358 |
+
else:
|
359 |
+
alignment_lengths = mel_lengths // model.decoder.r
|
360 |
+
|
361 |
+
# compute loss
|
362 |
+
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
363 |
+
linear_input, stop_tokens, stop_targets,
|
364 |
+
mel_lengths, decoder_backward_output,
|
365 |
+
alignments, alignment_lengths, alignments_backward,
|
366 |
+
text_lengths)
|
367 |
+
|
368 |
+
# step time
|
369 |
+
step_time = time.time() - start_time
|
370 |
+
epoch_time += step_time
|
371 |
+
|
372 |
+
# compute alignment score
|
373 |
+
align_error = 1 - alignment_diagonal_score(alignments)
|
374 |
+
loss_dict['align_error'] = align_error
|
375 |
+
|
376 |
+
# aggregate losses from processes
|
377 |
+
if num_gpus > 1:
|
378 |
+
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
379 |
+
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
380 |
+
if c.stopnet:
|
381 |
+
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
|
382 |
+
|
383 |
+
# detach loss values
|
384 |
+
loss_dict_new = dict()
|
385 |
+
for key, value in loss_dict.items():
|
386 |
+
if isinstance(value, (int, float)):
|
387 |
+
loss_dict_new[key] = value
|
388 |
+
else:
|
389 |
+
loss_dict_new[key] = value.item()
|
390 |
+
loss_dict = loss_dict_new
|
391 |
+
|
392 |
+
# update avg stats
|
393 |
+
update_train_values = dict()
|
394 |
+
for key, value in loss_dict.items():
|
395 |
+
update_train_values['avg_' + key] = value
|
396 |
+
keep_avg.update_values(update_train_values)
|
397 |
+
|
398 |
+
if c.print_eval:
|
399 |
+
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
400 |
+
|
401 |
+
if args.rank == 0:
|
402 |
+
# Diagnostic visualizations
|
403 |
+
idx = np.random.randint(mel_input.shape[0])
|
404 |
+
const_spec = postnet_output[idx].data.cpu().numpy()
|
405 |
+
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
406 |
+
"Tacotron", "TacotronGST"
|
407 |
+
] else mel_input[idx].data.cpu().numpy()
|
408 |
+
align_img = alignments[idx].data.cpu().numpy()
|
409 |
+
|
410 |
+
eval_figures = {
|
411 |
+
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
412 |
+
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
413 |
+
"alignment": plot_alignment(align_img, output_fig=False)
|
414 |
+
}
|
415 |
+
|
416 |
+
# Sample audio
|
417 |
+
if c.model in ["Tacotron", "TacotronGST"]:
|
418 |
+
eval_audio = ap.inv_spectrogram(const_spec.T)
|
419 |
+
else:
|
420 |
+
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
421 |
+
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
422 |
+
c.audio["sample_rate"])
|
423 |
+
|
424 |
+
# Plot Validation Stats
|
425 |
+
|
426 |
+
if c.bidirectional_decoder or c.double_decoder_consistency:
|
427 |
+
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
428 |
+
eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
|
429 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
430 |
+
tb_logger.tb_eval_figures(global_step, eval_figures)
|
431 |
+
|
432 |
+
if args.rank == 0 and epoch > c.test_delay_epochs:
|
433 |
+
if c.test_sentences_file is None:
|
434 |
+
test_sentences = [
|
435 |
+
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
436 |
+
"Be a voice, not an echo.",
|
437 |
+
"I'm sorry Dave. I'm afraid I can't do that.",
|
438 |
+
"This cake is great. It's so delicious and moist.",
|
439 |
+
"Prior to November 22, 1963."
|
440 |
+
]
|
441 |
+
else:
|
442 |
+
with open(c.test_sentences_file, "r") as f:
|
443 |
+
test_sentences = [s.strip() for s in f.readlines()]
|
444 |
+
|
445 |
+
# test sentences
|
446 |
+
test_audios = {}
|
447 |
+
test_figures = {}
|
448 |
+
print(" | > Synthesizing test sentences")
|
449 |
+
speaker_id = 0 if c.use_speaker_embedding else None
|
450 |
+
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None
|
451 |
+
style_wav = c.get("gst_style_input")
|
452 |
+
if style_wav is None and c.use_gst:
|
453 |
+
# inicialize GST with zero dict.
|
454 |
+
style_wav = {}
|
455 |
+
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
456 |
+
for i in range(c.gst['gst_style_tokens']):
|
457 |
+
style_wav[str(i)] = 0
|
458 |
+
style_wav = c.get("gst_style_input")
|
459 |
+
for idx, test_sentence in enumerate(test_sentences):
|
460 |
+
try:
|
461 |
+
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
|
462 |
+
model,
|
463 |
+
test_sentence,
|
464 |
+
c,
|
465 |
+
use_cuda,
|
466 |
+
ap,
|
467 |
+
speaker_id=speaker_id,
|
468 |
+
speaker_embedding=speaker_embedding,
|
469 |
+
style_wav=style_wav,
|
470 |
+
truncated=False,
|
471 |
+
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
472 |
+
use_griffin_lim=True,
|
473 |
+
do_trim_silence=False)
|
474 |
+
|
475 |
+
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
476 |
+
os.makedirs(file_path, exist_ok=True)
|
477 |
+
file_path = os.path.join(file_path,
|
478 |
+
"TestSentence_{}.wav".format(idx))
|
479 |
+
ap.save_wav(wav, file_path)
|
480 |
+
test_audios['{}-audio'.format(idx)] = wav
|
481 |
+
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
482 |
+
postnet_output, ap, output_fig=False)
|
483 |
+
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
484 |
+
alignment, output_fig=False)
|
485 |
+
except: #pylint: disable=bare-except
|
486 |
+
print(" !! Error creating Test Sentence -", idx)
|
487 |
+
traceback.print_exc()
|
488 |
+
tb_logger.tb_test_audios(global_step, test_audios,
|
489 |
+
c.audio['sample_rate'])
|
490 |
+
tb_logger.tb_test_figures(global_step, test_figures)
|
491 |
+
return keep_avg.avg_values
|
492 |
+
|
493 |
+
|
494 |
+
# FIXME: move args definition/parsing inside of main?
|
495 |
+
def main(args): # pylint: disable=redefined-outer-name
|
496 |
+
# pylint: disable=global-variable-undefined
|
497 |
+
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
498 |
+
# Audio processor
|
499 |
+
ap = AudioProcessor(**c.audio)
|
500 |
+
if 'characters' in c.keys():
|
501 |
+
symbols, phonemes = make_symbols(**c.characters)
|
502 |
+
|
503 |
+
# DISTRUBUTED
|
504 |
+
if num_gpus > 1:
|
505 |
+
init_distributed(args.rank, num_gpus, args.group_id,
|
506 |
+
c.distributed["backend"], c.distributed["url"])
|
507 |
+
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
508 |
+
|
509 |
+
# load data instances
|
510 |
+
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
511 |
+
|
512 |
+
# set the portion of the data used for training
|
513 |
+
if 'train_portion' in c.keys():
|
514 |
+
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
515 |
+
if 'eval_portion' in c.keys():
|
516 |
+
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
517 |
+
|
518 |
+
# parse speakers
|
519 |
+
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
520 |
+
|
521 |
+
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
|
522 |
+
|
523 |
+
# scalers for mixed precision training
|
524 |
+
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
525 |
+
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
|
526 |
+
|
527 |
+
params = set_weight_decay(model, c.wd)
|
528 |
+
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
529 |
+
if c.stopnet and c.separate_stopnet:
|
530 |
+
optimizer_st = RAdam(model.decoder.stopnet.parameters(),
|
531 |
+
lr=c.lr,
|
532 |
+
weight_decay=0)
|
533 |
+
else:
|
534 |
+
optimizer_st = None
|
535 |
+
|
536 |
+
# setup criterion
|
537 |
+
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
538 |
+
|
539 |
+
if args.restore_path:
|
540 |
+
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
541 |
+
try:
|
542 |
+
print(" > Restoring Model.")
|
543 |
+
model.load_state_dict(checkpoint['model'])
|
544 |
+
# optimizer restore
|
545 |
+
print(" > Restoring Optimizer.")
|
546 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
547 |
+
if "scaler" in checkpoint and c.mixed_precision:
|
548 |
+
print(" > Restoring AMP Scaler...")
|
549 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
550 |
+
if c.reinit_layers:
|
551 |
+
raise RuntimeError
|
552 |
+
except (KeyError, RuntimeError):
|
553 |
+
print(" > Partial model initialization.")
|
554 |
+
model_dict = model.state_dict()
|
555 |
+
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
556 |
+
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
|
557 |
+
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
|
558 |
+
model.load_state_dict(model_dict)
|
559 |
+
del model_dict
|
560 |
+
|
561 |
+
for group in optimizer.param_groups:
|
562 |
+
group['lr'] = c.lr
|
563 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
564 |
+
flush=True)
|
565 |
+
args.restore_step = checkpoint['step']
|
566 |
+
else:
|
567 |
+
args.restore_step = 0
|
568 |
+
|
569 |
+
if use_cuda:
|
570 |
+
model.cuda()
|
571 |
+
criterion.cuda()
|
572 |
+
|
573 |
+
# DISTRUBUTED
|
574 |
+
if num_gpus > 1:
|
575 |
+
model = apply_gradient_allreduce(model)
|
576 |
+
|
577 |
+
if c.noam_schedule:
|
578 |
+
scheduler = NoamLR(optimizer,
|
579 |
+
warmup_steps=c.warmup_steps,
|
580 |
+
last_epoch=args.restore_step - 1)
|
581 |
+
else:
|
582 |
+
scheduler = None
|
583 |
+
|
584 |
+
num_params = count_parameters(model)
|
585 |
+
print("\n > Model has {} parameters".format(num_params), flush=True)
|
586 |
+
|
587 |
+
if 'best_loss' not in locals():
|
588 |
+
best_loss = float('inf')
|
589 |
+
|
590 |
+
# define data loaders
|
591 |
+
train_loader = setup_loader(ap,
|
592 |
+
model.decoder.r,
|
593 |
+
is_val=False,
|
594 |
+
verbose=True)
|
595 |
+
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
596 |
+
|
597 |
+
global_step = args.restore_step
|
598 |
+
for epoch in range(0, c.epochs):
|
599 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
600 |
+
# set gradual training
|
601 |
+
if c.gradual_training is not None:
|
602 |
+
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
603 |
+
c.r = r
|
604 |
+
model.decoder.set_r(r)
|
605 |
+
if c.bidirectional_decoder:
|
606 |
+
model.decoder_backward.set_r(r)
|
607 |
+
train_loader.dataset.outputs_per_step = r
|
608 |
+
eval_loader.dataset.outputs_per_step = r
|
609 |
+
train_loader = setup_loader(ap,
|
610 |
+
model.decoder.r,
|
611 |
+
is_val=False,
|
612 |
+
dataset=train_loader.dataset)
|
613 |
+
eval_loader = setup_loader(ap,
|
614 |
+
model.decoder.r,
|
615 |
+
is_val=True,
|
616 |
+
dataset=eval_loader.dataset)
|
617 |
+
print("\n > Number of output frames:", model.decoder.r)
|
618 |
+
# train one epoch
|
619 |
+
train_avg_loss_dict, global_step = train(train_loader, model,
|
620 |
+
criterion, optimizer,
|
621 |
+
optimizer_st, scheduler, ap,
|
622 |
+
global_step, epoch, scaler,
|
623 |
+
scaler_st)
|
624 |
+
# eval one epoch
|
625 |
+
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
626 |
+
global_step, epoch)
|
627 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
628 |
+
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
629 |
+
if c.run_eval:
|
630 |
+
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
631 |
+
best_loss = save_best_model(
|
632 |
+
target_loss,
|
633 |
+
best_loss,
|
634 |
+
model,
|
635 |
+
optimizer,
|
636 |
+
global_step,
|
637 |
+
epoch,
|
638 |
+
c.r,
|
639 |
+
OUT_PATH,
|
640 |
+
scaler=scaler.state_dict() if c.mixed_precision else None)
|
641 |
+
|
642 |
+
|
643 |
+
if __name__ == '__main__':
|
644 |
+
parser = argparse.ArgumentParser()
|
645 |
+
parser.add_argument(
|
646 |
+
'--continue_path',
|
647 |
+
type=str,
|
648 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
649 |
+
default='',
|
650 |
+
required='--config_path' not in sys.argv)
|
651 |
+
parser.add_argument(
|
652 |
+
'--restore_path',
|
653 |
+
type=str,
|
654 |
+
help='Model file to be restored. Use to finetune a model.',
|
655 |
+
default='')
|
656 |
+
parser.add_argument(
|
657 |
+
'--config_path',
|
658 |
+
type=str,
|
659 |
+
help='Path to config file for training.',
|
660 |
+
required='--continue_path' not in sys.argv
|
661 |
+
)
|
662 |
+
parser.add_argument('--debug',
|
663 |
+
type=bool,
|
664 |
+
default=False,
|
665 |
+
help='Do not verify commit integrity to run training.')
|
666 |
+
|
667 |
+
# DISTRUBUTED
|
668 |
+
parser.add_argument(
|
669 |
+
'--rank',
|
670 |
+
type=int,
|
671 |
+
default=0,
|
672 |
+
help='DISTRIBUTED: process rank for distributed training.')
|
673 |
+
parser.add_argument('--group_id',
|
674 |
+
type=str,
|
675 |
+
default="",
|
676 |
+
help='DISTRIBUTED: process group id.')
|
677 |
+
args = parser.parse_args()
|
678 |
+
|
679 |
+
if args.continue_path != '':
|
680 |
+
print(f" > Training continues for {args.continue_path}")
|
681 |
+
args.output_path = args.continue_path
|
682 |
+
args.config_path = os.path.join(args.continue_path, 'config.json')
|
683 |
+
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
684 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
685 |
+
args.restore_path = latest_model_file
|
686 |
+
|
687 |
+
# setup output paths and read configs
|
688 |
+
c = load_config(args.config_path)
|
689 |
+
check_config_tts(c)
|
690 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
691 |
+
|
692 |
+
if c.mixed_precision:
|
693 |
+
print(" > Mixed precision mode is ON")
|
694 |
+
|
695 |
+
OUT_PATH = args.continue_path
|
696 |
+
if args.continue_path == '':
|
697 |
+
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
698 |
+
|
699 |
+
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
700 |
+
|
701 |
+
c_logger = ConsoleLogger()
|
702 |
+
|
703 |
+
if args.rank == 0:
|
704 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
705 |
+
new_fields = {}
|
706 |
+
if args.restore_path:
|
707 |
+
new_fields["restore_path"] = args.restore_path
|
708 |
+
new_fields["github_branch"] = get_git_branch()
|
709 |
+
copy_model_files(c, args.config_path,
|
710 |
+
OUT_PATH, new_fields)
|
711 |
+
os.chmod(AUDIO_PATH, 0o775)
|
712 |
+
os.chmod(OUT_PATH, 0o775)
|
713 |
+
|
714 |
+
LOG_DIR = OUT_PATH
|
715 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
716 |
+
|
717 |
+
# write model desc to tensorboard
|
718 |
+
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
719 |
+
|
720 |
+
try:
|
721 |
+
main(args)
|
722 |
+
except KeyboardInterrupt:
|
723 |
+
remove_experiment_folder(OUT_PATH)
|
724 |
+
try:
|
725 |
+
sys.exit(0)
|
726 |
+
except SystemExit:
|
727 |
+
os._exit(0) # pylint: disable=protected-access
|
728 |
+
except Exception: # pylint: disable=broad-except
|
729 |
+
remove_experiment_folder(OUT_PATH)
|
730 |
+
traceback.print_exc()
|
731 |
+
sys.exit(1)
|
TTS/bin/train_vocoder_gan.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
from inspect import signature
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from TTS.utils.audio import AudioProcessor
|
12 |
+
from TTS.utils.console_logger import ConsoleLogger
|
13 |
+
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
14 |
+
create_experiment_folder, get_git_branch,
|
15 |
+
remove_experiment_folder, set_init_dict)
|
16 |
+
from TTS.utils.io import copy_model_files, load_config
|
17 |
+
from TTS.utils.radam import RAdam
|
18 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
19 |
+
from TTS.utils.training import setup_torch_training_env
|
20 |
+
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
21 |
+
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
22 |
+
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
23 |
+
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
|
24 |
+
setup_generator)
|
25 |
+
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
26 |
+
|
27 |
+
# DISTRIBUTED
|
28 |
+
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
29 |
+
from torch.utils.data.distributed import DistributedSampler
|
30 |
+
from TTS.utils.distribute import init_distributed
|
31 |
+
|
32 |
+
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
33 |
+
|
34 |
+
|
35 |
+
def setup_loader(ap, is_val=False, verbose=False):
|
36 |
+
if is_val and not c.run_eval:
|
37 |
+
loader = None
|
38 |
+
else:
|
39 |
+
dataset = GANDataset(ap=ap,
|
40 |
+
items=eval_data if is_val else train_data,
|
41 |
+
seq_len=c.seq_len,
|
42 |
+
hop_len=ap.hop_length,
|
43 |
+
pad_short=c.pad_short,
|
44 |
+
conv_pad=c.conv_pad,
|
45 |
+
is_training=not is_val,
|
46 |
+
return_segments=not is_val,
|
47 |
+
use_noise_augment=c.use_noise_augment,
|
48 |
+
use_cache=c.use_cache,
|
49 |
+
verbose=verbose)
|
50 |
+
dataset.shuffle_mapping()
|
51 |
+
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
52 |
+
loader = DataLoader(dataset,
|
53 |
+
batch_size=1 if is_val else c.batch_size,
|
54 |
+
shuffle=False if num_gpus > 1 else True,
|
55 |
+
drop_last=False,
|
56 |
+
sampler=sampler,
|
57 |
+
num_workers=c.num_val_loader_workers
|
58 |
+
if is_val else c.num_loader_workers,
|
59 |
+
pin_memory=False)
|
60 |
+
return loader
|
61 |
+
|
62 |
+
|
63 |
+
def format_data(data):
|
64 |
+
if isinstance(data[0], list):
|
65 |
+
# setup input data
|
66 |
+
c_G, x_G = data[0]
|
67 |
+
c_D, x_D = data[1]
|
68 |
+
|
69 |
+
# dispatch data to GPU
|
70 |
+
if use_cuda:
|
71 |
+
c_G = c_G.cuda(non_blocking=True)
|
72 |
+
x_G = x_G.cuda(non_blocking=True)
|
73 |
+
c_D = c_D.cuda(non_blocking=True)
|
74 |
+
x_D = x_D.cuda(non_blocking=True)
|
75 |
+
|
76 |
+
return c_G, x_G, c_D, x_D
|
77 |
+
|
78 |
+
# return a whole audio segment
|
79 |
+
co, x = data
|
80 |
+
if use_cuda:
|
81 |
+
co = co.cuda(non_blocking=True)
|
82 |
+
x = x.cuda(non_blocking=True)
|
83 |
+
return co, x, None, None
|
84 |
+
|
85 |
+
|
86 |
+
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
87 |
+
scheduler_G, scheduler_D, ap, global_step, epoch):
|
88 |
+
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
89 |
+
model_G.train()
|
90 |
+
model_D.train()
|
91 |
+
epoch_time = 0
|
92 |
+
keep_avg = KeepAverage()
|
93 |
+
if use_cuda:
|
94 |
+
batch_n_iter = int(
|
95 |
+
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
96 |
+
else:
|
97 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
98 |
+
end_time = time.time()
|
99 |
+
c_logger.print_train_start()
|
100 |
+
for num_iter, data in enumerate(data_loader):
|
101 |
+
start_time = time.time()
|
102 |
+
|
103 |
+
# format data
|
104 |
+
c_G, y_G, c_D, y_D = format_data(data)
|
105 |
+
loader_time = time.time() - end_time
|
106 |
+
|
107 |
+
global_step += 1
|
108 |
+
|
109 |
+
##############################
|
110 |
+
# GENERATOR
|
111 |
+
##############################
|
112 |
+
|
113 |
+
# generator pass
|
114 |
+
y_hat = model_G(c_G)
|
115 |
+
y_hat_sub = None
|
116 |
+
y_G_sub = None
|
117 |
+
y_hat_vis = y_hat # for visualization
|
118 |
+
|
119 |
+
# PQMF formatting
|
120 |
+
if y_hat.shape[1] > 1:
|
121 |
+
y_hat_sub = y_hat
|
122 |
+
y_hat = model_G.pqmf_synthesis(y_hat)
|
123 |
+
y_hat_vis = y_hat
|
124 |
+
y_G_sub = model_G.pqmf_analysis(y_G)
|
125 |
+
|
126 |
+
scores_fake, feats_fake, feats_real = None, None, None
|
127 |
+
if global_step > c.steps_to_start_discriminator:
|
128 |
+
|
129 |
+
# run D with or without cond. features
|
130 |
+
if len(signature(model_D.forward).parameters) == 2:
|
131 |
+
D_out_fake = model_D(y_hat, c_G)
|
132 |
+
else:
|
133 |
+
D_out_fake = model_D(y_hat)
|
134 |
+
D_out_real = None
|
135 |
+
|
136 |
+
if c.use_feat_match_loss:
|
137 |
+
with torch.no_grad():
|
138 |
+
D_out_real = model_D(y_G)
|
139 |
+
|
140 |
+
# format D outputs
|
141 |
+
if isinstance(D_out_fake, tuple):
|
142 |
+
scores_fake, feats_fake = D_out_fake
|
143 |
+
if D_out_real is None:
|
144 |
+
feats_real = None
|
145 |
+
else:
|
146 |
+
_, feats_real = D_out_real
|
147 |
+
else:
|
148 |
+
scores_fake = D_out_fake
|
149 |
+
|
150 |
+
# compute losses
|
151 |
+
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
152 |
+
feats_real, y_hat_sub, y_G_sub)
|
153 |
+
loss_G = loss_G_dict['G_loss']
|
154 |
+
|
155 |
+
# optimizer generator
|
156 |
+
optimizer_G.zero_grad()
|
157 |
+
loss_G.backward()
|
158 |
+
if c.gen_clip_grad > 0:
|
159 |
+
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
160 |
+
c.gen_clip_grad)
|
161 |
+
optimizer_G.step()
|
162 |
+
if scheduler_G is not None:
|
163 |
+
scheduler_G.step()
|
164 |
+
|
165 |
+
loss_dict = dict()
|
166 |
+
for key, value in loss_G_dict.items():
|
167 |
+
if isinstance(value, int):
|
168 |
+
loss_dict[key] = value
|
169 |
+
else:
|
170 |
+
loss_dict[key] = value.item()
|
171 |
+
|
172 |
+
##############################
|
173 |
+
# DISCRIMINATOR
|
174 |
+
##############################
|
175 |
+
if global_step >= c.steps_to_start_discriminator:
|
176 |
+
# discriminator pass
|
177 |
+
with torch.no_grad():
|
178 |
+
y_hat = model_G(c_D)
|
179 |
+
|
180 |
+
# PQMF formatting
|
181 |
+
if y_hat.shape[1] > 1:
|
182 |
+
y_hat = model_G.pqmf_synthesis(y_hat)
|
183 |
+
|
184 |
+
# run D with or without cond. features
|
185 |
+
if len(signature(model_D.forward).parameters) == 2:
|
186 |
+
D_out_fake = model_D(y_hat.detach(), c_D)
|
187 |
+
D_out_real = model_D(y_D, c_D)
|
188 |
+
else:
|
189 |
+
D_out_fake = model_D(y_hat.detach())
|
190 |
+
D_out_real = model_D(y_D)
|
191 |
+
|
192 |
+
# format D outputs
|
193 |
+
if isinstance(D_out_fake, tuple):
|
194 |
+
scores_fake, feats_fake = D_out_fake
|
195 |
+
if D_out_real is None:
|
196 |
+
scores_real, feats_real = None, None
|
197 |
+
else:
|
198 |
+
scores_real, feats_real = D_out_real
|
199 |
+
else:
|
200 |
+
scores_fake = D_out_fake
|
201 |
+
scores_real = D_out_real
|
202 |
+
|
203 |
+
# compute losses
|
204 |
+
loss_D_dict = criterion_D(scores_fake, scores_real)
|
205 |
+
loss_D = loss_D_dict['D_loss']
|
206 |
+
|
207 |
+
# optimizer discriminator
|
208 |
+
optimizer_D.zero_grad()
|
209 |
+
loss_D.backward()
|
210 |
+
if c.disc_clip_grad > 0:
|
211 |
+
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
212 |
+
c.disc_clip_grad)
|
213 |
+
optimizer_D.step()
|
214 |
+
if scheduler_D is not None:
|
215 |
+
scheduler_D.step()
|
216 |
+
|
217 |
+
for key, value in loss_D_dict.items():
|
218 |
+
if isinstance(value, (int, float)):
|
219 |
+
loss_dict[key] = value
|
220 |
+
else:
|
221 |
+
loss_dict[key] = value.item()
|
222 |
+
|
223 |
+
step_time = time.time() - start_time
|
224 |
+
epoch_time += step_time
|
225 |
+
|
226 |
+
# get current learning rates
|
227 |
+
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
228 |
+
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
229 |
+
|
230 |
+
# update avg stats
|
231 |
+
update_train_values = dict()
|
232 |
+
for key, value in loss_dict.items():
|
233 |
+
update_train_values['avg_' + key] = value
|
234 |
+
update_train_values['avg_loader_time'] = loader_time
|
235 |
+
update_train_values['avg_step_time'] = step_time
|
236 |
+
keep_avg.update_values(update_train_values)
|
237 |
+
|
238 |
+
# print training stats
|
239 |
+
if global_step % c.print_step == 0:
|
240 |
+
log_dict = {
|
241 |
+
'step_time': [step_time, 2],
|
242 |
+
'loader_time': [loader_time, 4],
|
243 |
+
"current_lr_G": current_lr_G,
|
244 |
+
"current_lr_D": current_lr_D
|
245 |
+
}
|
246 |
+
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
247 |
+
log_dict, loss_dict, keep_avg.avg_values)
|
248 |
+
|
249 |
+
if args.rank == 0:
|
250 |
+
# plot step stats
|
251 |
+
if global_step % 10 == 0:
|
252 |
+
iter_stats = {
|
253 |
+
"lr_G": current_lr_G,
|
254 |
+
"lr_D": current_lr_D,
|
255 |
+
"step_time": step_time
|
256 |
+
}
|
257 |
+
iter_stats.update(loss_dict)
|
258 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
259 |
+
|
260 |
+
# save checkpoint
|
261 |
+
if global_step % c.save_step == 0:
|
262 |
+
if c.checkpoint:
|
263 |
+
# save model
|
264 |
+
save_checkpoint(model_G,
|
265 |
+
optimizer_G,
|
266 |
+
scheduler_G,
|
267 |
+
model_D,
|
268 |
+
optimizer_D,
|
269 |
+
scheduler_D,
|
270 |
+
global_step,
|
271 |
+
epoch,
|
272 |
+
OUT_PATH,
|
273 |
+
model_losses=loss_dict)
|
274 |
+
|
275 |
+
# compute spectrograms
|
276 |
+
figures = plot_results(y_hat_vis, y_G, ap, global_step,
|
277 |
+
'train')
|
278 |
+
tb_logger.tb_train_figures(global_step, figures)
|
279 |
+
|
280 |
+
# Sample audio
|
281 |
+
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
282 |
+
tb_logger.tb_train_audios(global_step,
|
283 |
+
{'train/audio': sample_voice},
|
284 |
+
c.audio["sample_rate"])
|
285 |
+
end_time = time.time()
|
286 |
+
|
287 |
+
# print epoch stats
|
288 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
289 |
+
|
290 |
+
# Plot Training Epoch Stats
|
291 |
+
epoch_stats = {"epoch_time": epoch_time}
|
292 |
+
epoch_stats.update(keep_avg.avg_values)
|
293 |
+
if args.rank == 0:
|
294 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
295 |
+
# TODO: plot model stats
|
296 |
+
# if c.tb_model_param_stats:
|
297 |
+
# tb_logger.tb_model_weights(model, global_step)
|
298 |
+
return keep_avg.avg_values, global_step
|
299 |
+
|
300 |
+
|
301 |
+
@torch.no_grad()
|
302 |
+
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
|
303 |
+
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
304 |
+
model_G.eval()
|
305 |
+
model_D.eval()
|
306 |
+
epoch_time = 0
|
307 |
+
keep_avg = KeepAverage()
|
308 |
+
end_time = time.time()
|
309 |
+
c_logger.print_eval_start()
|
310 |
+
for num_iter, data in enumerate(data_loader):
|
311 |
+
start_time = time.time()
|
312 |
+
|
313 |
+
# format data
|
314 |
+
c_G, y_G, _, _ = format_data(data)
|
315 |
+
loader_time = time.time() - end_time
|
316 |
+
|
317 |
+
global_step += 1
|
318 |
+
|
319 |
+
##############################
|
320 |
+
# GENERATOR
|
321 |
+
##############################
|
322 |
+
|
323 |
+
# generator pass
|
324 |
+
y_hat = model_G(c_G)
|
325 |
+
y_hat_sub = None
|
326 |
+
y_G_sub = None
|
327 |
+
|
328 |
+
# PQMF formatting
|
329 |
+
if y_hat.shape[1] > 1:
|
330 |
+
y_hat_sub = y_hat
|
331 |
+
y_hat = model_G.pqmf_synthesis(y_hat)
|
332 |
+
y_G_sub = model_G.pqmf_analysis(y_G)
|
333 |
+
|
334 |
+
scores_fake, feats_fake, feats_real = None, None, None
|
335 |
+
if global_step > c.steps_to_start_discriminator:
|
336 |
+
|
337 |
+
if len(signature(model_D.forward).parameters) == 2:
|
338 |
+
D_out_fake = model_D(y_hat, c_G)
|
339 |
+
else:
|
340 |
+
D_out_fake = model_D(y_hat)
|
341 |
+
D_out_real = None
|
342 |
+
|
343 |
+
if c.use_feat_match_loss:
|
344 |
+
with torch.no_grad():
|
345 |
+
D_out_real = model_D(y_G)
|
346 |
+
|
347 |
+
# format D outputs
|
348 |
+
if isinstance(D_out_fake, tuple):
|
349 |
+
scores_fake, feats_fake = D_out_fake
|
350 |
+
if D_out_real is None:
|
351 |
+
feats_real = None
|
352 |
+
else:
|
353 |
+
_, feats_real = D_out_real
|
354 |
+
else:
|
355 |
+
scores_fake = D_out_fake
|
356 |
+
feats_fake, feats_real = None, None
|
357 |
+
|
358 |
+
# compute losses
|
359 |
+
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
360 |
+
feats_real, y_hat_sub, y_G_sub)
|
361 |
+
|
362 |
+
loss_dict = dict()
|
363 |
+
for key, value in loss_G_dict.items():
|
364 |
+
if isinstance(value, (int, float)):
|
365 |
+
loss_dict[key] = value
|
366 |
+
else:
|
367 |
+
loss_dict[key] = value.item()
|
368 |
+
|
369 |
+
##############################
|
370 |
+
# DISCRIMINATOR
|
371 |
+
##############################
|
372 |
+
|
373 |
+
if global_step >= c.steps_to_start_discriminator:
|
374 |
+
# discriminator pass
|
375 |
+
with torch.no_grad():
|
376 |
+
y_hat = model_G(c_G)
|
377 |
+
|
378 |
+
# PQMF formatting
|
379 |
+
if y_hat.shape[1] > 1:
|
380 |
+
y_hat = model_G.pqmf_synthesis(y_hat)
|
381 |
+
|
382 |
+
# run D with or without cond. features
|
383 |
+
if len(signature(model_D.forward).parameters) == 2:
|
384 |
+
D_out_fake = model_D(y_hat.detach(), c_G)
|
385 |
+
D_out_real = model_D(y_G, c_G)
|
386 |
+
else:
|
387 |
+
D_out_fake = model_D(y_hat.detach())
|
388 |
+
D_out_real = model_D(y_G)
|
389 |
+
|
390 |
+
# format D outputs
|
391 |
+
if isinstance(D_out_fake, tuple):
|
392 |
+
scores_fake, feats_fake = D_out_fake
|
393 |
+
if D_out_real is None:
|
394 |
+
scores_real, feats_real = None, None
|
395 |
+
else:
|
396 |
+
scores_real, feats_real = D_out_real
|
397 |
+
else:
|
398 |
+
scores_fake = D_out_fake
|
399 |
+
scores_real = D_out_real
|
400 |
+
|
401 |
+
# compute losses
|
402 |
+
loss_D_dict = criterion_D(scores_fake, scores_real)
|
403 |
+
|
404 |
+
for key, value in loss_D_dict.items():
|
405 |
+
if isinstance(value, (int, float)):
|
406 |
+
loss_dict[key] = value
|
407 |
+
else:
|
408 |
+
loss_dict[key] = value.item()
|
409 |
+
|
410 |
+
step_time = time.time() - start_time
|
411 |
+
epoch_time += step_time
|
412 |
+
|
413 |
+
# update avg stats
|
414 |
+
update_eval_values = dict()
|
415 |
+
for key, value in loss_dict.items():
|
416 |
+
update_eval_values['avg_' + key] = value
|
417 |
+
update_eval_values['avg_loader_time'] = loader_time
|
418 |
+
update_eval_values['avg_step_time'] = step_time
|
419 |
+
keep_avg.update_values(update_eval_values)
|
420 |
+
|
421 |
+
# print eval stats
|
422 |
+
if c.print_eval:
|
423 |
+
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
424 |
+
|
425 |
+
if args.rank == 0:
|
426 |
+
# compute spectrograms
|
427 |
+
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
428 |
+
tb_logger.tb_eval_figures(global_step, figures)
|
429 |
+
|
430 |
+
# Sample audio
|
431 |
+
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
432 |
+
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
433 |
+
c.audio["sample_rate"])
|
434 |
+
|
435 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
436 |
+
|
437 |
+
# synthesize a full voice
|
438 |
+
data_loader.return_segments = False
|
439 |
+
|
440 |
+
return keep_avg.avg_values
|
441 |
+
|
442 |
+
|
443 |
+
# FIXME: move args definition/parsing inside of main?
|
444 |
+
def main(args): # pylint: disable=redefined-outer-name
|
445 |
+
# pylint: disable=global-variable-undefined
|
446 |
+
global train_data, eval_data
|
447 |
+
print(f" > Loading wavs from: {c.data_path}")
|
448 |
+
if c.feature_path is not None:
|
449 |
+
print(f" > Loading features from: {c.feature_path}")
|
450 |
+
eval_data, train_data = load_wav_feat_data(
|
451 |
+
c.data_path, c.feature_path, c.eval_split_size)
|
452 |
+
else:
|
453 |
+
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
454 |
+
|
455 |
+
# setup audio processor
|
456 |
+
ap = AudioProcessor(**c.audio)
|
457 |
+
|
458 |
+
# DISTRUBUTED
|
459 |
+
if num_gpus > 1:
|
460 |
+
init_distributed(args.rank, num_gpus, args.group_id,
|
461 |
+
c.distributed["backend"], c.distributed["url"])
|
462 |
+
|
463 |
+
# setup models
|
464 |
+
model_gen = setup_generator(c)
|
465 |
+
model_disc = setup_discriminator(c)
|
466 |
+
|
467 |
+
# setup optimizers
|
468 |
+
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
|
469 |
+
optimizer_disc = RAdam(model_disc.parameters(),
|
470 |
+
lr=c.lr_disc,
|
471 |
+
weight_decay=0)
|
472 |
+
|
473 |
+
# schedulers
|
474 |
+
scheduler_gen = None
|
475 |
+
scheduler_disc = None
|
476 |
+
if 'lr_scheduler_gen' in c:
|
477 |
+
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
478 |
+
scheduler_gen = scheduler_gen(
|
479 |
+
optimizer_gen, **c.lr_scheduler_gen_params)
|
480 |
+
if 'lr_scheduler_disc' in c:
|
481 |
+
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
482 |
+
scheduler_disc = scheduler_disc(
|
483 |
+
optimizer_disc, **c.lr_scheduler_disc_params)
|
484 |
+
|
485 |
+
# setup criterion
|
486 |
+
criterion_gen = GeneratorLoss(c)
|
487 |
+
criterion_disc = DiscriminatorLoss(c)
|
488 |
+
|
489 |
+
if args.restore_path:
|
490 |
+
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
491 |
+
try:
|
492 |
+
print(" > Restoring Generator Model...")
|
493 |
+
model_gen.load_state_dict(checkpoint['model'])
|
494 |
+
print(" > Restoring Generator Optimizer...")
|
495 |
+
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
496 |
+
print(" > Restoring Discriminator Model...")
|
497 |
+
model_disc.load_state_dict(checkpoint['model_disc'])
|
498 |
+
print(" > Restoring Discriminator Optimizer...")
|
499 |
+
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
500 |
+
if 'scheduler' in checkpoint:
|
501 |
+
print(" > Restoring Generator LR Scheduler...")
|
502 |
+
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
503 |
+
# NOTE: Not sure if necessary
|
504 |
+
scheduler_gen.optimizer = optimizer_gen
|
505 |
+
if 'scheduler_disc' in checkpoint:
|
506 |
+
print(" > Restoring Discriminator LR Scheduler...")
|
507 |
+
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
508 |
+
scheduler_disc.optimizer = optimizer_disc
|
509 |
+
except RuntimeError:
|
510 |
+
# retore only matching layers.
|
511 |
+
print(" > Partial model initialization...")
|
512 |
+
model_dict = model_gen.state_dict()
|
513 |
+
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
514 |
+
model_gen.load_state_dict(model_dict)
|
515 |
+
|
516 |
+
model_dict = model_disc.state_dict()
|
517 |
+
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
|
518 |
+
model_disc.load_state_dict(model_dict)
|
519 |
+
del model_dict
|
520 |
+
|
521 |
+
# reset lr if not countinuining training.
|
522 |
+
for group in optimizer_gen.param_groups:
|
523 |
+
group['lr'] = c.lr_gen
|
524 |
+
|
525 |
+
for group in optimizer_disc.param_groups:
|
526 |
+
group['lr'] = c.lr_disc
|
527 |
+
|
528 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
529 |
+
flush=True)
|
530 |
+
args.restore_step = checkpoint['step']
|
531 |
+
else:
|
532 |
+
args.restore_step = 0
|
533 |
+
|
534 |
+
if use_cuda:
|
535 |
+
model_gen.cuda()
|
536 |
+
criterion_gen.cuda()
|
537 |
+
model_disc.cuda()
|
538 |
+
criterion_disc.cuda()
|
539 |
+
|
540 |
+
# DISTRUBUTED
|
541 |
+
if num_gpus > 1:
|
542 |
+
model_gen = DDP_th(model_gen, device_ids=[args.rank])
|
543 |
+
model_disc = DDP_th(model_disc, device_ids=[args.rank])
|
544 |
+
|
545 |
+
num_params = count_parameters(model_gen)
|
546 |
+
print(" > Generator has {} parameters".format(num_params), flush=True)
|
547 |
+
num_params = count_parameters(model_disc)
|
548 |
+
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
549 |
+
|
550 |
+
if 'best_loss' not in locals():
|
551 |
+
best_loss = float('inf')
|
552 |
+
|
553 |
+
global_step = args.restore_step
|
554 |
+
for epoch in range(0, c.epochs):
|
555 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
556 |
+
_, global_step = train(model_gen, criterion_gen, optimizer_gen,
|
557 |
+
model_disc, criterion_disc, optimizer_disc,
|
558 |
+
scheduler_gen, scheduler_disc, ap, global_step,
|
559 |
+
epoch)
|
560 |
+
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
|
561 |
+
global_step, epoch)
|
562 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
563 |
+
target_loss = eval_avg_loss_dict[c.target_loss]
|
564 |
+
best_loss = save_best_model(target_loss,
|
565 |
+
best_loss,
|
566 |
+
model_gen,
|
567 |
+
optimizer_gen,
|
568 |
+
scheduler_gen,
|
569 |
+
model_disc,
|
570 |
+
optimizer_disc,
|
571 |
+
scheduler_disc,
|
572 |
+
global_step,
|
573 |
+
epoch,
|
574 |
+
OUT_PATH,
|
575 |
+
model_losses=eval_avg_loss_dict)
|
576 |
+
|
577 |
+
|
578 |
+
if __name__ == '__main__':
|
579 |
+
parser = argparse.ArgumentParser()
|
580 |
+
parser.add_argument(
|
581 |
+
'--continue_path',
|
582 |
+
type=str,
|
583 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
584 |
+
default='',
|
585 |
+
required='--config_path' not in sys.argv)
|
586 |
+
parser.add_argument(
|
587 |
+
'--restore_path',
|
588 |
+
type=str,
|
589 |
+
help='Model file to be restored. Use to finetune a model.',
|
590 |
+
default='')
|
591 |
+
parser.add_argument('--config_path',
|
592 |
+
type=str,
|
593 |
+
help='Path to config file for training.',
|
594 |
+
required='--continue_path' not in sys.argv)
|
595 |
+
parser.add_argument('--debug',
|
596 |
+
type=bool,
|
597 |
+
default=False,
|
598 |
+
help='Do not verify commit integrity to run training.')
|
599 |
+
|
600 |
+
# DISTRUBUTED
|
601 |
+
parser.add_argument(
|
602 |
+
'--rank',
|
603 |
+
type=int,
|
604 |
+
default=0,
|
605 |
+
help='DISTRIBUTED: process rank for distributed training.')
|
606 |
+
parser.add_argument('--group_id',
|
607 |
+
type=str,
|
608 |
+
default="",
|
609 |
+
help='DISTRIBUTED: process group id.')
|
610 |
+
args = parser.parse_args()
|
611 |
+
|
612 |
+
if args.continue_path != '':
|
613 |
+
args.output_path = args.continue_path
|
614 |
+
args.config_path = os.path.join(args.continue_path, 'config.json')
|
615 |
+
list_of_files = glob.glob(
|
616 |
+
args.continue_path +
|
617 |
+
"/*.pth.tar") # * means all if need specific format then *.csv
|
618 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
619 |
+
args.restore_path = latest_model_file
|
620 |
+
print(f" > Training continues for {args.restore_path}")
|
621 |
+
|
622 |
+
# setup output paths and read configs
|
623 |
+
c = load_config(args.config_path)
|
624 |
+
# check_config(c)
|
625 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
626 |
+
|
627 |
+
OUT_PATH = args.continue_path
|
628 |
+
if args.continue_path == '':
|
629 |
+
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
630 |
+
args.debug)
|
631 |
+
|
632 |
+
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
633 |
+
|
634 |
+
c_logger = ConsoleLogger()
|
635 |
+
|
636 |
+
if args.rank == 0:
|
637 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
638 |
+
new_fields = {}
|
639 |
+
if args.restore_path:
|
640 |
+
new_fields["restore_path"] = args.restore_path
|
641 |
+
new_fields["github_branch"] = get_git_branch()
|
642 |
+
copy_model_files(c, args.config_path,
|
643 |
+
OUT_PATH, new_fields)
|
644 |
+
os.chmod(AUDIO_PATH, 0o775)
|
645 |
+
os.chmod(OUT_PATH, 0o775)
|
646 |
+
|
647 |
+
LOG_DIR = OUT_PATH
|
648 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
649 |
+
|
650 |
+
# write model desc to tensorboard
|
651 |
+
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
652 |
+
|
653 |
+
try:
|
654 |
+
main(args)
|
655 |
+
except KeyboardInterrupt:
|
656 |
+
remove_experiment_folder(OUT_PATH)
|
657 |
+
try:
|
658 |
+
sys.exit(0)
|
659 |
+
except SystemExit:
|
660 |
+
os._exit(0) # pylint: disable=protected-access
|
661 |
+
except Exception: # pylint: disable=broad-except
|
662 |
+
remove_experiment_folder(OUT_PATH)
|
663 |
+
traceback.print_exc()
|
664 |
+
sys.exit(1)
|
TTS/bin/train_vocoder_wavegrad.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
# DISTRIBUTED
|
11 |
+
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
12 |
+
from torch.optim import Adam
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from torch.utils.data.distributed import DistributedSampler
|
15 |
+
from TTS.utils.audio import AudioProcessor
|
16 |
+
from TTS.utils.console_logger import ConsoleLogger
|
17 |
+
from TTS.utils.distribute import init_distributed
|
18 |
+
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
19 |
+
create_experiment_folder, get_git_branch,
|
20 |
+
remove_experiment_folder, set_init_dict)
|
21 |
+
from TTS.utils.io import copy_model_files, load_config
|
22 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
23 |
+
from TTS.utils.training import setup_torch_training_env
|
24 |
+
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
25 |
+
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
26 |
+
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
27 |
+
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
28 |
+
|
29 |
+
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
30 |
+
|
31 |
+
|
32 |
+
def setup_loader(ap, is_val=False, verbose=False):
|
33 |
+
if is_val and not c.run_eval:
|
34 |
+
loader = None
|
35 |
+
else:
|
36 |
+
dataset = WaveGradDataset(ap=ap,
|
37 |
+
items=eval_data if is_val else train_data,
|
38 |
+
seq_len=c.seq_len,
|
39 |
+
hop_len=ap.hop_length,
|
40 |
+
pad_short=c.pad_short,
|
41 |
+
conv_pad=c.conv_pad,
|
42 |
+
is_training=not is_val,
|
43 |
+
return_segments=True,
|
44 |
+
use_noise_augment=False,
|
45 |
+
use_cache=c.use_cache,
|
46 |
+
verbose=verbose)
|
47 |
+
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
48 |
+
loader = DataLoader(dataset,
|
49 |
+
batch_size=c.batch_size,
|
50 |
+
shuffle=num_gpus <= 1,
|
51 |
+
drop_last=False,
|
52 |
+
sampler=sampler,
|
53 |
+
num_workers=c.num_val_loader_workers
|
54 |
+
if is_val else c.num_loader_workers,
|
55 |
+
pin_memory=False)
|
56 |
+
|
57 |
+
|
58 |
+
return loader
|
59 |
+
|
60 |
+
|
61 |
+
def format_data(data):
|
62 |
+
# return a whole audio segment
|
63 |
+
m, x = data
|
64 |
+
x = x.unsqueeze(1)
|
65 |
+
if use_cuda:
|
66 |
+
m = m.cuda(non_blocking=True)
|
67 |
+
x = x.cuda(non_blocking=True)
|
68 |
+
return m, x
|
69 |
+
|
70 |
+
|
71 |
+
def format_test_data(data):
|
72 |
+
# return a whole audio segment
|
73 |
+
m, x = data
|
74 |
+
m = m[None, ...]
|
75 |
+
x = x[None, None, ...]
|
76 |
+
if use_cuda:
|
77 |
+
m = m.cuda(non_blocking=True)
|
78 |
+
x = x.cuda(non_blocking=True)
|
79 |
+
return m, x
|
80 |
+
|
81 |
+
|
82 |
+
def train(model, criterion, optimizer,
|
83 |
+
scheduler, scaler, ap, global_step, epoch):
|
84 |
+
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
85 |
+
model.train()
|
86 |
+
epoch_time = 0
|
87 |
+
keep_avg = KeepAverage()
|
88 |
+
if use_cuda:
|
89 |
+
batch_n_iter = int(
|
90 |
+
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
91 |
+
else:
|
92 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
93 |
+
end_time = time.time()
|
94 |
+
c_logger.print_train_start()
|
95 |
+
# setup noise schedule
|
96 |
+
noise_schedule = c['train_noise_schedule']
|
97 |
+
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
98 |
+
if hasattr(model, 'module'):
|
99 |
+
model.module.compute_noise_level(betas)
|
100 |
+
else:
|
101 |
+
model.compute_noise_level(betas)
|
102 |
+
for num_iter, data in enumerate(data_loader):
|
103 |
+
start_time = time.time()
|
104 |
+
|
105 |
+
# format data
|
106 |
+
m, x = format_data(data)
|
107 |
+
loader_time = time.time() - end_time
|
108 |
+
|
109 |
+
global_step += 1
|
110 |
+
|
111 |
+
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
112 |
+
# compute noisy input
|
113 |
+
if hasattr(model, 'module'):
|
114 |
+
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
115 |
+
else:
|
116 |
+
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
117 |
+
|
118 |
+
# forward pass
|
119 |
+
noise_hat = model(x_noisy, m, noise_scale)
|
120 |
+
|
121 |
+
# compute losses
|
122 |
+
loss = criterion(noise, noise_hat)
|
123 |
+
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
124 |
+
|
125 |
+
# check nan loss
|
126 |
+
if torch.isnan(loss).any():
|
127 |
+
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
128 |
+
|
129 |
+
optimizer.zero_grad()
|
130 |
+
|
131 |
+
# backward pass with loss scaling
|
132 |
+
if c.mixed_precision:
|
133 |
+
scaler.scale(loss).backward()
|
134 |
+
scaler.unscale_(optimizer)
|
135 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
136 |
+
c.clip_grad)
|
137 |
+
scaler.step(optimizer)
|
138 |
+
scaler.update()
|
139 |
+
else:
|
140 |
+
loss.backward()
|
141 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
142 |
+
c.clip_grad)
|
143 |
+
optimizer.step()
|
144 |
+
|
145 |
+
# schedule update
|
146 |
+
if scheduler is not None:
|
147 |
+
scheduler.step()
|
148 |
+
|
149 |
+
# disconnect loss values
|
150 |
+
loss_dict = dict()
|
151 |
+
for key, value in loss_wavegrad_dict.items():
|
152 |
+
if isinstance(value, int):
|
153 |
+
loss_dict[key] = value
|
154 |
+
else:
|
155 |
+
loss_dict[key] = value.item()
|
156 |
+
|
157 |
+
# epoch/step timing
|
158 |
+
step_time = time.time() - start_time
|
159 |
+
epoch_time += step_time
|
160 |
+
|
161 |
+
# get current learning rates
|
162 |
+
current_lr = list(optimizer.param_groups)[0]['lr']
|
163 |
+
|
164 |
+
# update avg stats
|
165 |
+
update_train_values = dict()
|
166 |
+
for key, value in loss_dict.items():
|
167 |
+
update_train_values['avg_' + key] = value
|
168 |
+
update_train_values['avg_loader_time'] = loader_time
|
169 |
+
update_train_values['avg_step_time'] = step_time
|
170 |
+
keep_avg.update_values(update_train_values)
|
171 |
+
|
172 |
+
# print training stats
|
173 |
+
if global_step % c.print_step == 0:
|
174 |
+
log_dict = {
|
175 |
+
'step_time': [step_time, 2],
|
176 |
+
'loader_time': [loader_time, 4],
|
177 |
+
"current_lr": current_lr,
|
178 |
+
"grad_norm": grad_norm.item()
|
179 |
+
}
|
180 |
+
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
181 |
+
log_dict, loss_dict, keep_avg.avg_values)
|
182 |
+
|
183 |
+
if args.rank == 0:
|
184 |
+
# plot step stats
|
185 |
+
if global_step % 10 == 0:
|
186 |
+
iter_stats = {
|
187 |
+
"lr": current_lr,
|
188 |
+
"grad_norm": grad_norm.item(),
|
189 |
+
"step_time": step_time
|
190 |
+
}
|
191 |
+
iter_stats.update(loss_dict)
|
192 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
193 |
+
|
194 |
+
# save checkpoint
|
195 |
+
if global_step % c.save_step == 0:
|
196 |
+
if c.checkpoint:
|
197 |
+
# save model
|
198 |
+
save_checkpoint(model,
|
199 |
+
optimizer,
|
200 |
+
scheduler,
|
201 |
+
None,
|
202 |
+
None,
|
203 |
+
None,
|
204 |
+
global_step,
|
205 |
+
epoch,
|
206 |
+
OUT_PATH,
|
207 |
+
model_losses=loss_dict,
|
208 |
+
scaler=scaler.state_dict() if c.mixed_precision else None)
|
209 |
+
|
210 |
+
end_time = time.time()
|
211 |
+
|
212 |
+
# print epoch stats
|
213 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
214 |
+
|
215 |
+
# Plot Training Epoch Stats
|
216 |
+
epoch_stats = {"epoch_time": epoch_time}
|
217 |
+
epoch_stats.update(keep_avg.avg_values)
|
218 |
+
if args.rank == 0:
|
219 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
220 |
+
# TODO: plot model stats
|
221 |
+
if c.tb_model_param_stats and args.rank == 0:
|
222 |
+
tb_logger.tb_model_weights(model, global_step)
|
223 |
+
return keep_avg.avg_values, global_step
|
224 |
+
|
225 |
+
|
226 |
+
@torch.no_grad()
|
227 |
+
def evaluate(model, criterion, ap, global_step, epoch):
|
228 |
+
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
229 |
+
model.eval()
|
230 |
+
epoch_time = 0
|
231 |
+
keep_avg = KeepAverage()
|
232 |
+
end_time = time.time()
|
233 |
+
c_logger.print_eval_start()
|
234 |
+
for num_iter, data in enumerate(data_loader):
|
235 |
+
start_time = time.time()
|
236 |
+
|
237 |
+
# format data
|
238 |
+
m, x = format_data(data)
|
239 |
+
loader_time = time.time() - end_time
|
240 |
+
|
241 |
+
global_step += 1
|
242 |
+
|
243 |
+
# compute noisy input
|
244 |
+
if hasattr(model, 'module'):
|
245 |
+
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
246 |
+
else:
|
247 |
+
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
248 |
+
|
249 |
+
|
250 |
+
# forward pass
|
251 |
+
noise_hat = model(x_noisy, m, noise_scale)
|
252 |
+
|
253 |
+
# compute losses
|
254 |
+
loss = criterion(noise, noise_hat)
|
255 |
+
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
256 |
+
|
257 |
+
|
258 |
+
loss_dict = dict()
|
259 |
+
for key, value in loss_wavegrad_dict.items():
|
260 |
+
if isinstance(value, (int, float)):
|
261 |
+
loss_dict[key] = value
|
262 |
+
else:
|
263 |
+
loss_dict[key] = value.item()
|
264 |
+
|
265 |
+
step_time = time.time() - start_time
|
266 |
+
epoch_time += step_time
|
267 |
+
|
268 |
+
# update avg stats
|
269 |
+
update_eval_values = dict()
|
270 |
+
for key, value in loss_dict.items():
|
271 |
+
update_eval_values['avg_' + key] = value
|
272 |
+
update_eval_values['avg_loader_time'] = loader_time
|
273 |
+
update_eval_values['avg_step_time'] = step_time
|
274 |
+
keep_avg.update_values(update_eval_values)
|
275 |
+
|
276 |
+
# print eval stats
|
277 |
+
if c.print_eval:
|
278 |
+
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
279 |
+
|
280 |
+
if args.rank == 0:
|
281 |
+
data_loader.dataset.return_segments = False
|
282 |
+
samples = data_loader.dataset.load_test_samples(1)
|
283 |
+
m, x = format_test_data(samples[0])
|
284 |
+
|
285 |
+
# setup noise schedule and inference
|
286 |
+
noise_schedule = c['test_noise_schedule']
|
287 |
+
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
288 |
+
if hasattr(model, 'module'):
|
289 |
+
model.module.compute_noise_level(betas)
|
290 |
+
# compute voice
|
291 |
+
x_pred = model.module.inference(m)
|
292 |
+
else:
|
293 |
+
model.compute_noise_level(betas)
|
294 |
+
# compute voice
|
295 |
+
x_pred = model.inference(m)
|
296 |
+
|
297 |
+
# compute spectrograms
|
298 |
+
figures = plot_results(x_pred, x, ap, global_step, 'eval')
|
299 |
+
tb_logger.tb_eval_figures(global_step, figures)
|
300 |
+
|
301 |
+
# Sample audio
|
302 |
+
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
303 |
+
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
304 |
+
c.audio["sample_rate"])
|
305 |
+
|
306 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
307 |
+
data_loader.dataset.return_segments = True
|
308 |
+
|
309 |
+
return keep_avg.avg_values
|
310 |
+
|
311 |
+
|
312 |
+
def main(args): # pylint: disable=redefined-outer-name
|
313 |
+
# pylint: disable=global-variable-undefined
|
314 |
+
global train_data, eval_data
|
315 |
+
print(f" > Loading wavs from: {c.data_path}")
|
316 |
+
if c.feature_path is not None:
|
317 |
+
print(f" > Loading features from: {c.feature_path}")
|
318 |
+
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
319 |
+
else:
|
320 |
+
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
321 |
+
|
322 |
+
# setup audio processor
|
323 |
+
ap = AudioProcessor(**c.audio)
|
324 |
+
|
325 |
+
# DISTRUBUTED
|
326 |
+
if num_gpus > 1:
|
327 |
+
init_distributed(args.rank, num_gpus, args.group_id,
|
328 |
+
c.distributed["backend"], c.distributed["url"])
|
329 |
+
|
330 |
+
# setup models
|
331 |
+
model = setup_generator(c)
|
332 |
+
|
333 |
+
# scaler for mixed_precision
|
334 |
+
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
335 |
+
|
336 |
+
# setup optimizers
|
337 |
+
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
338 |
+
|
339 |
+
# schedulers
|
340 |
+
scheduler = None
|
341 |
+
if 'lr_scheduler' in c:
|
342 |
+
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
343 |
+
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
344 |
+
|
345 |
+
# setup criterion
|
346 |
+
criterion = torch.nn.L1Loss().cuda()
|
347 |
+
|
348 |
+
if args.restore_path:
|
349 |
+
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
350 |
+
try:
|
351 |
+
print(" > Restoring Model...")
|
352 |
+
model.load_state_dict(checkpoint['model'])
|
353 |
+
print(" > Restoring Optimizer...")
|
354 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
355 |
+
if 'scheduler' in checkpoint:
|
356 |
+
print(" > Restoring LR Scheduler...")
|
357 |
+
scheduler.load_state_dict(checkpoint['scheduler'])
|
358 |
+
# NOTE: Not sure if necessary
|
359 |
+
scheduler.optimizer = optimizer
|
360 |
+
if "scaler" in checkpoint and c.mixed_precision:
|
361 |
+
print(" > Restoring AMP Scaler...")
|
362 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
363 |
+
except RuntimeError:
|
364 |
+
# retore only matching layers.
|
365 |
+
print(" > Partial model initialization...")
|
366 |
+
model_dict = model.state_dict()
|
367 |
+
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
368 |
+
model.load_state_dict(model_dict)
|
369 |
+
del model_dict
|
370 |
+
|
371 |
+
# reset lr if not countinuining training.
|
372 |
+
for group in optimizer.param_groups:
|
373 |
+
group['lr'] = c.lr
|
374 |
+
|
375 |
+
print(" > Model restored from step %d" % checkpoint['step'],
|
376 |
+
flush=True)
|
377 |
+
args.restore_step = checkpoint['step']
|
378 |
+
else:
|
379 |
+
args.restore_step = 0
|
380 |
+
|
381 |
+
if use_cuda:
|
382 |
+
model.cuda()
|
383 |
+
criterion.cuda()
|
384 |
+
|
385 |
+
# DISTRUBUTED
|
386 |
+
if num_gpus > 1:
|
387 |
+
model = DDP_th(model, device_ids=[args.rank])
|
388 |
+
|
389 |
+
num_params = count_parameters(model)
|
390 |
+
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
391 |
+
|
392 |
+
if 'best_loss' not in locals():
|
393 |
+
best_loss = float('inf')
|
394 |
+
|
395 |
+
global_step = args.restore_step
|
396 |
+
for epoch in range(0, c.epochs):
|
397 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
398 |
+
_, global_step = train(model, criterion, optimizer,
|
399 |
+
scheduler, scaler, ap, global_step,
|
400 |
+
epoch)
|
401 |
+
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
402 |
+
global_step, epoch)
|
403 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
404 |
+
target_loss = eval_avg_loss_dict[c.target_loss]
|
405 |
+
best_loss = save_best_model(target_loss,
|
406 |
+
best_loss,
|
407 |
+
model,
|
408 |
+
optimizer,
|
409 |
+
scheduler,
|
410 |
+
None,
|
411 |
+
None,
|
412 |
+
None,
|
413 |
+
global_step,
|
414 |
+
epoch,
|
415 |
+
OUT_PATH,
|
416 |
+
model_losses=eval_avg_loss_dict,
|
417 |
+
scaler=scaler.state_dict() if c.mixed_precision else None)
|
418 |
+
|
419 |
+
|
420 |
+
if __name__ == '__main__':
|
421 |
+
parser = argparse.ArgumentParser()
|
422 |
+
parser.add_argument(
|
423 |
+
'--continue_path',
|
424 |
+
type=str,
|
425 |
+
help=
|
426 |
+
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
427 |
+
default='',
|
428 |
+
required='--config_path' not in sys.argv)
|
429 |
+
parser.add_argument(
|
430 |
+
'--restore_path',
|
431 |
+
type=str,
|
432 |
+
help='Model file to be restored. Use to finetune a model.',
|
433 |
+
default='')
|
434 |
+
parser.add_argument('--config_path',
|
435 |
+
type=str,
|
436 |
+
help='Path to config file for training.',
|
437 |
+
required='--continue_path' not in sys.argv)
|
438 |
+
parser.add_argument('--debug',
|
439 |
+
type=bool,
|
440 |
+
default=False,
|
441 |
+
help='Do not verify commit integrity to run training.')
|
442 |
+
|
443 |
+
# DISTRUBUTED
|
444 |
+
parser.add_argument(
|
445 |
+
'--rank',
|
446 |
+
type=int,
|
447 |
+
default=0,
|
448 |
+
help='DISTRIBUTED: process rank for distributed training.')
|
449 |
+
parser.add_argument('--group_id',
|
450 |
+
type=str,
|
451 |
+
default="",
|
452 |
+
help='DISTRIBUTED: process group id.')
|
453 |
+
args = parser.parse_args()
|
454 |
+
|
455 |
+
if args.continue_path != '':
|
456 |
+
args.output_path = args.continue_path
|
457 |
+
args.config_path = os.path.join(args.continue_path, 'config.json')
|
458 |
+
list_of_files = glob.glob(
|
459 |
+
args.continue_path +
|
460 |
+
"/*.pth.tar") # * means all if need specific format then *.csv
|
461 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
462 |
+
args.restore_path = latest_model_file
|
463 |
+
print(f" > Training continues for {args.restore_path}")
|
464 |
+
|
465 |
+
# setup output paths and read configs
|
466 |
+
c = load_config(args.config_path)
|
467 |
+
# check_config(c)
|
468 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
469 |
+
|
470 |
+
# DISTRIBUTED
|
471 |
+
if c.mixed_precision:
|
472 |
+
print(" > Mixed precision is enabled")
|
473 |
+
|
474 |
+
OUT_PATH = args.continue_path
|
475 |
+
if args.continue_path == '':
|
476 |
+
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
477 |
+
args.debug)
|
478 |
+
|
479 |
+
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
480 |
+
|
481 |
+
c_logger = ConsoleLogger()
|
482 |
+
|
483 |
+
if args.rank == 0:
|
484 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
485 |
+
new_fields = {}
|
486 |
+
if args.restore_path:
|
487 |
+
new_fields["restore_path"] = args.restore_path
|
488 |
+
new_fields["github_branch"] = get_git_branch()
|
489 |
+
copy_model_files(c, args.config_path,
|
490 |
+
OUT_PATH, new_fields)
|
491 |
+
os.chmod(AUDIO_PATH, 0o775)
|
492 |
+
os.chmod(OUT_PATH, 0o775)
|
493 |
+
|
494 |
+
LOG_DIR = OUT_PATH
|
495 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
496 |
+
|
497 |
+
# write model desc to tensorboard
|
498 |
+
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
499 |
+
|
500 |
+
try:
|
501 |
+
main(args)
|
502 |
+
except KeyboardInterrupt:
|
503 |
+
remove_experiment_folder(OUT_PATH)
|
504 |
+
try:
|
505 |
+
sys.exit(0)
|
506 |
+
except SystemExit:
|
507 |
+
os._exit(0) # pylint: disable=protected-access
|
508 |
+
except Exception: # pylint: disable=broad-except
|
509 |
+
remove_experiment_folder(OUT_PATH)
|
510 |
+
traceback.print_exc()
|
511 |
+
sys.exit(1)
|
TTS/bin/train_vocoder_wavernn.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
import time
|
6 |
+
import glob
|
7 |
+
import random
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
# from torch.utils.data.distributed import DistributedSampler
|
13 |
+
|
14 |
+
from TTS.tts.utils.visual import plot_spectrogram
|
15 |
+
from TTS.utils.audio import AudioProcessor
|
16 |
+
from TTS.utils.radam import RAdam
|
17 |
+
from TTS.utils.io import copy_model_files, load_config
|
18 |
+
from TTS.utils.training import setup_torch_training_env
|
19 |
+
from TTS.utils.console_logger import ConsoleLogger
|
20 |
+
from TTS.utils.tensorboard_logger import TensorboardLogger
|
21 |
+
from TTS.utils.generic_utils import (
|
22 |
+
KeepAverage,
|
23 |
+
count_parameters,
|
24 |
+
create_experiment_folder,
|
25 |
+
get_git_branch,
|
26 |
+
remove_experiment_folder,
|
27 |
+
set_init_dict,
|
28 |
+
)
|
29 |
+
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
30 |
+
from TTS.vocoder.datasets.preprocess import (
|
31 |
+
load_wav_data,
|
32 |
+
load_wav_feat_data
|
33 |
+
)
|
34 |
+
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
35 |
+
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
36 |
+
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
37 |
+
|
38 |
+
|
39 |
+
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
40 |
+
|
41 |
+
|
42 |
+
def setup_loader(ap, is_val=False, verbose=False):
|
43 |
+
if is_val and not c.run_eval:
|
44 |
+
loader = None
|
45 |
+
else:
|
46 |
+
dataset = WaveRNNDataset(ap=ap,
|
47 |
+
items=eval_data if is_val else train_data,
|
48 |
+
seq_len=c.seq_len,
|
49 |
+
hop_len=ap.hop_length,
|
50 |
+
pad=c.padding,
|
51 |
+
mode=c.mode,
|
52 |
+
mulaw=c.mulaw,
|
53 |
+
is_training=not is_val,
|
54 |
+
verbose=verbose,
|
55 |
+
)
|
56 |
+
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
57 |
+
loader = DataLoader(dataset,
|
58 |
+
shuffle=True,
|
59 |
+
collate_fn=dataset.collate,
|
60 |
+
batch_size=c.batch_size,
|
61 |
+
num_workers=c.num_val_loader_workers
|
62 |
+
if is_val
|
63 |
+
else c.num_loader_workers,
|
64 |
+
pin_memory=True,
|
65 |
+
)
|
66 |
+
return loader
|
67 |
+
|
68 |
+
|
69 |
+
def format_data(data):
|
70 |
+
# setup input data
|
71 |
+
x_input = data[0]
|
72 |
+
mels = data[1]
|
73 |
+
y_coarse = data[2]
|
74 |
+
|
75 |
+
# dispatch data to GPU
|
76 |
+
if use_cuda:
|
77 |
+
x_input = x_input.cuda(non_blocking=True)
|
78 |
+
mels = mels.cuda(non_blocking=True)
|
79 |
+
y_coarse = y_coarse.cuda(non_blocking=True)
|
80 |
+
|
81 |
+
return x_input, mels, y_coarse
|
82 |
+
|
83 |
+
|
84 |
+
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
|
85 |
+
# create train loader
|
86 |
+
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
87 |
+
model.train()
|
88 |
+
epoch_time = 0
|
89 |
+
keep_avg = KeepAverage()
|
90 |
+
if use_cuda:
|
91 |
+
batch_n_iter = int(len(data_loader.dataset) /
|
92 |
+
(c.batch_size * num_gpus))
|
93 |
+
else:
|
94 |
+
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
95 |
+
end_time = time.time()
|
96 |
+
c_logger.print_train_start()
|
97 |
+
# train loop
|
98 |
+
for num_iter, data in enumerate(data_loader):
|
99 |
+
start_time = time.time()
|
100 |
+
x_input, mels, y_coarse = format_data(data)
|
101 |
+
loader_time = time.time() - end_time
|
102 |
+
global_step += 1
|
103 |
+
|
104 |
+
optimizer.zero_grad()
|
105 |
+
|
106 |
+
if c.mixed_precision:
|
107 |
+
# mixed precision training
|
108 |
+
with torch.cuda.amp.autocast():
|
109 |
+
y_hat = model(x_input, mels)
|
110 |
+
if isinstance(model.mode, int):
|
111 |
+
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
112 |
+
else:
|
113 |
+
y_coarse = y_coarse.float()
|
114 |
+
y_coarse = y_coarse.unsqueeze(-1)
|
115 |
+
# compute losses
|
116 |
+
loss = criterion(y_hat, y_coarse)
|
117 |
+
scaler.scale(loss).backward()
|
118 |
+
scaler.unscale_(optimizer)
|
119 |
+
if c.grad_clip > 0:
|
120 |
+
torch.nn.utils.clip_grad_norm_(
|
121 |
+
model.parameters(), c.grad_clip)
|
122 |
+
scaler.step(optimizer)
|
123 |
+
scaler.update()
|
124 |
+
else:
|
125 |
+
# full precision training
|
126 |
+
y_hat = model(x_input, mels)
|
127 |
+
if isinstance(model.mode, int):
|
128 |
+
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
129 |
+
else:
|
130 |
+
y_coarse = y_coarse.float()
|
131 |
+
y_coarse = y_coarse.unsqueeze(-1)
|
132 |
+
# compute losses
|
133 |
+
loss = criterion(y_hat, y_coarse)
|
134 |
+
if loss.item() is None:
|
135 |
+
raise RuntimeError(" [!] None loss. Exiting ...")
|
136 |
+
loss.backward()
|
137 |
+
if c.grad_clip > 0:
|
138 |
+
torch.nn.utils.clip_grad_norm_(
|
139 |
+
model.parameters(), c.grad_clip)
|
140 |
+
optimizer.step()
|
141 |
+
|
142 |
+
if scheduler is not None:
|
143 |
+
scheduler.step()
|
144 |
+
|
145 |
+
# get the current learning rate
|
146 |
+
cur_lr = list(optimizer.param_groups)[0]["lr"]
|
147 |
+
|
148 |
+
step_time = time.time() - start_time
|
149 |
+
epoch_time += step_time
|
150 |
+
|
151 |
+
update_train_values = dict()
|
152 |
+
loss_dict = dict()
|
153 |
+
loss_dict["model_loss"] = loss.item()
|
154 |
+
for key, value in loss_dict.items():
|
155 |
+
update_train_values["avg_" + key] = value
|
156 |
+
update_train_values["avg_loader_time"] = loader_time
|
157 |
+
update_train_values["avg_step_time"] = step_time
|
158 |
+
keep_avg.update_values(update_train_values)
|
159 |
+
|
160 |
+
# print training stats
|
161 |
+
if global_step % c.print_step == 0:
|
162 |
+
log_dict = {"step_time": [step_time, 2],
|
163 |
+
"loader_time": [loader_time, 4],
|
164 |
+
"current_lr": cur_lr,
|
165 |
+
}
|
166 |
+
c_logger.print_train_step(batch_n_iter,
|
167 |
+
num_iter,
|
168 |
+
global_step,
|
169 |
+
log_dict,
|
170 |
+
loss_dict,
|
171 |
+
keep_avg.avg_values,
|
172 |
+
)
|
173 |
+
|
174 |
+
# plot step stats
|
175 |
+
if global_step % 10 == 0:
|
176 |
+
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
177 |
+
iter_stats.update(loss_dict)
|
178 |
+
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
179 |
+
|
180 |
+
# save checkpoint
|
181 |
+
if global_step % c.save_step == 0:
|
182 |
+
if c.checkpoint:
|
183 |
+
# save model
|
184 |
+
save_checkpoint(model,
|
185 |
+
optimizer,
|
186 |
+
scheduler,
|
187 |
+
None,
|
188 |
+
None,
|
189 |
+
None,
|
190 |
+
global_step,
|
191 |
+
epoch,
|
192 |
+
OUT_PATH,
|
193 |
+
model_losses=loss_dict,
|
194 |
+
scaler=scaler.state_dict() if c.mixed_precision else None
|
195 |
+
)
|
196 |
+
|
197 |
+
# synthesize a full voice
|
198 |
+
rand_idx = random.randrange(0, len(train_data))
|
199 |
+
wav_path = train_data[rand_idx] if not isinstance(
|
200 |
+
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
201 |
+
wav = ap.load_wav(wav_path)
|
202 |
+
ground_mel = ap.melspectrogram(wav)
|
203 |
+
sample_wav = model.generate(ground_mel,
|
204 |
+
c.batched,
|
205 |
+
c.target_samples,
|
206 |
+
c.overlap_samples,
|
207 |
+
use_cuda
|
208 |
+
)
|
209 |
+
predict_mel = ap.melspectrogram(sample_wav)
|
210 |
+
|
211 |
+
# compute spectrograms
|
212 |
+
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
213 |
+
"train/prediction": plot_spectrogram(predict_mel.T)
|
214 |
+
}
|
215 |
+
tb_logger.tb_train_figures(global_step, figures)
|
216 |
+
|
217 |
+
# Sample audio
|
218 |
+
tb_logger.tb_train_audios(
|
219 |
+
global_step, {
|
220 |
+
"train/audio": sample_wav}, c.audio["sample_rate"]
|
221 |
+
)
|
222 |
+
end_time = time.time()
|
223 |
+
|
224 |
+
# print epoch stats
|
225 |
+
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
226 |
+
|
227 |
+
# Plot Training Epoch Stats
|
228 |
+
epoch_stats = {"epoch_time": epoch_time}
|
229 |
+
epoch_stats.update(keep_avg.avg_values)
|
230 |
+
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
231 |
+
# TODO: plot model stats
|
232 |
+
# if c.tb_model_param_stats:
|
233 |
+
# tb_logger.tb_model_weights(model, global_step)
|
234 |
+
return keep_avg.avg_values, global_step
|
235 |
+
|
236 |
+
|
237 |
+
@torch.no_grad()
|
238 |
+
def evaluate(model, criterion, ap, global_step, epoch):
|
239 |
+
# create train loader
|
240 |
+
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
241 |
+
model.eval()
|
242 |
+
epoch_time = 0
|
243 |
+
keep_avg = KeepAverage()
|
244 |
+
end_time = time.time()
|
245 |
+
c_logger.print_eval_start()
|
246 |
+
with torch.no_grad():
|
247 |
+
for num_iter, data in enumerate(data_loader):
|
248 |
+
start_time = time.time()
|
249 |
+
# format data
|
250 |
+
x_input, mels, y_coarse = format_data(data)
|
251 |
+
loader_time = time.time() - end_time
|
252 |
+
global_step += 1
|
253 |
+
|
254 |
+
y_hat = model(x_input, mels)
|
255 |
+
if isinstance(model.mode, int):
|
256 |
+
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
257 |
+
else:
|
258 |
+
y_coarse = y_coarse.float()
|
259 |
+
y_coarse = y_coarse.unsqueeze(-1)
|
260 |
+
loss = criterion(y_hat, y_coarse)
|
261 |
+
# Compute avg loss
|
262 |
+
# if num_gpus > 1:
|
263 |
+
# loss = reduce_tensor(loss.data, num_gpus)
|
264 |
+
loss_dict = dict()
|
265 |
+
loss_dict["model_loss"] = loss.item()
|
266 |
+
|
267 |
+
step_time = time.time() - start_time
|
268 |
+
epoch_time += step_time
|
269 |
+
|
270 |
+
# update avg stats
|
271 |
+
update_eval_values = dict()
|
272 |
+
for key, value in loss_dict.items():
|
273 |
+
update_eval_values["avg_" + key] = value
|
274 |
+
update_eval_values["avg_loader_time"] = loader_time
|
275 |
+
update_eval_values["avg_step_time"] = step_time
|
276 |
+
keep_avg.update_values(update_eval_values)
|
277 |
+
|
278 |
+
# print eval stats
|
279 |
+
if c.print_eval:
|
280 |
+
c_logger.print_eval_step(
|
281 |
+
num_iter, loss_dict, keep_avg.avg_values)
|
282 |
+
|
283 |
+
if epoch % c.test_every_epochs == 0 and epoch != 0:
|
284 |
+
# synthesize a full voice
|
285 |
+
rand_idx = random.randrange(0, len(eval_data))
|
286 |
+
wav_path = eval_data[rand_idx] if not isinstance(
|
287 |
+
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
288 |
+
wav = ap.load_wav(wav_path)
|
289 |
+
ground_mel = ap.melspectrogram(wav)
|
290 |
+
sample_wav = model.generate(ground_mel,
|
291 |
+
c.batched,
|
292 |
+
c.target_samples,
|
293 |
+
c.overlap_samples,
|
294 |
+
use_cuda
|
295 |
+
)
|
296 |
+
predict_mel = ap.melspectrogram(sample_wav)
|
297 |
+
|
298 |
+
# Sample audio
|
299 |
+
tb_logger.tb_eval_audios(
|
300 |
+
global_step, {
|
301 |
+
"eval/audio": sample_wav}, c.audio["sample_rate"]
|
302 |
+
)
|
303 |
+
|
304 |
+
# compute spectrograms
|
305 |
+
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
306 |
+
"eval/prediction": plot_spectrogram(predict_mel.T)
|
307 |
+
}
|
308 |
+
tb_logger.tb_eval_figures(global_step, figures)
|
309 |
+
|
310 |
+
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
311 |
+
return keep_avg.avg_values
|
312 |
+
|
313 |
+
|
314 |
+
# FIXME: move args definition/parsing inside of main?
|
315 |
+
def main(args): # pylint: disable=redefined-outer-name
|
316 |
+
# pylint: disable=global-variable-undefined
|
317 |
+
global train_data, eval_data
|
318 |
+
|
319 |
+
# setup audio processor
|
320 |
+
ap = AudioProcessor(**c.audio)
|
321 |
+
|
322 |
+
# print(f" > Loading wavs from: {c.data_path}")
|
323 |
+
# if c.feature_path is not None:
|
324 |
+
# print(f" > Loading features from: {c.feature_path}")
|
325 |
+
# eval_data, train_data = load_wav_feat_data(
|
326 |
+
# c.data_path, c.feature_path, c.eval_split_size
|
327 |
+
# )
|
328 |
+
# else:
|
329 |
+
# mel_feat_path = os.path.join(OUT_PATH, "mel")
|
330 |
+
# feat_data = find_feat_files(mel_feat_path)
|
331 |
+
# if feat_data:
|
332 |
+
# print(f" > Loading features from: {mel_feat_path}")
|
333 |
+
# eval_data, train_data = load_wav_feat_data(
|
334 |
+
# c.data_path, mel_feat_path, c.eval_split_size
|
335 |
+
# )
|
336 |
+
# else:
|
337 |
+
# print(" > No feature data found. Preprocessing...")
|
338 |
+
# # preprocessing feature data from given wav files
|
339 |
+
# preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
340 |
+
# eval_data, train_data = load_wav_feat_data(
|
341 |
+
# c.data_path, mel_feat_path, c.eval_split_size
|
342 |
+
# )
|
343 |
+
|
344 |
+
print(f" > Loading wavs from: {c.data_path}")
|
345 |
+
if c.feature_path is not None:
|
346 |
+
print(f" > Loading features from: {c.feature_path}")
|
347 |
+
eval_data, train_data = load_wav_feat_data(
|
348 |
+
c.data_path, c.feature_path, c.eval_split_size)
|
349 |
+
else:
|
350 |
+
eval_data, train_data = load_wav_data(
|
351 |
+
c.data_path, c.eval_split_size)
|
352 |
+
# setup model
|
353 |
+
model_wavernn = setup_wavernn(c)
|
354 |
+
|
355 |
+
# setup amp scaler
|
356 |
+
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
357 |
+
|
358 |
+
# define train functions
|
359 |
+
if c.mode == "mold":
|
360 |
+
criterion = discretized_mix_logistic_loss
|
361 |
+
elif c.mode == "gauss":
|
362 |
+
criterion = gaussian_loss
|
363 |
+
elif isinstance(c.mode, int):
|
364 |
+
criterion = torch.nn.CrossEntropyLoss()
|
365 |
+
|
366 |
+
if use_cuda:
|
367 |
+
model_wavernn.cuda()
|
368 |
+
if isinstance(c.mode, int):
|
369 |
+
criterion.cuda()
|
370 |
+
|
371 |
+
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
|
372 |
+
|
373 |
+
scheduler = None
|
374 |
+
if "lr_scheduler" in c:
|
375 |
+
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
376 |
+
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
377 |
+
# slow start for the first 5 epochs
|
378 |
+
# lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
|
379 |
+
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
380 |
+
|
381 |
+
# restore any checkpoint
|
382 |
+
if args.restore_path:
|
383 |
+
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
384 |
+
try:
|
385 |
+
print(" > Restoring Model...")
|
386 |
+
model_wavernn.load_state_dict(checkpoint["model"])
|
387 |
+
print(" > Restoring Optimizer...")
|
388 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
389 |
+
if "scheduler" in checkpoint:
|
390 |
+
print(" > Restoring Generator LR Scheduler...")
|
391 |
+
scheduler.load_state_dict(checkpoint["scheduler"])
|
392 |
+
scheduler.optimizer = optimizer
|
393 |
+
if "scaler" in checkpoint and c.mixed_precision:
|
394 |
+
print(" > Restoring AMP Scaler...")
|
395 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
396 |
+
except RuntimeError:
|
397 |
+
# retore only matching layers.
|
398 |
+
print(" > Partial model initialization...")
|
399 |
+
model_dict = model_wavernn.state_dict()
|
400 |
+
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
401 |
+
model_wavernn.load_state_dict(model_dict)
|
402 |
+
|
403 |
+
print(" > Model restored from step %d" %
|
404 |
+
checkpoint["step"], flush=True)
|
405 |
+
args.restore_step = checkpoint["step"]
|
406 |
+
else:
|
407 |
+
args.restore_step = 0
|
408 |
+
|
409 |
+
# DISTRIBUTED
|
410 |
+
# if num_gpus > 1:
|
411 |
+
# model = apply_gradient_allreduce(model)
|
412 |
+
|
413 |
+
num_parameters = count_parameters(model_wavernn)
|
414 |
+
print(" > Model has {} parameters".format(num_parameters), flush=True)
|
415 |
+
|
416 |
+
if "best_loss" not in locals():
|
417 |
+
best_loss = float("inf")
|
418 |
+
|
419 |
+
global_step = args.restore_step
|
420 |
+
for epoch in range(0, c.epochs):
|
421 |
+
c_logger.print_epoch_start(epoch, c.epochs)
|
422 |
+
_, global_step = train(model_wavernn, optimizer,
|
423 |
+
criterion, scheduler, scaler, ap, global_step, epoch)
|
424 |
+
eval_avg_loss_dict = evaluate(
|
425 |
+
model_wavernn, criterion, ap, global_step, epoch)
|
426 |
+
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
427 |
+
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
428 |
+
best_loss = save_best_model(
|
429 |
+
target_loss,
|
430 |
+
best_loss,
|
431 |
+
model_wavernn,
|
432 |
+
optimizer,
|
433 |
+
scheduler,
|
434 |
+
None,
|
435 |
+
None,
|
436 |
+
None,
|
437 |
+
global_step,
|
438 |
+
epoch,
|
439 |
+
OUT_PATH,
|
440 |
+
model_losses=eval_avg_loss_dict,
|
441 |
+
scaler=scaler.state_dict() if c.mixed_precision else None
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
if __name__ == "__main__":
|
446 |
+
parser = argparse.ArgumentParser()
|
447 |
+
parser.add_argument(
|
448 |
+
"--continue_path",
|
449 |
+
type=str,
|
450 |
+
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
451 |
+
default="",
|
452 |
+
required="--config_path" not in sys.argv,
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--restore_path",
|
456 |
+
type=str,
|
457 |
+
help="Model file to be restored. Use to finetune a model.",
|
458 |
+
default="",
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--config_path",
|
462 |
+
type=str,
|
463 |
+
help="Path to config file for training.",
|
464 |
+
required="--continue_path" not in sys.argv,
|
465 |
+
)
|
466 |
+
parser.add_argument(
|
467 |
+
"--debug",
|
468 |
+
type=bool,
|
469 |
+
default=False,
|
470 |
+
help="Do not verify commit integrity to run training.",
|
471 |
+
)
|
472 |
+
|
473 |
+
# DISTRUBUTED
|
474 |
+
parser.add_argument(
|
475 |
+
"--rank",
|
476 |
+
type=int,
|
477 |
+
default=0,
|
478 |
+
help="DISTRIBUTED: process rank for distributed training.",
|
479 |
+
)
|
480 |
+
parser.add_argument(
|
481 |
+
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
|
482 |
+
)
|
483 |
+
args = parser.parse_args()
|
484 |
+
|
485 |
+
if args.continue_path != "":
|
486 |
+
args.output_path = args.continue_path
|
487 |
+
args.config_path = os.path.join(args.continue_path, "config.json")
|
488 |
+
list_of_files = glob.glob(
|
489 |
+
args.continue_path + "/*.pth.tar"
|
490 |
+
) # * means all if need specific format then *.csv
|
491 |
+
latest_model_file = max(list_of_files, key=os.path.getctime)
|
492 |
+
args.restore_path = latest_model_file
|
493 |
+
print(f" > Training continues for {args.restore_path}")
|
494 |
+
|
495 |
+
# setup output paths and read configs
|
496 |
+
c = load_config(args.config_path)
|
497 |
+
# check_config(c)
|
498 |
+
_ = os.path.dirname(os.path.realpath(__file__))
|
499 |
+
|
500 |
+
OUT_PATH = args.continue_path
|
501 |
+
if args.continue_path == "":
|
502 |
+
OUT_PATH = create_experiment_folder(
|
503 |
+
c.output_path, c.run_name, args.debug
|
504 |
+
)
|
505 |
+
|
506 |
+
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
507 |
+
|
508 |
+
c_logger = ConsoleLogger()
|
509 |
+
|
510 |
+
if args.rank == 0:
|
511 |
+
os.makedirs(AUDIO_PATH, exist_ok=True)
|
512 |
+
new_fields = {}
|
513 |
+
if args.restore_path:
|
514 |
+
new_fields["restore_path"] = args.restore_path
|
515 |
+
new_fields["github_branch"] = get_git_branch()
|
516 |
+
copy_model_files(
|
517 |
+
c, args.config_path, OUT_PATH, new_fields
|
518 |
+
)
|
519 |
+
os.chmod(AUDIO_PATH, 0o775)
|
520 |
+
os.chmod(OUT_PATH, 0o775)
|
521 |
+
|
522 |
+
LOG_DIR = OUT_PATH
|
523 |
+
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
524 |
+
|
525 |
+
# write model desc to tensorboard
|
526 |
+
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
527 |
+
|
528 |
+
try:
|
529 |
+
main(args)
|
530 |
+
except KeyboardInterrupt:
|
531 |
+
remove_experiment_folder(OUT_PATH)
|
532 |
+
try:
|
533 |
+
sys.exit(0)
|
534 |
+
except SystemExit:
|
535 |
+
os._exit(0) # pylint: disable=protected-access
|
536 |
+
except Exception: # pylint: disable=broad-except
|
537 |
+
remove_experiment_folder(OUT_PATH)
|
538 |
+
traceback.print_exc()
|
539 |
+
sys.exit(1)
|
TTS/bin/tune_wavegrad.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
2 |
+
import argparse
|
3 |
+
from itertools import product as cartesian_product
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from tqdm import tqdm
|
9 |
+
from TTS.utils.audio import AudioProcessor
|
10 |
+
from TTS.utils.io import load_config
|
11 |
+
from TTS.vocoder.datasets.preprocess import load_wav_data
|
12 |
+
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
13 |
+
from TTS.vocoder.utils.generic_utils import setup_generator
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
|
17 |
+
parser.add_argument('--config_path', type=str, help='Path to model config file.')
|
18 |
+
parser.add_argument('--data_path', type=str, help='Path to data directory.')
|
19 |
+
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
|
20 |
+
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
|
21 |
+
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
|
22 |
+
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
|
23 |
+
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
|
24 |
+
|
25 |
+
# load config
|
26 |
+
args = parser.parse_args()
|
27 |
+
config = load_config(args.config_path)
|
28 |
+
|
29 |
+
# setup audio processor
|
30 |
+
ap = AudioProcessor(**config.audio)
|
31 |
+
|
32 |
+
# load dataset
|
33 |
+
_, train_data = load_wav_data(args.data_path, 0)
|
34 |
+
train_data = train_data[:args.num_samples]
|
35 |
+
dataset = WaveGradDataset(ap=ap,
|
36 |
+
items=train_data,
|
37 |
+
seq_len=-1,
|
38 |
+
hop_len=ap.hop_length,
|
39 |
+
pad_short=config.pad_short,
|
40 |
+
conv_pad=config.conv_pad,
|
41 |
+
is_training=True,
|
42 |
+
return_segments=False,
|
43 |
+
use_noise_augment=False,
|
44 |
+
use_cache=False,
|
45 |
+
verbose=True)
|
46 |
+
loader = DataLoader(
|
47 |
+
dataset,
|
48 |
+
batch_size=1,
|
49 |
+
shuffle=False,
|
50 |
+
collate_fn=dataset.collate_full_clips,
|
51 |
+
drop_last=False,
|
52 |
+
num_workers=config.num_loader_workers,
|
53 |
+
pin_memory=False)
|
54 |
+
|
55 |
+
# setup the model
|
56 |
+
model = setup_generator(config)
|
57 |
+
if args.use_cuda:
|
58 |
+
model.cuda()
|
59 |
+
|
60 |
+
# setup optimization parameters
|
61 |
+
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
62 |
+
print(base_values)
|
63 |
+
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
64 |
+
best_error = float('inf')
|
65 |
+
best_schedule = None
|
66 |
+
total_search_iter = len(base_values)**args.num_iter
|
67 |
+
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
68 |
+
beta = exponents * base
|
69 |
+
model.compute_noise_level(beta)
|
70 |
+
for data in loader:
|
71 |
+
mel, audio = data
|
72 |
+
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
73 |
+
|
74 |
+
if args.use_cuda:
|
75 |
+
y_hat = y_hat.cpu()
|
76 |
+
y_hat = y_hat.numpy()
|
77 |
+
|
78 |
+
mel_hat = []
|
79 |
+
for i in range(y_hat.shape[0]):
|
80 |
+
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
81 |
+
mel_hat.append(torch.from_numpy(m))
|
82 |
+
|
83 |
+
mel_hat = torch.stack(mel_hat)
|
84 |
+
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
85 |
+
if mse.item() < best_error:
|
86 |
+
best_error = mse.item()
|
87 |
+
best_schedule = {'beta': beta}
|
88 |
+
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
89 |
+
np.save(args.output_path, best_schedule)
|
90 |
+
|
91 |
+
|
TTS/server/README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## TTS example web-server
|
2 |
+
|
3 |
+
You'll need a model package (Zip file, includes TTS Python wheel, model files, server configuration, and optional nginx/uwsgi configs). Publicly available models are listed [here](https://github.com/mozilla/TTS/wiki/Released-Models#simple-packaging---self-contained-package-that-runs-an-http-api-for-a-pre-trained-tts-model).
|
4 |
+
|
5 |
+
Instructions below are based on a Ubuntu 18.04 machine, but it should be simple to adapt the package names to other distros if needed. Python 3.6 is recommended, as some of the dependencies' versions predate Python 3.7 and will force building from source, which requires extra dependencies and is not guaranteed to work.
|
6 |
+
|
7 |
+
#### Development server:
|
8 |
+
|
9 |
+
##### Using server.py
|
10 |
+
If you have the environment set already for TTS, then you can directly call ```server.py```.
|
11 |
+
|
12 |
+
**Note:** After installing TTS as a package you can use ```tts-server``` to call the commands below.
|
13 |
+
|
14 |
+
Examples runs:
|
15 |
+
|
16 |
+
List officially released models.
|
17 |
+
```python TTS/server/server.py --list_models ```
|
18 |
+
|
19 |
+
Run the server with the official models.
|
20 |
+
```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/mulitband-melgan```
|
21 |
+
|
22 |
+
Run the server with the official models on a GPU.
|
23 |
+
```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/mulitband-melgan --use_cuda True```
|
24 |
+
|
25 |
+
Run the server with a custom models.
|
26 |
+
```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth.tar --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth.tar --vocoder_config /path/to/vocoder/config.json```
|
27 |
+
|
28 |
+
##### Using .whl
|
29 |
+
1. apt-get install -y espeak libsndfile1 python3-venv
|
30 |
+
2. python3 -m venv /tmp/venv
|
31 |
+
3. source /tmp/venv/bin/activate
|
32 |
+
4. pip install -U pip setuptools wheel
|
33 |
+
5. pip install -U https//example.com/url/to/python/package.whl
|
34 |
+
6. python -m TTS.server.server
|
35 |
+
|
36 |
+
You can now open http://localhost:5002 in a browser
|
37 |
+
|
38 |
+
#### Running with nginx/uwsgi:
|
39 |
+
|
40 |
+
**Note:** This method uses an old TTS model, so quality might be low.
|
41 |
+
|
42 |
+
1. apt-get install -y uwsgi uwsgi-plugin-python3 nginx espeak libsndfile1 python3-venv
|
43 |
+
2. python3 -m venv /tmp/venv
|
44 |
+
3. source /tmp/venv/bin/activate
|
45 |
+
4. pip install -U pip setuptools wheel
|
46 |
+
5. pip install -U https//example.com/url/to/python/package.whl
|
47 |
+
6. curl -LO https://github.com/reuben/TTS/releases/download/t2-ljspeech-mold/t2-ljspeech-mold-nginx-uwsgi.zip
|
48 |
+
7. unzip *-nginx-uwsgi.zip
|
49 |
+
8. cp tts_site_nginx /etc/nginx/sites-enabled/default
|
50 |
+
9. service nginx restart
|
51 |
+
10. uwsgi --ini uwsgi.ini
|
52 |
+
|
53 |
+
You can now open http://localhost:80 in a browser (edit the port in /etc/nginx/sites-enabled/tts_site_nginx).
|
54 |
+
Configure number of workers (number of requests that will be processed in parallel) by editing the `uwsgi.ini` file, specifically the `processes` setting.
|
55 |
+
|
56 |
+
#### Creating a server package with an embedded model
|
57 |
+
|
58 |
+
[setup.py](../setup.py) was extended with two new parameters when running the `bdist_wheel` command:
|
59 |
+
|
60 |
+
- `--checkpoint <path to checkpoint file>` - path to model checkpoint file you want to embed in the package
|
61 |
+
- `--model_config <path to config.json file>` - path to corresponding config.json file for the checkpoint
|
62 |
+
|
63 |
+
To create a package, run `python setup.py bdist_wheel --checkpoint /path/to/checkpoint --model_config /path/to/config.json`.
|
64 |
+
|
65 |
+
A Python `.whl` file will be created in the `dist/` folder with the checkpoint and config embedded in it.
|
TTS/server/__init__.py
ADDED
File without changes
|
TTS/server/conf.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder
|
3 |
+
"tts_file":"best_model.pth.tar", // tts checkpoint file
|
4 |
+
"tts_config":"config.json", // tts config.json file
|
5 |
+
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
|
6 |
+
"vocoder_config":null,
|
7 |
+
"vocoder_file": null,
|
8 |
+
"is_wavernn_batched":true,
|
9 |
+
"port": 5002,
|
10 |
+
"use_cuda": true,
|
11 |
+
"debug": true
|
12 |
+
}
|
TTS/server/server.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!flask/bin/python
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import io
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from flask import Flask, render_template, request, send_file
|
9 |
+
from TTS.utils.synthesizer import Synthesizer
|
10 |
+
from TTS.utils.manage import ModelManager
|
11 |
+
from TTS.utils.io import load_config
|
12 |
+
|
13 |
+
|
14 |
+
def create_argparser():
|
15 |
+
def convert_boolean(x):
|
16 |
+
return x.lower() in ['true', '1', 'yes']
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.')
|
20 |
+
parser.add_argument('--model_name', type=str, help='name of one of the released tts models.')
|
21 |
+
parser.add_argument('--vocoder_name', type=str, help='name of one of the released vocoder models.')
|
22 |
+
parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file')
|
23 |
+
parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file')
|
24 |
+
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
25 |
+
parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.')
|
26 |
+
parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.')
|
27 |
+
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
28 |
+
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
29 |
+
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
30 |
+
parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.')
|
31 |
+
return parser
|
32 |
+
|
33 |
+
synthesizer = None
|
34 |
+
|
35 |
+
embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
|
36 |
+
|
37 |
+
embedded_tts_folder = os.path.join(embedded_models_folder, 'tts')
|
38 |
+
tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar')
|
39 |
+
tts_config_file = os.path.join(embedded_tts_folder, 'config.json')
|
40 |
+
|
41 |
+
embedded_vocoder_folder = os.path.join(embedded_models_folder, 'vocoder')
|
42 |
+
vocoder_checkpoint_file = os.path.join(embedded_vocoder_folder, 'checkpoint.pth.tar')
|
43 |
+
vocoder_config_file = os.path.join(embedded_vocoder_folder, 'config.json')
|
44 |
+
|
45 |
+
# These models are soon to be deprecated
|
46 |
+
embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn')
|
47 |
+
wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar')
|
48 |
+
wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json')
|
49 |
+
|
50 |
+
args = create_argparser().parse_args()
|
51 |
+
|
52 |
+
path = Path(__file__).parent / "../.models.json"
|
53 |
+
manager = ModelManager(path)
|
54 |
+
|
55 |
+
if args.list_models:
|
56 |
+
manager.list_models()
|
57 |
+
sys.exit()
|
58 |
+
|
59 |
+
# set models by the released models
|
60 |
+
if args.model_name is not None:
|
61 |
+
tts_checkpoint_file, tts_config_file = manager.download_model(args.model_name)
|
62 |
+
|
63 |
+
if args.vocoder_name is not None:
|
64 |
+
vocoder_checkpoint_file, vocoder_config_file = manager.download_model(args.vocoder_name)
|
65 |
+
|
66 |
+
# If these were not specified in the CLI args, use default values with embedded model files
|
67 |
+
if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file):
|
68 |
+
args.tts_checkpoint = tts_checkpoint_file
|
69 |
+
if not args.tts_config and os.path.isfile(tts_config_file):
|
70 |
+
args.tts_config = tts_config_file
|
71 |
+
|
72 |
+
if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file):
|
73 |
+
args.vocoder_checkpoint = vocoder_checkpoint_file
|
74 |
+
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
|
75 |
+
args.vocoder_config = vocoder_config_file
|
76 |
+
|
77 |
+
synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)
|
78 |
+
|
79 |
+
app = Flask(__name__)
|
80 |
+
|
81 |
+
|
82 |
+
@app.route('/')
|
83 |
+
def index():
|
84 |
+
return render_template('index.html', show_details=args.show_details)
|
85 |
+
|
86 |
+
@app.route('/details')
|
87 |
+
def details():
|
88 |
+
model_config = load_config(args.tts_config)
|
89 |
+
if args.vocoder_config is not None and os.path.isfile(args.vocoder_config):
|
90 |
+
vocoder_config = load_config(args.vocoder_config)
|
91 |
+
else:
|
92 |
+
vocoder_config = None
|
93 |
+
|
94 |
+
return render_template('details.html',
|
95 |
+
show_details=args.show_details
|
96 |
+
, model_config=model_config
|
97 |
+
, vocoder_config=vocoder_config
|
98 |
+
, args=args.__dict__
|
99 |
+
)
|
100 |
+
|
101 |
+
@app.route('/api/tts', methods=['GET'])
|
102 |
+
def tts():
|
103 |
+
text = request.args.get('text')
|
104 |
+
print(" > Model input: {}".format(text))
|
105 |
+
wavs = synthesizer.tts(text)
|
106 |
+
out = io.BytesIO()
|
107 |
+
synthesizer.save_wav(wavs, out)
|
108 |
+
return send_file(out, mimetype='audio/wav')
|
109 |
+
|
110 |
+
|
111 |
+
def main():
|
112 |
+
app.run(debug=args.debug, host='0.0.0.0', port=args.port)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
main()
|
TTS/server/static/TTS_circle.png
ADDED
TTS/server/templates/details.html
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
8 |
+
<meta name="description" content="">
|
9 |
+
<meta name="author" content="">
|
10 |
+
|
11 |
+
<title>TTS engine</title>
|
12 |
+
|
13 |
+
<!-- Bootstrap core CSS -->
|
14 |
+
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
|
15 |
+
integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous"
|
16 |
+
rel="stylesheet">
|
17 |
+
|
18 |
+
<!-- Custom styles for this template -->
|
19 |
+
<style>
|
20 |
+
body {
|
21 |
+
padding-top: 54px;
|
22 |
+
}
|
23 |
+
|
24 |
+
@media (min-width: 992px) {
|
25 |
+
body {
|
26 |
+
padding-top: 56px;
|
27 |
+
}
|
28 |
+
}
|
29 |
+
</style>
|
30 |
+
</head>
|
31 |
+
|
32 |
+
<body>
|
33 |
+
<a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
|
34 |
+
src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
|
35 |
+
|
36 |
+
{% if show_details == true %}
|
37 |
+
|
38 |
+
<div class="container">
|
39 |
+
<b>Model details</b>
|
40 |
+
</div>
|
41 |
+
|
42 |
+
<div class="container">
|
43 |
+
<details>
|
44 |
+
<summary>CLI arguments:</summary>
|
45 |
+
<table border="1" align="center" width="75%">
|
46 |
+
<tr>
|
47 |
+
<td> CLI key </td>
|
48 |
+
<td> Value </td>
|
49 |
+
</tr>
|
50 |
+
|
51 |
+
{% for key, value in args.items() %}
|
52 |
+
|
53 |
+
<tr>
|
54 |
+
<td>{{ key }}</td>
|
55 |
+
<td>{{ value }}</td>
|
56 |
+
</tr>
|
57 |
+
|
58 |
+
{% endfor %}
|
59 |
+
</table>
|
60 |
+
</details>
|
61 |
+
</div></br>
|
62 |
+
|
63 |
+
<div class="container">
|
64 |
+
|
65 |
+
{% if model_config != None %}
|
66 |
+
|
67 |
+
<details>
|
68 |
+
<summary>Model config:</summary>
|
69 |
+
|
70 |
+
<table border="1" align="center" width="75%">
|
71 |
+
<tr>
|
72 |
+
<td> Key </td>
|
73 |
+
<td> Value </td>
|
74 |
+
</tr>
|
75 |
+
|
76 |
+
|
77 |
+
{% for key, value in model_config.items() %}
|
78 |
+
|
79 |
+
<tr>
|
80 |
+
<td>{{ key }}</td>
|
81 |
+
<td>{{ value }}</td>
|
82 |
+
</tr>
|
83 |
+
|
84 |
+
{% endfor %}
|
85 |
+
|
86 |
+
</table>
|
87 |
+
</details>
|
88 |
+
|
89 |
+
{% endif %}
|
90 |
+
|
91 |
+
</div></br>
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
<div class="container">
|
96 |
+
{% if vocoder_config != None %}
|
97 |
+
<details>
|
98 |
+
<summary>Vocoder model config:</summary>
|
99 |
+
|
100 |
+
<table border="1" align="center" width="75%">
|
101 |
+
<tr>
|
102 |
+
<td> Key </td>
|
103 |
+
<td> Value </td>
|
104 |
+
</tr>
|
105 |
+
|
106 |
+
|
107 |
+
{% for key, value in vocoder_config.items() %}
|
108 |
+
|
109 |
+
<tr>
|
110 |
+
<td>{{ key }}</td>
|
111 |
+
<td>{{ value }}</td>
|
112 |
+
</tr>
|
113 |
+
|
114 |
+
{% endfor %}
|
115 |
+
|
116 |
+
|
117 |
+
</table>
|
118 |
+
</details>
|
119 |
+
{% endif %}
|
120 |
+
</div></br>
|
121 |
+
|
122 |
+
{% else %}
|
123 |
+
<div class="container">
|
124 |
+
<b>Please start server with --show_details=true to see details.</b>
|
125 |
+
</div>
|
126 |
+
|
127 |
+
{% endif %}
|
128 |
+
|
129 |
+
</body>
|
130 |
+
|
131 |
+
</html>
|
TTS/server/templates/index.html
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
8 |
+
<meta name="description" content="">
|
9 |
+
<meta name="author" content="">
|
10 |
+
|
11 |
+
<title>TTS engine</title>
|
12 |
+
|
13 |
+
<!-- Bootstrap core CSS -->
|
14 |
+
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
|
15 |
+
integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous" rel="stylesheet">
|
16 |
+
|
17 |
+
<!-- Custom styles for this template -->
|
18 |
+
<style>
|
19 |
+
body {
|
20 |
+
padding-top: 54px;
|
21 |
+
}
|
22 |
+
@media (min-width: 992px) {
|
23 |
+
body {
|
24 |
+
padding-top: 56px;
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
</style>
|
29 |
+
</head>
|
30 |
+
|
31 |
+
<body>
|
32 |
+
<a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;" src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
|
33 |
+
|
34 |
+
<!-- Navigation -->
|
35 |
+
<!--
|
36 |
+
<nav class="navbar navbar-expand-lg navbar-dark bg-dark fixed-top">
|
37 |
+
<div class="container">
|
38 |
+
<a class="navbar-brand" href="#">Mozilla TTS</a>
|
39 |
+
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarResponsive" aria-controls="navbarResponsive" aria-expanded="false" aria-label="Toggle navigation">
|
40 |
+
<span class="navbar-toggler-icon"></span>
|
41 |
+
</button>
|
42 |
+
<div class="collapse navbar-collapse" id="navbarResponsive">
|
43 |
+
<ul class="navbar-nav ml-auto">
|
44 |
+
<li class="nav-item active">
|
45 |
+
<a class="nav-link" href="#">Home
|
46 |
+
<span class="sr-only">(current)</span>
|
47 |
+
</a>
|
48 |
+
</li>
|
49 |
+
</ul>
|
50 |
+
</div>
|
51 |
+
</div>
|
52 |
+
</nav>
|
53 |
+
-->
|
54 |
+
|
55 |
+
<!-- Page Content -->
|
56 |
+
<div class="container">
|
57 |
+
<div class="row">
|
58 |
+
<div class="col-lg-12 text-center">
|
59 |
+
<img class="mt-5" src="{{url_for('static', filename='TTS_circle.png')}}" align="middle" />
|
60 |
+
|
61 |
+
<ul class="list-unstyled">
|
62 |
+
</ul>
|
63 |
+
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
64 |
+
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
65 |
+
{%if show_details%}
|
66 |
+
<button id="details-button" onclick="location.href = 'details'" name="model-details">Model Details</button><br/><br/>
|
67 |
+
{%endif%}
|
68 |
+
<audio id="audio" controls autoplay hidden></audio>
|
69 |
+
<p id="message"></p>
|
70 |
+
</div>
|
71 |
+
</div>
|
72 |
+
</div>
|
73 |
+
|
74 |
+
<!-- Bootstrap core JavaScript -->
|
75 |
+
<script>
|
76 |
+
function q(selector) {return document.querySelector(selector)}
|
77 |
+
q('#text').focus()
|
78 |
+
function do_tts(e) {
|
79 |
+
text = q('#text').value
|
80 |
+
if (text) {
|
81 |
+
q('#message').textContent = 'Synthesizing...'
|
82 |
+
q('#speak-button').disabled = true
|
83 |
+
q('#audio').hidden = true
|
84 |
+
synthesize(text)
|
85 |
+
}
|
86 |
+
e.preventDefault()
|
87 |
+
return false
|
88 |
+
}
|
89 |
+
q('#speak-button').addEventListener('click', do_tts)
|
90 |
+
q('#text').addEventListener('keyup', function(e) {
|
91 |
+
if (e.keyCode == 13) { // enter
|
92 |
+
do_tts(e)
|
93 |
+
}
|
94 |
+
})
|
95 |
+
function synthesize(text) {
|
96 |
+
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
|
97 |
+
.then(function(res) {
|
98 |
+
if (!res.ok) throw Error(res.statusText)
|
99 |
+
return res.blob()
|
100 |
+
}).then(function(blob) {
|
101 |
+
q('#message').textContent = ''
|
102 |
+
q('#speak-button').disabled = false
|
103 |
+
q('#audio').src = URL.createObjectURL(blob)
|
104 |
+
q('#audio').hidden = false
|
105 |
+
}).catch(function(err) {
|
106 |
+
q('#message').textContent = 'Error: ' + err.message
|
107 |
+
q('#speak-button').disabled = false
|
108 |
+
})
|
109 |
+
}
|
110 |
+
</script>
|
111 |
+
|
112 |
+
</body>
|
113 |
+
|
114 |
+
</html>
|
TTS/speaker_encoder/README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Speaker Encoder
|
2 |
+
|
3 |
+
This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding.
|
4 |
+
|
5 |
+
With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart.
|
6 |
+
|
7 |
+
Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q).
|
8 |
+
|
9 |
+
![](umap.png)
|
10 |
+
|
11 |
+
Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page.
|
12 |
+
|
13 |
+
To run the code, you need to follow the same flow as in TTS.
|
14 |
+
|
15 |
+
- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
|
16 |
+
- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
|
17 |
+
- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
|
18 |
+
- Watch training on Tensorboard as in TTS
|
TTS/speaker_encoder/__init__.py
ADDED
File without changes
|
TTS/speaker_encoder/config.json
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"run_name": "mueller91",
|
4 |
+
"run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
|
5 |
+
"audio":{
|
6 |
+
// Audio processing parameters
|
7 |
+
"num_mels": 40, // size of the mel spec frame.
|
8 |
+
"fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame.
|
9 |
+
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
10 |
+
"win_length": 400, // stft window length in ms.
|
11 |
+
"hop_length": 160, // stft window hop-lengh in ms.
|
12 |
+
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
13 |
+
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
14 |
+
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
15 |
+
"min_level_db": -100, // normalization range
|
16 |
+
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
17 |
+
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
18 |
+
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
19 |
+
// Normalization parameters
|
20 |
+
"signal_norm": true, // normalize the spec values in range [0, 1]
|
21 |
+
"symmetric_norm": true, // move normalization to range [-1, 1]
|
22 |
+
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
23 |
+
"clip_norm": true, // clip normalized values into the range.
|
24 |
+
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
25 |
+
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
26 |
+
"do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
27 |
+
"trim_db": 60 // threshold for timming silence. Set this according to your dataset.
|
28 |
+
},
|
29 |
+
"reinit_layers": [],
|
30 |
+
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
|
31 |
+
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
32 |
+
"epochs": 1000, // total number of epochs to train.
|
33 |
+
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
34 |
+
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
35 |
+
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
36 |
+
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
37 |
+
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
38 |
+
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
39 |
+
"num_utters_per_speaker": 10, //
|
40 |
+
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
41 |
+
"wd": 0.000001, // Weight decay weight.
|
42 |
+
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
43 |
+
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
44 |
+
"print_step": 20, // Number of steps to log traning on console.
|
45 |
+
"output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
46 |
+
"model": {
|
47 |
+
"input_dim": 40,
|
48 |
+
"proj_dim": 256,
|
49 |
+
"lstm_dim": 768,
|
50 |
+
"num_lstm_layers": 3,
|
51 |
+
"use_lstm_with_projection": true
|
52 |
+
},
|
53 |
+
"storage": {
|
54 |
+
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
55 |
+
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
|
56 |
+
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
|
57 |
+
},
|
58 |
+
"datasets":
|
59 |
+
[
|
60 |
+
{
|
61 |
+
"name": "vctk_slim",
|
62 |
+
"path": "../../../audio-datasets/en/VCTK-Corpus/",
|
63 |
+
"meta_file_train": null,
|
64 |
+
"meta_file_val": null
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "libri_tts",
|
68 |
+
"path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
|
69 |
+
"meta_file_train": null,
|
70 |
+
"meta_file_val": null
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "libri_tts",
|
74 |
+
"path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
|
75 |
+
"meta_file_train": null,
|
76 |
+
"meta_file_val": null
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "libri_tts",
|
80 |
+
"path": "../../../audio-datasets/en/LibriTTS/train-other-500",
|
81 |
+
"meta_file_train": null,
|
82 |
+
"meta_file_val": null
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"name": "voxceleb1",
|
86 |
+
"path": "../../../audio-datasets/en/voxceleb1/",
|
87 |
+
"meta_file_train": null,
|
88 |
+
"meta_file_val": null
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"name": "voxceleb2",
|
92 |
+
"path": "../../../audio-datasets/en/voxceleb2/",
|
93 |
+
"meta_file_train": null,
|
94 |
+
"meta_file_val": null
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"name": "common_voice",
|
98 |
+
"path": "../../../audio-datasets/en/MozillaCommonVoice",
|
99 |
+
"meta_file_train": "train.tsv",
|
100 |
+
"meta_file_val": "test.tsv"
|
101 |
+
}
|
102 |
+
]
|
103 |
+
}
|
TTS/speaker_encoder/dataset.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import numpy as np
|
3 |
+
import queue
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
class MyDataset(Dataset):
|
11 |
+
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
|
12 |
+
storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
|
13 |
+
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
17 |
+
meta_data (list): list of dataset instances.
|
18 |
+
seq_len (int): voice segment length in seconds.
|
19 |
+
verbose (bool): print diagnostic information.
|
20 |
+
"""
|
21 |
+
self.items = meta_data
|
22 |
+
self.sample_rate = ap.sample_rate
|
23 |
+
self.voice_len = voice_len
|
24 |
+
self.seq_len = int(voice_len * self.sample_rate)
|
25 |
+
self.num_speakers_in_batch = num_speakers_in_batch
|
26 |
+
self.num_utter_per_speaker = num_utter_per_speaker
|
27 |
+
self.skip_speakers = skip_speakers
|
28 |
+
self.ap = ap
|
29 |
+
self.verbose = verbose
|
30 |
+
self.__parse_items()
|
31 |
+
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
|
32 |
+
self.sample_from_storage_p = float(sample_from_storage_p)
|
33 |
+
self.additive_noise = float(additive_noise)
|
34 |
+
if self.verbose:
|
35 |
+
print("\n > DataLoader initialization")
|
36 |
+
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
37 |
+
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
|
38 |
+
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
39 |
+
print(f" | > Noise added : {self.additive_noise}")
|
40 |
+
print(f" | > Number of instances : {len(self.items)}")
|
41 |
+
print(f" | > Sequence length: {self.seq_len}")
|
42 |
+
print(f" | > Num speakers: {len(self.speakers)}")
|
43 |
+
|
44 |
+
def load_wav(self, filename):
|
45 |
+
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
46 |
+
return audio
|
47 |
+
|
48 |
+
def load_data(self, idx):
|
49 |
+
text, wav_file, speaker_name = self.items[idx]
|
50 |
+
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
51 |
+
mel = self.ap.melspectrogram(wav).astype("float32")
|
52 |
+
# sample seq_len
|
53 |
+
|
54 |
+
assert text.size > 0, self.items[idx][1]
|
55 |
+
assert wav.size > 0, self.items[idx][1]
|
56 |
+
|
57 |
+
sample = {
|
58 |
+
"mel": mel,
|
59 |
+
"item_idx": self.items[idx][1],
|
60 |
+
"speaker_name": speaker_name,
|
61 |
+
}
|
62 |
+
return sample
|
63 |
+
|
64 |
+
def __parse_items(self):
|
65 |
+
self.speaker_to_utters = {}
|
66 |
+
for i in self.items:
|
67 |
+
path_ = i[1]
|
68 |
+
speaker_ = i[2]
|
69 |
+
if speaker_ in self.speaker_to_utters.keys():
|
70 |
+
self.speaker_to_utters[speaker_].append(path_)
|
71 |
+
else:
|
72 |
+
self.speaker_to_utters[speaker_] = [path_, ]
|
73 |
+
|
74 |
+
if self.skip_speakers:
|
75 |
+
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
|
76 |
+
len(v) >= self.num_utter_per_speaker}
|
77 |
+
|
78 |
+
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
|
79 |
+
|
80 |
+
# def __parse_items(self):
|
81 |
+
# """
|
82 |
+
# Find unique speaker ids and create a dict mapping utterances from speaker id
|
83 |
+
# """
|
84 |
+
# speakers = list({item[-1] for item in self.items})
|
85 |
+
# self.speaker_to_utters = {}
|
86 |
+
# self.speakers = []
|
87 |
+
# for speaker in speakers:
|
88 |
+
# speaker_utters = [item[1] for item in self.items if item[2] == speaker]
|
89 |
+
# if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
|
90 |
+
# print(
|
91 |
+
# f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
|
92 |
+
# )
|
93 |
+
# else:
|
94 |
+
# self.speakers.append(speaker)
|
95 |
+
# self.speaker_to_utters[speaker] = speaker_utters
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return int(1e10)
|
99 |
+
|
100 |
+
def __sample_speaker(self):
|
101 |
+
speaker = random.sample(self.speakers, 1)[0]
|
102 |
+
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
103 |
+
utters = random.choices(
|
104 |
+
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
utters = random.sample(
|
108 |
+
self.speaker_to_utters[speaker], self.num_utter_per_speaker
|
109 |
+
)
|
110 |
+
return speaker, utters
|
111 |
+
|
112 |
+
def __sample_speaker_utterances(self, speaker):
|
113 |
+
"""
|
114 |
+
Sample all M utterances for the given speaker.
|
115 |
+
"""
|
116 |
+
wavs = []
|
117 |
+
labels = []
|
118 |
+
for _ in range(self.num_utter_per_speaker):
|
119 |
+
# TODO:dummy but works
|
120 |
+
while True:
|
121 |
+
if len(self.speaker_to_utters[speaker]) > 0:
|
122 |
+
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
123 |
+
else:
|
124 |
+
self.speakers.remove(speaker)
|
125 |
+
speaker, _ = self.__sample_speaker()
|
126 |
+
continue
|
127 |
+
wav = self.load_wav(utter)
|
128 |
+
if wav.shape[0] - self.seq_len > 0:
|
129 |
+
break
|
130 |
+
self.speaker_to_utters[speaker].remove(utter)
|
131 |
+
|
132 |
+
wavs.append(wav)
|
133 |
+
labels.append(speaker)
|
134 |
+
return wavs, labels
|
135 |
+
|
136 |
+
def __getitem__(self, idx):
|
137 |
+
speaker, _ = self.__sample_speaker()
|
138 |
+
return speaker
|
139 |
+
|
140 |
+
def collate_fn(self, batch):
|
141 |
+
labels = []
|
142 |
+
feats = []
|
143 |
+
for speaker in batch:
|
144 |
+
if random.random() < self.sample_from_storage_p and self.storage.full():
|
145 |
+
# sample from storage (if full), ignoring the speaker
|
146 |
+
wavs_, labels_ = random.choice(self.storage.queue)
|
147 |
+
else:
|
148 |
+
# don't sample from storage, but from HDD
|
149 |
+
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
150 |
+
# if storage is full, remove an item
|
151 |
+
if self.storage.full():
|
152 |
+
_ = self.storage.get_nowait()
|
153 |
+
# put the newly loaded item into storage
|
154 |
+
self.storage.put_nowait((wavs_, labels_))
|
155 |
+
|
156 |
+
# add random gaussian noise
|
157 |
+
if self.additive_noise > 0:
|
158 |
+
noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
|
159 |
+
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
|
160 |
+
|
161 |
+
# get a random subset of each of the wavs and convert to MFCC.
|
162 |
+
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
|
163 |
+
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
|
164 |
+
feats_ = [torch.FloatTensor(mel) for mel in mels_]
|
165 |
+
|
166 |
+
labels.append(labels_)
|
167 |
+
feats.extend(feats_)
|
168 |
+
feats = torch.stack(feats)
|
169 |
+
return feats.transpose(1, 2), labels
|
TTS/speaker_encoder/losses.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
# adapted from https://github.com/cvqluu/GE2E-Loss
|
7 |
+
class GE2ELoss(nn.Module):
|
8 |
+
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
|
9 |
+
"""
|
10 |
+
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
|
11 |
+
Accepts an input of size (N, M, D)
|
12 |
+
where N is the number of speakers in the batch,
|
13 |
+
M is the number of utterances per speaker,
|
14 |
+
and D is the dimensionality of the embedding vector (e.g. d-vector)
|
15 |
+
Args:
|
16 |
+
- init_w (float): defines the initial value of w in Equation (5) of [1]
|
17 |
+
- init_b (float): definies the initial value of b in Equation (5) of [1]
|
18 |
+
"""
|
19 |
+
super(GE2ELoss, self).__init__()
|
20 |
+
# pylint: disable=E1102
|
21 |
+
self.w = nn.Parameter(torch.tensor(init_w))
|
22 |
+
# pylint: disable=E1102
|
23 |
+
self.b = nn.Parameter(torch.tensor(init_b))
|
24 |
+
self.loss_method = loss_method
|
25 |
+
|
26 |
+
print(' > Initialised Generalized End-to-End loss')
|
27 |
+
|
28 |
+
assert self.loss_method in ["softmax", "contrast"]
|
29 |
+
|
30 |
+
if self.loss_method == "softmax":
|
31 |
+
self.embed_loss = self.embed_loss_softmax
|
32 |
+
if self.loss_method == "contrast":
|
33 |
+
self.embed_loss = self.embed_loss_contrast
|
34 |
+
|
35 |
+
# pylint: disable=R0201
|
36 |
+
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
|
37 |
+
"""
|
38 |
+
Calculates the new centroids excluding the reference utterance
|
39 |
+
"""
|
40 |
+
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
|
41 |
+
excl = torch.mean(excl, 0)
|
42 |
+
new_centroids = []
|
43 |
+
for i, centroid in enumerate(centroids):
|
44 |
+
if i == spkr:
|
45 |
+
new_centroids.append(excl)
|
46 |
+
else:
|
47 |
+
new_centroids.append(centroid)
|
48 |
+
return torch.stack(new_centroids)
|
49 |
+
|
50 |
+
def calc_cosine_sim(self, dvecs, centroids):
|
51 |
+
"""
|
52 |
+
Make the cosine similarity matrix with dims (N,M,N)
|
53 |
+
"""
|
54 |
+
cos_sim_matrix = []
|
55 |
+
for spkr_idx, speaker in enumerate(dvecs):
|
56 |
+
cs_row = []
|
57 |
+
for utt_idx, utterance in enumerate(speaker):
|
58 |
+
new_centroids = self.calc_new_centroids(
|
59 |
+
dvecs, centroids, spkr_idx, utt_idx
|
60 |
+
)
|
61 |
+
# vector based cosine similarity for speed
|
62 |
+
cs_row.append(
|
63 |
+
torch.clamp(
|
64 |
+
torch.mm(
|
65 |
+
utterance.unsqueeze(1).transpose(0, 1),
|
66 |
+
new_centroids.transpose(0, 1),
|
67 |
+
)
|
68 |
+
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
|
69 |
+
1e-6,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
cs_row = torch.cat(cs_row, dim=0)
|
73 |
+
cos_sim_matrix.append(cs_row)
|
74 |
+
return torch.stack(cos_sim_matrix)
|
75 |
+
|
76 |
+
# pylint: disable=R0201
|
77 |
+
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
|
78 |
+
"""
|
79 |
+
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
|
80 |
+
"""
|
81 |
+
N, M, _ = dvecs.shape
|
82 |
+
L = []
|
83 |
+
for j in range(N):
|
84 |
+
L_row = []
|
85 |
+
for i in range(M):
|
86 |
+
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
|
87 |
+
L_row = torch.stack(L_row)
|
88 |
+
L.append(L_row)
|
89 |
+
return torch.stack(L)
|
90 |
+
|
91 |
+
# pylint: disable=R0201
|
92 |
+
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
|
93 |
+
"""
|
94 |
+
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
|
95 |
+
"""
|
96 |
+
N, M, _ = dvecs.shape
|
97 |
+
L = []
|
98 |
+
for j in range(N):
|
99 |
+
L_row = []
|
100 |
+
for i in range(M):
|
101 |
+
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
|
102 |
+
excl_centroids_sigmoids = torch.cat(
|
103 |
+
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
|
104 |
+
)
|
105 |
+
L_row.append(
|
106 |
+
1.0
|
107 |
+
- torch.sigmoid(cos_sim_matrix[j, i, j])
|
108 |
+
+ torch.max(excl_centroids_sigmoids)
|
109 |
+
)
|
110 |
+
L_row = torch.stack(L_row)
|
111 |
+
L.append(L_row)
|
112 |
+
return torch.stack(L)
|
113 |
+
|
114 |
+
def forward(self, dvecs):
|
115 |
+
"""
|
116 |
+
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
117 |
+
"""
|
118 |
+
centroids = torch.mean(dvecs, 1)
|
119 |
+
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
120 |
+
torch.clamp(self.w, 1e-6)
|
121 |
+
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
122 |
+
L = self.embed_loss(dvecs, cos_sim_matrix)
|
123 |
+
return L.mean()
|
124 |
+
|
125 |
+
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
126 |
+
class AngleProtoLoss(nn.Module):
|
127 |
+
"""
|
128 |
+
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
129 |
+
Accepts an input of size (N, M, D)
|
130 |
+
where N is the number of speakers in the batch,
|
131 |
+
M is the number of utterances per speaker,
|
132 |
+
and D is the dimensionality of the embedding vector
|
133 |
+
Args:
|
134 |
+
- init_w (float): defines the initial value of w
|
135 |
+
- init_b (float): definies the initial value of b
|
136 |
+
"""
|
137 |
+
def __init__(self, init_w=10.0, init_b=-5.0):
|
138 |
+
super(AngleProtoLoss, self).__init__()
|
139 |
+
# pylint: disable=E1102
|
140 |
+
self.w = nn.Parameter(torch.tensor(init_w))
|
141 |
+
# pylint: disable=E1102
|
142 |
+
self.b = nn.Parameter(torch.tensor(init_b))
|
143 |
+
self.criterion = torch.nn.CrossEntropyLoss()
|
144 |
+
|
145 |
+
print(' > Initialised Angular Prototypical loss')
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
"""
|
149 |
+
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
150 |
+
"""
|
151 |
+
out_anchor = torch.mean(x[:, 1:, :], 1)
|
152 |
+
out_positive = x[:, 0, :]
|
153 |
+
num_speakers = out_anchor.size()[0]
|
154 |
+
|
155 |
+
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
|
156 |
+
torch.clamp(self.w, 1e-6)
|
157 |
+
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
158 |
+
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
|
159 |
+
L = self.criterion(cos_sim_matrix, label)
|
160 |
+
return L
|
TTS/speaker_encoder/model.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class LSTMWithProjection(nn.Module):
|
6 |
+
def __init__(self, input_size, hidden_size, proj_size):
|
7 |
+
super().__init__()
|
8 |
+
self.input_size = input_size
|
9 |
+
self.hidden_size = hidden_size
|
10 |
+
self.proj_size = proj_size
|
11 |
+
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
12 |
+
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
self.lstm.flatten_parameters()
|
16 |
+
o, (_, _) = self.lstm(x)
|
17 |
+
return self.linear(o)
|
18 |
+
|
19 |
+
class LSTMWithoutProjection(nn.Module):
|
20 |
+
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
|
21 |
+
super().__init__()
|
22 |
+
self.lstm = nn.LSTM(input_size=input_dim,
|
23 |
+
hidden_size=lstm_dim,
|
24 |
+
num_layers=num_lstm_layers,
|
25 |
+
batch_first=True)
|
26 |
+
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
|
27 |
+
self.relu = nn.ReLU()
|
28 |
+
def forward(self, x):
|
29 |
+
_, (hidden, _) = self.lstm(x)
|
30 |
+
return self.relu(self.linear(hidden[-1]))
|
31 |
+
|
32 |
+
class SpeakerEncoder(nn.Module):
|
33 |
+
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
34 |
+
super().__init__()
|
35 |
+
self.use_lstm_with_projection = use_lstm_with_projection
|
36 |
+
layers = []
|
37 |
+
# choise LSTM layer
|
38 |
+
if use_lstm_with_projection:
|
39 |
+
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
|
40 |
+
for _ in range(num_lstm_layers - 1):
|
41 |
+
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
|
42 |
+
self.layers = nn.Sequential(*layers)
|
43 |
+
else:
|
44 |
+
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
45 |
+
|
46 |
+
self._init_layers()
|
47 |
+
|
48 |
+
def _init_layers(self):
|
49 |
+
for name, param in self.layers.named_parameters():
|
50 |
+
if "bias" in name:
|
51 |
+
nn.init.constant_(param, 0.0)
|
52 |
+
elif "weight" in name:
|
53 |
+
nn.init.xavier_normal_(param)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
# TODO: implement state passing for lstms
|
57 |
+
d = self.layers(x)
|
58 |
+
if self.use_lstm_with_projection:
|
59 |
+
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
60 |
+
else:
|
61 |
+
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
62 |
+
return d
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def inference(self, x):
|
66 |
+
d = self.layers.forward(x)
|
67 |
+
if self.use_lstm_with_projection:
|
68 |
+
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
69 |
+
else:
|
70 |
+
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
71 |
+
return d
|
72 |
+
|
73 |
+
def compute_embedding(self, x, num_frames=160, overlap=0.5):
|
74 |
+
"""
|
75 |
+
Generate embeddings for a batch of utterances
|
76 |
+
x: 1xTxD
|
77 |
+
"""
|
78 |
+
num_overlap = int(num_frames * overlap)
|
79 |
+
max_len = x.shape[1]
|
80 |
+
embed = None
|
81 |
+
cur_iter = 0
|
82 |
+
for offset in range(0, max_len, num_frames - num_overlap):
|
83 |
+
cur_iter += 1
|
84 |
+
end_offset = min(x.shape[1], offset + num_frames)
|
85 |
+
frames = x[:, offset:end_offset]
|
86 |
+
if embed is None:
|
87 |
+
embed = self.inference(frames)
|
88 |
+
else:
|
89 |
+
embed += self.inference(frames)
|
90 |
+
return embed / cur_iter
|
91 |
+
|
92 |
+
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
93 |
+
"""
|
94 |
+
Generate embeddings for a batch of utterances
|
95 |
+
x: BxTxD
|
96 |
+
"""
|
97 |
+
num_overlap = num_frames * overlap
|
98 |
+
max_len = x.shape[1]
|
99 |
+
embed = None
|
100 |
+
num_iters = seq_lens / (num_frames - num_overlap)
|
101 |
+
cur_iter = 0
|
102 |
+
for offset in range(0, max_len, num_frames - num_overlap):
|
103 |
+
cur_iter += 1
|
104 |
+
end_offset = min(x.shape[1], offset + num_frames)
|
105 |
+
frames = x[:, offset:end_offset]
|
106 |
+
if embed is None:
|
107 |
+
embed = self.inference(frames)
|
108 |
+
else:
|
109 |
+
embed[cur_iter <= num_iters, :] += self.inference(
|
110 |
+
frames[cur_iter <= num_iters, :, :]
|
111 |
+
)
|
112 |
+
return embed / num_iters
|
TTS/speaker_encoder/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
umap-learn
|
2 |
+
numpy>=1.17.0
|
TTS/speaker_encoder/umap.png
ADDED