KasugaiSakura commited on
Commit
58fbdee
·
verified ·
1 Parent(s): 1e3e261

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -35
  2. .github/ISSUE_TEMPLATE/ask_for_help.yaml +124 -0
  3. .github/ISSUE_TEMPLATE/ask_for_help_en_US.yaml +124 -0
  4. .github/ISSUE_TEMPLATE/bug_report.yaml +92 -0
  5. .github/ISSUE_TEMPLATE/bug_report_en_US.yaml +92 -0
  6. .github/ISSUE_TEMPLATE/config.yml +5 -0
  7. .github/ISSUE_TEMPLATE/default.md +7 -0
  8. .github/workflows/reviewdog.yml +18 -0
  9. .github/workflows/ruff.yml +8 -0
  10. .gitignore +165 -0
  11. .ruff.toml +4 -0
  12. LICENSE +661 -0
  13. README.md +579 -5
  14. README_zh_CN.md +575 -0
  15. app.py +312 -0
  16. cluster/__init__.py +29 -0
  17. cluster/kmeans.py +204 -0
  18. cluster/train_cluster.py +85 -0
  19. compress_model.py +72 -0
  20. configs/diffusion.yaml +0 -0
  21. configs_template/config_template.json +79 -0
  22. configs_template/config_tiny_template.json +79 -0
  23. configs_template/diffusion_template.yaml +51 -0
  24. data_utils.py +185 -0
  25. diffusion/__init__.py +0 -0
  26. diffusion/data_loaders.py +288 -0
  27. diffusion/diffusion.py +396 -0
  28. diffusion/diffusion_onnx.py +614 -0
  29. diffusion/dpm_solver_pytorch.py +1307 -0
  30. diffusion/how to export onnx.md +4 -0
  31. diffusion/infer_gt_mel.py +74 -0
  32. diffusion/logger/__init__.py +0 -0
  33. diffusion/logger/saver.py +145 -0
  34. diffusion/logger/utils.py +127 -0
  35. diffusion/onnx_export.py +235 -0
  36. diffusion/solver.py +200 -0
  37. diffusion/uni_pc.py +733 -0
  38. diffusion/unit2mel.py +167 -0
  39. diffusion/vocoder.py +95 -0
  40. diffusion/wavenet.py +108 -0
  41. edgetts/tts.py +48 -0
  42. edgetts/tts_voices.py +306 -0
  43. export_index_for_onnx.py +20 -0
  44. flask_api.py +60 -0
  45. flask_api_full_song.py +55 -0
  46. inference/__init__.py +0 -0
  47. inference/infer_tool.py +546 -0
  48. inference/infer_tool_grad.py +156 -0
  49. inference/slicer.py +142 -0
  50. inference_main.py +155 -0
.gitattributes CHANGED
@@ -1,35 +1,9 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ * text=auto eol=lf
2
+ pretrain/checkpoint_best_legacy_500.pt filter=lfs diff=lfs merge=lfs -text
3
+ pretrain/fcpe.pt filter=lfs diff=lfs merge=lfs -text
4
+ pretrain/nsf_hifigan/model filter=lfs diff=lfs merge=lfs -text
5
+ pretrain/rmvpe.pt filter=lfs diff=lfs merge=lfs -text
6
+ trained/G_58400.pth filter=lfs diff=lfs merge=lfs -text
7
+ trained/feature_and_index.pkl filter=lfs diff=lfs merge=lfs -text
8
+ trained/kmeans_10000.pt filter=lfs diff=lfs merge=lfs -text
9
+ trained/model_14000.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/ask_for_help.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 请求帮助
2
+ description: 遇到了无法自行解决的错误
3
+ title: '[Help]: '
4
+ labels: [ "help wanted" ]
5
+
6
+ body:
7
+ - type: markdown
8
+ attributes:
9
+ value: |
10
+ #### 提问前请先自己去尝试解决,比如查看[本仓库wiki](https://github.com/svc-develop-team/so-vits-svc/wiki),也可以借助chatgpt或一些搜索引擎(谷歌/必应/New Bing/StackOverflow等等)。如果实在无法自己解决再发issue,在提issue之前,请先了解《[提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md)》。
11
+ ---
12
+ ### 什么样的issue会被直接close
13
+ 1. 伸手党
14
+ 2. 一键包/环境包相关
15
+ 3. 提供的信息不全
16
+ 4. 低级的如缺少依赖而导致无法运行的问题
17
+ 4. 所用的数据集是无授权数据集(游戏角色/二次元人物暂不归为此类,但是训练时候也要小心谨慎。如果能联系到官方,必须先和官方联系并核实清楚)
18
+ ---
19
+
20
+ - type: checkboxes
21
+ id: Clause
22
+ attributes:
23
+ label: 请勾选下方的确认框。
24
+ options:
25
+ - label: "我已仔细阅读[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README_zh_CN.md)和[wiki中的Quick solution](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)。"
26
+ required: true
27
+ - label: "我已通过各种搜索引擎排查问题,我要提出的问题并不常见。"
28
+ required: true
29
+ - label: "我未在使用由第三方用户提供的一键包/环境包。"
30
+ required: true
31
+
32
+ - type: markdown
33
+ attributes:
34
+ value: |
35
+ # 请根据实际使用环境填写以下信息
36
+
37
+ - type: input
38
+ id: System
39
+ attributes:
40
+ label: 系统平台版本号
41
+ description: Windows执行`winver` | Linux执行`uname -a`
42
+ validations:
43
+ required: true
44
+
45
+ - type: input
46
+ id: GPU
47
+ attributes:
48
+ label: GPU 型号
49
+ description: 执行`nvidia-smi`
50
+ validations:
51
+ required: true
52
+
53
+ - type: input
54
+ id: PythonVersion
55
+ attributes:
56
+ label: Python版本
57
+ description: 执行`python -V`
58
+ validations:
59
+ required: true
60
+
61
+ - type: input
62
+ id: PyTorchVersion
63
+ attributes:
64
+ label: PyTorch版本
65
+ description: 执行`pip show torch`
66
+ validations:
67
+ required: true
68
+
69
+ - type: dropdown
70
+ id: Branch
71
+ attributes:
72
+ label: sovits分支
73
+ options:
74
+ - 4.0(默认)
75
+ - 4.0-v2
76
+ - 3.0-32k
77
+ - 3.0-48k
78
+ validations:
79
+ required: true
80
+
81
+ - type: input
82
+ id: DatasetSource
83
+ attributes:
84
+ label: 数据集来源(用于判断数据集质量)
85
+ description: 如:UVR处理过的vtb直播音频、录音棚录制
86
+ validations:
87
+ required: true
88
+
89
+ - type: input
90
+ id: WhereOccurs
91
+ attributes:
92
+ label: 出现问题的环节或执行的命令
93
+ description: 如:预处理、训练、`python preprocess_hubert_f0.py`
94
+ validations:
95
+ required: true
96
+
97
+ - type: textarea
98
+ id: Description
99
+ attributes:
100
+ label: 问题描述
101
+ description: 在这里描述自己的问题,越详细越好
102
+ validations:
103
+ required: true
104
+
105
+ - type: textarea
106
+ id: Log
107
+ attributes:
108
+ label: 日志
109
+ description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面
110
+ render: python
111
+ validations:
112
+ required: true
113
+
114
+ - type: textarea
115
+ id: ValidOneClick
116
+ attributes:
117
+ label: 截图`so-vits-svc`、`logs/44k`文件夹并粘贴到此处
118
+ validations:
119
+ required: true
120
+
121
+ - type: textarea
122
+ id: Supplementary
123
+ attributes:
124
+ label: 补充说明
.github/ISSUE_TEMPLATE/ask_for_help_en_US.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Ask for help
2
+ description: Encountered an error cannot be resolved by self
3
+ title: '[Help]: '
4
+ labels: [ "help wanted" ]
5
+
6
+ body:
7
+ - type: markdown
8
+ attributes:
9
+ value: |
10
+ #### Please try to solve the problem yourself before asking for help. At first you can read *[repo wiki](https://github.com/svc-develop-team/so-vits-svc/wiki)*. Then you can use chatgpt or some search engines like google, bing, new bing and StackOverflow until you really find that you can't solve it by yourself. And before you raise an issue, please understand *[How To Ask Questions The Smart Way](http://www.catb.org/~esr/faqs/smart-questions.html)* in advance.
11
+ ---
12
+ ### What kind of issue will be closed immediately
13
+ 1. Beggars or Free Riders
14
+ 2. One click package / Environment package (Not using `pip install -r requirement.txt`)
15
+ 3. Incomplete information
16
+ 4. Stupid issues such as miss a dependency package
17
+ 4. Using unlicenced dataset (Game characters / anime characters are not included in this category temporarily but you still need to pay attention. If you can contact the official, you must contact the official and verify it at first.)
18
+ ---
19
+
20
+ - type: checkboxes
21
+ id: Clause
22
+ attributes:
23
+ label: Please check the checkboxes below.
24
+ options:
25
+ - label: "I have read *[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README.md)* and *[Quick solution in wiki](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)* carefully."
26
+ required: true
27
+ - label: "I have been troubleshooting issues through various search engines. The questions I want to ask are not common."
28
+ required: true
29
+ - label: "I am NOT using one click package / environment package."
30
+ required: true
31
+
32
+ - type: markdown
33
+ attributes:
34
+ value: |
35
+ # Please fill in the following information according to your actual environment
36
+
37
+ - type: input
38
+ id: System
39
+ attributes:
40
+ label: OS version
41
+ description: Windows run `winver` | Linux run `uname -a`
42
+ validations:
43
+ required: true
44
+
45
+ - type: input
46
+ id: GPU
47
+ attributes:
48
+ label: GPU
49
+ description: Run `nvidia-smi`
50
+ validations:
51
+ required: true
52
+
53
+ - type: input
54
+ id: PythonVersion
55
+ attributes:
56
+ label: Python version
57
+ description: Run `python -V`
58
+ validations:
59
+ required: true
60
+
61
+ - type: input
62
+ id: PyTorchVersion
63
+ attributes:
64
+ label: PyTorch version
65
+ description: Run `pip show torch`
66
+ validations:
67
+ required: true
68
+
69
+ - type: dropdown
70
+ id: Branch
71
+ attributes:
72
+ label: Branch of sovits
73
+ options:
74
+ - 4.0(Default)
75
+ - 4.0-v2
76
+ - 3.0-32k
77
+ - 3.0-48k
78
+ validations:
79
+ required: true
80
+
81
+ - type: input
82
+ id: DatasetSource
83
+ attributes:
84
+ label: Dataset source (Used to judge the dataset quality)
85
+ description: Such as UVR-processed streaming audio / Recorded in recording studio
86
+ validations:
87
+ required: true
88
+
89
+ - type: input
90
+ id: WhereOccurs
91
+ attributes:
92
+ label: Where thr problem occurs or what command you executed
93
+ description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py`
94
+ validations:
95
+ required: true
96
+
97
+ - type: textarea
98
+ id: Description
99
+ attributes:
100
+ label: Problem description
101
+ description: Describe your problem here, the more detailed the better.
102
+ validations:
103
+ required: true
104
+
105
+ - type: textarea
106
+ id: Log
107
+ attributes:
108
+ label: Log
109
+ description: All information output from the command you executed to the end of execution (include the command). It can also be directly posted below if there is only few text.
110
+ render: python
111
+ validations:
112
+ required: true
113
+
114
+ - type: textarea
115
+ id: ValidOneClick
116
+ attributes:
117
+ label: Screenshot `so-vits-svc` and `logs/44k` folders and paste here
118
+ validations:
119
+ required: true
120
+
121
+ - type: textarea
122
+ id: Supplementary
123
+ attributes:
124
+ label: Supplementary description
.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 问题回报
2
+ description: 遇到了BUG?!
3
+ title: '[Bug]: '
4
+ labels: [ "bug?" ]
5
+
6
+ body:
7
+ - type: markdown
8
+ attributes:
9
+ value: |
10
+ # 请根据实际使用环境填写以下信息
11
+
12
+ - type: input
13
+ id: System
14
+ attributes:
15
+ label: 系统平台版本号
16
+ description: Windows执行`winver` | Linux执行`uname -a`
17
+ validations:
18
+ required: true
19
+
20
+ - type: input
21
+ id: GPU
22
+ attributes:
23
+ label: GPU 型号
24
+ description: 执行`nvidia-smi`
25
+ validations:
26
+ required: true
27
+
28
+ - type: input
29
+ id: PythonVersion
30
+ attributes:
31
+ label: Python版本
32
+ description: 执行`python -V`
33
+ validations:
34
+ required: true
35
+
36
+ - type: input
37
+ id: PyTorchVersion
38
+ attributes:
39
+ label: PyTorch版本
40
+ description: 执行`pip show torch`
41
+ validations:
42
+ required: true
43
+
44
+ - type: dropdown
45
+ id: Branch
46
+ attributes:
47
+ label: sovits分支
48
+ options:
49
+ - 4.0(默认)
50
+ - 4.0-v2
51
+ - 3.0-32k
52
+ - 3.0-48k
53
+ validations:
54
+ required: true
55
+
56
+ - type: input
57
+ id: DatasetSource
58
+ attributes:
59
+ label: 数据集来源(用于判断数据集质量)
60
+ description: 如:UVR处理过的vtb直播音频、录音棚录制
61
+ validations:
62
+ required: true
63
+
64
+ - type: input
65
+ id: WhereOccurs
66
+ attributes:
67
+ label: 出现问题的环节或执行的命令
68
+ description: 如:预处理、训练、`python preprocess_hubert_f0.py`
69
+ validations:
70
+ required: true
71
+
72
+ - type: textarea
73
+ id: Description
74
+ attributes:
75
+ label: 情况描述
76
+ description: 在这里描述遇到的情况,越详细越好
77
+ validations:
78
+ required: true
79
+
80
+ - type: textarea
81
+ id: Log
82
+ attributes:
83
+ label: 日志
84
+ description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面
85
+ render: python
86
+ validations:
87
+ required: true
88
+
89
+ - type: textarea
90
+ id: Supplementary
91
+ attributes:
92
+ label: 补充说明
.github/ISSUE_TEMPLATE/bug_report_en_US.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Bug report
2
+ description: Encountered an bug?!
3
+ title: '[Bug]: '
4
+ labels: [ "bug?" ]
5
+
6
+ body:
7
+ - type: markdown
8
+ attributes:
9
+ value: |
10
+ # Please fill in the following information according to your actual environment
11
+
12
+ - type: input
13
+ id: System
14
+ attributes:
15
+ label: OS version
16
+ description: Windows run `winver` | Linux run `uname -a`
17
+ validations:
18
+ required: true
19
+
20
+ - type: input
21
+ id: GPU
22
+ attributes:
23
+ label: GPU
24
+ description: Run `nvidia-smi`
25
+ validations:
26
+ required: true
27
+
28
+ - type: input
29
+ id: PythonVersion
30
+ attributes:
31
+ label: Python version
32
+ description: Run `python -V`
33
+ validations:
34
+ required: true
35
+
36
+ - type: input
37
+ id: PyTorchVersion
38
+ attributes:
39
+ label: PyTorch version
40
+ description: Run `pip show torch`
41
+ validations:
42
+ required: true
43
+
44
+ - type: dropdown
45
+ id: Branch
46
+ attributes:
47
+ label: Branch of sovits
48
+ options:
49
+ - 4.0(Default)
50
+ - 4.0-v2
51
+ - 3.0-32k
52
+ - 3.0-48k
53
+ validations:
54
+ required: true
55
+
56
+ - type: input
57
+ id: DatasetSource
58
+ attributes:
59
+ label: Dataset source (Used to judge the dataset quality)
60
+ description: Such as UVR-processed streaming audio / Recorded in recording studio
61
+ validations:
62
+ required: true
63
+
64
+ - type: input
65
+ id: WhereOccurs
66
+ attributes:
67
+ label: Where thr problem occurs or what command you executed
68
+ description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py`
69
+ validations:
70
+ required: true
71
+
72
+ - type: textarea
73
+ id: Description
74
+ attributes:
75
+ label: Situation description
76
+ description: Describe your situation here, the more detailed the better.
77
+ validations:
78
+ required: true
79
+
80
+ - type: textarea
81
+ id: Log
82
+ attributes:
83
+ label: Log
84
+ description: All information output from the command you executed to the end of execution (include the command). You can paste them to [pastebin.com](https://pastebin.com/) then paste the short link here. It can also be directly posted below if there is only few text.
85
+ render: python
86
+ validations:
87
+ required: true
88
+
89
+ - type: textarea
90
+ id: Supplementary
91
+ attributes:
92
+ label: Supplementary description
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ blank_issues_enabled: false
2
+ contact_links:
3
+ - name: 讨论区 / Discussions
4
+ url: https://github.com/svc-develop-team/so-vits-svc/discussions
5
+ about: 简单的询问/讨论请转至讨论区或发起一个低优先级的Default issue / For simple inquiries / discussions, please go to the discussions or raise a low priority Default issue
.github/ISSUE_TEMPLATE/default.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Default issue
3
+ about: 如果模板中没有你想发起的issue类型,可以选择此项,但这个issue也许会获得一个较低的处理优先级 / If there is no issue type you want to raise, you can start with this one. But this issue maybe will get a lower priority to deal with.
4
+ title: ''
5
+ labels: 'not urgent'
6
+ assignees: ''
7
+ ---
.github/workflows/reviewdog.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Ruff Autofix
2
+ on: [pull_request]
3
+ jobs:
4
+ ruff:
5
+ permissions:
6
+ checks: write
7
+ contents: read
8
+ pull-requests: write
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v3
12
+ - uses: chartboost/ruff-action@v1
13
+ with:
14
+ args: --fix -e
15
+ - uses: reviewdog/action-suggester@v1
16
+ with:
17
+ tool_name: ruff
18
+
.github/workflows/ruff.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: Ruff
2
+ on: [push, pull_request]
3
+ jobs:
4
+ ruff:
5
+ runs-on: ubuntu-latest
6
+ steps:
7
+ - uses: actions/checkout@v3
8
+ - uses: chartboost/ruff-action@v1
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/python
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
4
+
5
+ ### Python ###
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+ checkpoints/
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ pip-wheel-metadata/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ pytestdebug.log
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+ doc/_build/
80
+
81
+ # PyBuilder
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102
+ __pypackages__/
103
+
104
+ # Celery stuff
105
+ celerybeat-schedule
106
+ celerybeat.pid
107
+
108
+ # SageMath parsed files
109
+ *.sage.py
110
+
111
+ # Environments
112
+ .env
113
+ .venv
114
+ env/
115
+ venv/
116
+ ENV/
117
+ env.bak/
118
+ venv.bak/
119
+
120
+ # Spyder project settings
121
+ .spyderproject
122
+ .spyproject
123
+
124
+ # Rope project settings
125
+ .ropeproject
126
+
127
+ # mkdocs documentation
128
+ /site
129
+
130
+ # mypy
131
+ .mypy_cache/
132
+ .dmypy.json
133
+ dmypy.json
134
+
135
+ # Pyre type checker
136
+ .pyre/
137
+
138
+ # pytype static type analyzer
139
+ .pytype/
140
+
141
+ # End of https://www.toptal.com/developers/gitignore/api/python
142
+
143
+ /shelf/
144
+ /workspace.xml
145
+
146
+ dataset
147
+ dataset_raw
148
+ raw
149
+ results
150
+ inference/chunks_temp.json
151
+ logs
152
+ hubert/checkpoint_best_legacy_500.pt
153
+ configs/config.json
154
+ configs/config.yaml
155
+ filelists/test.txt
156
+ filelists/train.txt
157
+ filelists/val.txt
158
+ .idea/
159
+ .vscode/
160
+ .idea/modules.xml
161
+ .idea/so-vits-svc.iml
162
+ .idea/vcs.xml
163
+ .idea/inspectionProfiles/profiles_settings.xml
164
+ .idea/inspectionProfiles/Project_Default.xml
165
+ .vscode/launch.json
.ruff.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ select = ["E", "F", "I"]
2
+
3
+ # Never enforce `E501` (line length violations).
4
+ ignore = ["E501", "E741"]
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
README.md CHANGED
@@ -1,12 +1,586 @@
1
  ---
2
- title: So Vits Svc Sora
3
- emoji: 🦀
4
  colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AI Sora SVC
3
+ emoji: 🎤
4
  colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
+ <div align="center">
12
+ <img alt="LOGO" src="https://avatars.githubusercontent.com/u/127122328?s=400&u=5395a98a4f945a3a50cb0cc96c2747505d190dbc&v=4" width="300" height="300" />
13
+
14
+ # SoftVC VITS Singing Voice Conversion
15
+
16
+ [**English**](./README.md) | [**中文简体**](./README_zh_CN.md)
17
+
18
+ [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb)
19
+ [![Licence](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE)
20
+
21
+ This round of limited time update is coming to an end, the warehouse will enter the Archieve state, please know
22
+
23
+ </div>
24
+
25
+ > ✨ A studio that contains visible f0 editor, speaker mix timeline editor and other features (Where the Onnx models are used) : [MoeVoiceStudio](https://github.com/NaruseMioShirakana/MoeVoiceStudio)
26
+
27
+ > ✨ A fork with a greatly improved user interface: [34j/so-vits-svc-fork](https://github.com/34j/so-vits-svc-fork)
28
+
29
+ > ✨ A client supports real-time conversion: [w-okada/voice-changer](https://github.com/w-okada/voice-changer)
30
+
31
+ **This project differs fundamentally from VITS, as it focuses on Singing Voice Conversion (SVC) rather than Text-to-Speech (TTS). In this project, TTS functionality is not supported, and VITS is incapable of performing SVC tasks. It's important to note that the models used in these two projects are not interchangeable or universally applicable.**
32
+
33
+ ## Announcement
34
+
35
+ The purpose of this project was to enable developers to have their beloved anime characters perform singing tasks. The developers' intention was to focus solely on fictional characters and avoid any involvement of real individuals, anything related to real individuals deviates from the developer's original intention.
36
+
37
+ ## Disclaimer
38
+
39
+ This project is an open-source, offline endeavor, and all members of SvcDevelopTeam, as well as other developers and maintainers involved (hereinafter referred to as contributors), have no control over the project. The contributors have never provided any form of assistance to any organization or individual, including but not limited to dataset extraction, dataset processing, computing support, training support, inference, and so on. The contributors do not and cannot be aware of the purposes for which users utilize the project. Therefore, any AI models and synthesized audio produced through the training of this project are unrelated to the contributors. Any issues or consequences arising from their use are the sole responsibility of the user.
40
+
41
+ This project is run completely offline and does not collect any user information or gather user input data. Therefore, contributors to this project are not aware of all user input and models and therefore are not responsible for any user input.
42
+
43
+ This project serves as a framework only and does not possess speech synthesis functionality by itself. All functionalities require users to train the models independently. Furthermore, this project does not come bundled with any models, and any secondary distributed projects are independent of the contributors of this project.
44
+
45
+ ## 📏 Terms of Use
46
+
47
+ # Warning: Please ensure that you address any authorization issues related to the dataset on your own. You bear full responsibility for any problems arising from the usage of non-authorized datasets for training, as well as any resulting consequences. The repository and its maintainer, svc develop team, disclaim any association with or liability for the consequences.
48
+
49
+ 1. This project is exclusively established for academic purposes, aiming to facilitate communication and learning. It is not intended for deployment in production environments.
50
+ 2. Any sovits-based video posted to a video platform must clearly specify in the introduction the input source vocals and audio used for the voice changer conversion, e.g., if you use someone else's video/audio and convert it by separating the vocals as the input source, you must give a clear link to the original video or music; if you use your own vocals or a voice synthesized by another voice synthesis engine as the input source, you must also state this in your introduction.
51
+ 3. You are solely responsible for any infringement issues caused by the input source and all consequences. When using other commercial vocal synthesis software as an input source, please ensure that you comply with the regulations of that software, noting that the regulations of many vocal synthesis engines explicitly state that they cannot be used to convert input sources!
52
+ 4. Engaging in illegal activities, as well as religious and political activities, is strictly prohibited when using this project. The project developers vehemently oppose the aforementioned activities. If you disagree with this provision, the usage of the project is prohibited.
53
+ 5. If you continue to use the program, you will be deemed to have agreed to the terms and conditions set forth in README and README has discouraged you and is not responsible for any subsequent problems.
54
+ 6. If you intend to employ this project for any other purposes, kindly contact and inform the maintainers of this repository in advance.
55
+
56
+ ## 📝 Model Introduction
57
+
58
+ The singing voice conversion model uses SoftVC content encoder to extract speech features from the source audio. These feature vectors are directly fed into VITS without the need for conversion to a text-based intermediate representation. As a result, the pitch and intonations of the original audio are preserved. Meanwhile, the vocoder was replaced with [NSF HiFiGAN](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) to solve the problem of sound interruption.
59
+
60
+ ### 🆕 4.1-Stable Version Update Content
61
+
62
+ - Feature input is changed to the 12th Layer of [Content Vec](https://github.com/auspicious3000/contentvec) Transformer output, And compatible with 4.0 branches.
63
+ - Update the shallow diffusion, you can use the shallow diffusion model to improve the sound quality.
64
+ - Added Whisper-PPG encoder support
65
+ - Added static/dynamic sound fusion
66
+ - Added loudness embedding
67
+ - Added Functionality of feature retrieval from [RVC](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI)
68
+
69
+ ### 🆕 Questions about compatibility with the 4.0 model
70
+
71
+ - To support the 4.0 model and incorporate the speech encoder, you can make modifications to the `config.json` file. Add the `speech_encoder` field to the "model" section as shown below:
72
+
73
+ ```
74
+ "model": {
75
+ .........
76
+ "ssl_dim": 256,
77
+ "n_speakers": 200,
78
+ "speech_encoder":"vec256l9"
79
+ }
80
+ ```
81
+
82
+ ### 🆕 Shallow diffusion
83
+ ![Diagram](shadowdiffusion.png)
84
+
85
+ ## 💬 Python Version
86
+
87
+ Based on our testing, we have determined that the project runs stable on `Python 3.8.9`.
88
+
89
+ ## 📥 Pre-trained Model Files
90
+
91
+ #### **Required**
92
+
93
+ **You need to select one encoder from the list below**
94
+
95
+ ##### **1. If using contentvec as speech encoder(recommended)**
96
+
97
+ `vec768l12` and `vec256l9` require the encoder
98
+
99
+ - ContentVec: [checkpoint_best_legacy_500.pt](https://ibm.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr)
100
+ - Place it under the `pretrain` directory
101
+
102
+ Or download the following ContentVec, which is only 199MB in size but has the same effect:
103
+ - ContentVec: [hubert_base.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt)
104
+ - Change the file name to `checkpoint_best_legacy_500.pt` and place it in the `pretrain` directory
105
+
106
+ ```shell
107
+ # contentvec
108
+ wget -P pretrain/ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -O checkpoint_best_legacy_500.pt
109
+ # Alternatively, you can manually download and place it in the hubert directory
110
+ ```
111
+
112
+ ##### **2. If hubertsoft is used as the speech encoder**
113
+ - soft vc hubert: [hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
114
+ - Place it under the `pretrain` directory
115
+
116
+ ##### **3. If whisper-ppg as the encoder**
117
+ - download model at [medium.pt](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt), the model fits `whisper-ppg`
118
+ - or download model at [large-v2.pt](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt), the model fits `whisper-ppg-large`
119
+ - Place it under the `pretrain` directory
120
+
121
+ ##### **4. If cnhubertlarge as the encoder**
122
+ - download model at [chinese-hubert-large-fairseq-ckpt.pt](https://huggingface.co/TencentGameMate/chinese-hubert-large/resolve/main/chinese-hubert-large-fairseq-ckpt.pt)
123
+ - Place it under the `pretrain` directory
124
+
125
+ ##### **5. If dphubert as the encoder**
126
+ - download model at [DPHuBERT-sp0.75.pth](https://huggingface.co/pyf98/DPHuBERT/resolve/main/DPHuBERT-sp0.75.pth)
127
+ - Place it under the `pretrain` directory
128
+
129
+ ##### **6. If WavLM is used as the encoder**
130
+ - download model at [WavLM-Base+.pt](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D), the model fits `wavlmbase+`
131
+ - Place it under the `pretrain` directory
132
+
133
+ ##### **7. If OnnxHubert/ContentVec as the encoder**
134
+ - download model at [MoeSS-SUBModel](https://huggingface.co/NaruseMioShirakana/MoeSS-SUBModel/tree/main)
135
+ - Place it under the `pretrain` directory
136
+
137
+ #### **List of Encoders**
138
+ - "vec768l12"
139
+ - "vec256l9"
140
+ - "vec256l9-onnx"
141
+ - "vec256l12-onnx"
142
+ - "vec768l9-onnx"
143
+ - "vec768l12-onnx"
144
+ - "hubertsoft-onnx"
145
+ - "hubertsoft"
146
+ - "whisper-ppg"
147
+ - "cnhubertlarge"
148
+ - "dphubert"
149
+ - "whisper-ppg-large"
150
+ - "wavlmbase+"
151
+
152
+ #### **Optional(Strongly recommend)**
153
+
154
+ - Pre-trained model files: `G_0.pth` `D_0.pth`
155
+ - Place them under the `logs/44k` directory
156
+
157
+ - Diffusion model pretraining base model file: `model_0.pt`
158
+ - Put it in the `logs/44k/diffusion` directory
159
+
160
+ Get Sovits Pre-trained model from svc-develop-team(TBD) or anywhere else.
161
+
162
+ Diffusion model references [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) diffusion model. The pre-trained diffusion model is universal with the DDSP-SVC's. You can go to [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC)'s repo to get the pre-trained diffusion model.
163
+
164
+ While the pretrained model typically does not pose copyright concerns, it is essential to remain vigilant. It is advisable to consult with the author beforehand or carefully review the description to ascertain the permissible usage of the model. This helps ensure compliance with any specified guidelines or restrictions regarding its utilization.
165
+
166
+ #### **Optional(Select as Required)**
167
+
168
+ ##### NSF-HIFIGAN
169
+
170
+ If you are using the `NSF-HIFIGAN enhancer` or `shallow diffusion`, you will need to download the pre-trained NSF-HIFIGAN model.
171
+
172
+ - Pre-trained NSF-HIFIGAN Vocoder: [nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
173
+ - Unzip and place the four files under the `pretrain/nsf_hifigan` directory
174
+
175
+ ```shell
176
+ # nsf_hifigan
177
+ wget -P pretrain/ https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip
178
+ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
179
+ # Alternatively, you can manually download and place it in the pretrain/nsf_hifigan directory
180
+ # URL: https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
181
+ ```
182
+
183
+ ##### RMVPE
184
+
185
+ If you are using the `rmvpe` F0 Predictor, you will need to download the pre-trained RMVPE model.
186
+
187
+ + download model at [rmvpe.zip](https://github.com/yxlllc/RMVPE/releases/download/230917/rmvpe.zip), this weight is recommended.
188
+ + unzip `rmvpe.zip`,and rename the `model.pt` file to `rmvpe.pt` and place it under the `pretrain` directory.
189
+
190
+ - ~~download model at [rmvpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt)~~
191
+ - ~~Place it under the `pretrain` directory~~
192
+
193
+ ##### FCPE(Preview version)
194
+
195
+ [FCPE(Fast Context-base Pitch Estimator)](https://github.com/CNChTu/MelPE) is a dedicated F0 predictor designed for real-time voice conversion and will become the preferred F0 predictor for sovits real-time voice conversion in the future.(The paper is being written)
196
+
197
+ If you are using the `fcpe` F0 Predictor, you will need to download the pre-trained FCPE model.
198
+
199
+ - download model at [fcpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt)
200
+ - Place it under the `pretrain` directory
201
+
202
+ ## 📊 Dataset Preparation
203
+
204
+ Simply place the dataset in the `dataset_raw` directory with the following file structure:
205
+
206
+ ```
207
+ dataset_raw
208
+ ├───speaker0
209
+ │ ├───xxx1-xxx1.wav
210
+ │ ├───...
211
+ │ └───Lxx-0xx8.wav
212
+ └───speaker1
213
+ ├───xx2-0xxx2.wav
214
+ ├───...
215
+ └───xxx7-xxx007.wav
216
+ ```
217
+ There are no specific restrictions on the format of the name for each audio file (naming conventions such as `000001.wav` to `999999.wav` are also valid), but the file type must be `WAV``.
218
+
219
+ You can customize the speaker's name as showed below:
220
+
221
+ ```
222
+ dataset_raw
223
+ └───suijiSUI
224
+ ├───1.wav
225
+ ├───...
226
+ └───25788785-20221210-200143-856_01_(Vocals)_0_0.wav
227
+ ```
228
+
229
+ ## 🛠️ Preprocessing
230
+
231
+ ### 0. Slice audio
232
+
233
+ To avoid video memory overflow during training or pre-processing, it is recommended to limit the length of audio clips. Cutting the audio to a length of "5s - 15s" is more recommended. Slightly longer times are acceptable, however, excessively long clips may cause problems such as `torch.cuda.OutOfMemoryError`.
234
+
235
+ To facilitate the slicing process, you can use [audio-slicer-GUI](https://github.com/flutydeer/audio-slicer) or [audio-slicer-CLI](https://github.com/openvpi/audio-slicer)
236
+
237
+ In general, only the `Minimum Interval` needs to be adjusted. For spoken audio, the default value usually suffices, while for singing audio, it can be adjusted to around `100` or even `50`, depending on the specific requirements.
238
+
239
+ After slicing, it is recommended to remove any audio clips that are excessively long or too short.
240
+
241
+ **If you are using whisper-ppg encoder for training, the audio clips must shorter than 30s.**
242
+
243
+ ### 1. Resample to 44100Hz and mono
244
+
245
+ ```shell
246
+ python resample.py
247
+ ```
248
+
249
+ #### Cautions
250
+
251
+ Although this project has resample.py scripts for resampling, mono and loudness matching, the default loudness matching is to match to 0db. This can cause damage to the sound quality. While python's loudness matching package pyloudnorm does not limit the level, this can lead to sonic boom. Therefore, it is recommended to consider using professional sound processing software, such as `adobe audition` for loudness matching. If you are already using other software for loudness matching, add the parameter `-skip_loudnorm` to the run command:
252
+
253
+ ```shell
254
+ python resample.py --skip_loudnorm
255
+ ```
256
+
257
+ ### 2. Automatically split the dataset into training and validation sets, and generate configuration files.
258
+
259
+ ```shell
260
+ python preprocess_flist_config.py --speech_encoder vec768l12
261
+ ```
262
+
263
+ speech_encoder has the following options
264
+
265
+ ```
266
+ vec768l12
267
+ vec256l9
268
+ hubertsoft
269
+ whisper-ppg
270
+ cnhubertlarge
271
+ dphubert
272
+ whisper-ppg-large
273
+ wavlmbase+
274
+ ```
275
+
276
+ If the speech_encoder argument is omitted, the default value is `vec768l12`
277
+
278
+ **Use loudness embedding**
279
+
280
+ Add `--vol_aug` if you want to enable loudness embedding:
281
+
282
+ ```shell
283
+ python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug
284
+ ```
285
+
286
+ After enabling loudness embedding, the trained model will match the loudness of the input source; otherwise, it will match the loudness of the training set.
287
+
288
+ #### You can modify some parameters in the generated config.json and diffusion.yaml
289
+
290
+ * `keep_ckpts`: Keep the the the number of previous models during training. Set to `0` to keep them all. Default is `3`.
291
+
292
+ * `all_in_mem`: Load all dataset to RAM. It can be enabled when the disk IO of some platforms is too low and the system memory is **much larger** than your dataset.
293
+
294
+ * `batch_size`: The amount of data loaded to the GPU for a single training session can be adjusted to a size lower than the GPU memory capacity.
295
+
296
+ * `vocoder_name`: Select a vocoder. The default is `nsf-hifigan`.
297
+
298
+ ##### diffusion.yaml
299
+
300
+ * `cache_all_data`: Load all dataset to RAM. It can be enabled when the disk IO of some platforms is too low and the system memory is **much larger** than your dataset.
301
+
302
+ * `duration`: The duration of the audio slicing during training, can be adjusted according to the size of the video memory, **Note: this value must be less than the minimum time of the audio in the training set!**
303
+
304
+ * `batch_size`: The amount of data loaded to the GPU for a single training session can be adjusted to a size lower than the video memory capacity.
305
+
306
+ * `timesteps`: The total number of steps in the diffusion model, which defaults to 1000.
307
+
308
+ * `k_step_max`: Training can only train `k_step_max` step diffusion to save training time, note that the value must be less than `timesteps`, 0 is to train the entire diffusion model, **Note: if you do not train the entire diffusion model will not be able to use only_diffusion!**
309
+
310
+ ##### **List of Vocoders**
311
+
312
+ ```
313
+ nsf-hifigan
314
+ nsf-snake-hifigan
315
+ ```
316
+
317
+ ### 3. Generate hubert and f0
318
+
319
+ ```shell
320
+ python preprocess_hubert_f0.py --f0_predictor dio
321
+ ```
322
+
323
+ f0_predictor has the following options
324
+
325
+ ```
326
+ crepe
327
+ dio
328
+ pm
329
+ harvest
330
+ rmvpe
331
+ fcpe
332
+ ```
333
+
334
+ If the training set is too noisy,it is recommended to use `crepe` to handle f0
335
+
336
+ If the f0_predictor parameter is omitted, the default value is `rmvpe`
337
+
338
+ If you want shallow diffusion (optional), you need to add the `--use_diff` parameter, for example:
339
+
340
+ ```shell
341
+ python preprocess_hubert_f0.py --f0_predictor dio --use_diff
342
+ ```
343
+
344
+ **Speed Up preprocess**
345
+
346
+ If your dataset is pretty large,you can increase the param `--num_processes` like that:
347
+
348
+ ```shell
349
+ python preprocess_hubert_f0.py --f0_predictor dio --num_processes 8
350
+ ```
351
+ All the worker will be assigned to different GPU if you have more than one GPUs.
352
+
353
+ After completing the above steps, the dataset directory will contain the preprocessed data, and the dataset_raw folder can be deleted.
354
+
355
+ ## 🏋️‍ Training
356
+
357
+ ### Sovits Model
358
+
359
+ ```shell
360
+ python train.py -c configs/config.json -m 44k
361
+ ```
362
+
363
+ ### Diffusion Model (optional)
364
+
365
+ If the shallow diffusion function is needed, the diffusion model needs to be trained. The diffusion model training method is as follows:
366
+
367
+ ```shell
368
+ python train_diff.py -c configs/diffusion.yaml
369
+ ```
370
+
371
+ During training, the model files will be saved to `logs/44k`, and the diffusion model will be saved to `logs/44k/diffusion`
372
+
373
+ ## 🤖 Inference
374
+
375
+ Use [inference_main.py](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/inference_main.py)
376
+
377
+ ```shell
378
+ # Example
379
+ python inference_main.py -m "logs/44k/G_30400.pth" -c "configs/config.json" -n "君の知らない物語-src.wav" -t 0 -s "nen"
380
+ ```
381
+
382
+ Required parameters:
383
+ - `-m` | `--model_path`: path to the model.
384
+ - `-c` | `--config_path`: path to the configuration file.
385
+ - `-n` | `--clean_names`: a list of wav file names located in the `raw` folder.
386
+ - `-t` | `--trans`: pitch shift, supports positive and negative (semitone) values.
387
+ - `-s` | `--spk_list`: Select the speaker ID to use for conversion.
388
+ - `-cl` | `--clip`: Forced audio clipping, set to 0 to disable(default), setting it to a non-zero value (duration in seconds) to enable.
389
+
390
+ Optional parameters: see the next section
391
+ - `-lg` | `--linear_gradient`: The cross fade length of two audio slices in seconds. If there is a discontinuous voice after forced slicing, you can adjust this value. Otherwise, it is recommended to use the default value of 0.
392
+ - `-f0p` | `--f0_predictor`: Select a F0 predictor, options are `crepe`, `pm`, `dio`, `harvest`, `rmvpe`,`fcpe`, default value is `pm`(note: f0 mean pooling will be enable when using `crepe`)
393
+ - `-a` | `--auto_predict_f0`: automatic pitch prediction, do not enable this when converting singing voices as it can cause serious pitch issues.
394
+ - `-cm` | `--cluster_model_path`: Cluster model or feature retrieval index path, if left blank, it will be automatically set as the default path of these models. If there is no training cluster or feature retrieval, fill in at will.
395
+ - `-cr` | `--cluster_infer_ratio`: The proportion of clustering scheme or feature retrieval ranges from 0 to 1. If there is no training clustering model or feature retrieval, the default is 0.
396
+ - `-eh` | `--enhance`: Whether to use NSF_HIFIGAN enhancer, this option has certain effect on sound quality enhancement for some models with few training sets, but has negative effect on well-trained models, so it is disabled by default.
397
+ - `-shd` | `--shallow_diffusion`: Whether to use shallow diffusion, which can solve some electrical sound problems after use. This option is disabled by default. When this option is enabled, NSF_HIFIGAN enhancer will be disabled
398
+ - `-usm` | `--use_spk_mix`: whether to use dynamic voice fusion
399
+ - `-lea` | `--loudness_envelope_adjustment`:The adjustment of the input source's loudness envelope in relation to the fusion ratio of the output loudness envelope. The closer to 1, the more the output loudness envelope is used
400
+ - `-fr` | `--feature_retrieval`:Whether to use feature retrieval If clustering model is used, it will be disabled, and `cm` and `cr` parameters will become the index path and mixing ratio of feature retrieval
401
+
402
+ Shallow diffusion settings:
403
+ - `-dm` | `--diffusion_model_path`: Diffusion model path
404
+ - `-dc` | `--diffusion_config_path`: Diffusion config file path
405
+ - `-ks` | `--k_step`: The larger the number of k_steps, the closer it is to the result of the diffusion model. The default is 100
406
+ - `-od` | `--only_diffusion`: Whether to use Only diffusion mode, which does not load the sovits model to only use diffusion model inference
407
+ - `-se` | `--second_encoding`:which involves applying an additional encoding to the original audio before shallow diffusion. This option can yield varying results - sometimes positive and sometimes negative.
408
+
409
+ ### Cautions
410
+
411
+ If inferencing using `whisper-ppg` speech encoder, you need to set `--clip` to 25 and `-lg` to 1. Otherwise it will fail to infer properly.
412
+
413
+ ## 🤔 Optional Settings
414
+
415
+ If you are satisfied with the previous results, or if you do not feel you understand what follows, you can skip it and it will have no effect on the use of the model. The impact of these optional settings mentioned is relatively small, and while they may have some impact on specific datasets, in most cases the difference may not be significant.
416
+
417
+ ### Automatic f0 prediction
418
+
419
+ During the training of the 4.0 model, an f0 predictor is also trained, which enables automatic pitch prediction during voice conversion. However, if the results are not satisfactory, manual pitch prediction can be used instead. Please note that when converting singing voices, it is advised not to enable this feature as it may cause significant pitch shifting.
420
+
421
+ - Set `auto_predict_f0` to `true` in `inference_main.py`.
422
+
423
+ ### Cluster-based timbre leakage control
424
+
425
+ Introduction: The clustering scheme implemented in this model aims to reduce timbre leakage and enhance the similarity of the trained model to the target's timbre, although the effect may not be very pronounced. However, relying solely on clustering can reduce the model's clarity and make it sound less distinct. Therefore, a fusion method is adopted in this model to control the balance between the clustering and non-clustering approaches. This allows manual adjustment of the trade-off between "sounding like the target's timbre" and "have clear enunciation" to find an optimal balance.
426
+
427
+ No changes are required in the existing steps. Simply train an additional clustering model, which incurs relatively low training costs.
428
+
429
+ - Training process:
430
+ - Train on a machine with good CPU performance. According to extant experience, it takes about 4 minutes to train each speaker on a Tencent Cloud machine with 6-core CPU.
431
+ - Execute `python cluster/train_cluster.py`. The output model will be saved in `logs/44k/kmeans_10000.pt`.
432
+ - The clustering model can currently be trained using the gpu by executing `python cluster/train_cluster.py --gpu`
433
+ - Inference process:
434
+ - Specify `cluster_model_path` in `inference_main.py`. If not specified, the default is `logs/44k/kmeans_10000.pt`.
435
+ - Specify `cluster_infer_ratio` in `inference_main.py`, where `0` means not using clustering at all, `1` means only using clustering, and usually `0.5` is sufficient.
436
+
437
+ ### Feature retrieval
438
+
439
+ Introduction: As with the clustering scheme, the timbre leakage can be reduced, the enunciation is slightly better than clustering, but it will reduce the inference speed. By employing the fusion method, it becomes possible to linearly control the balance between feature retrieval and non-feature retrieval, allowing for fine-tuning of the desired proportion.
440
+
441
+ - Training process:
442
+ First, it needs to be executed after generating hubert and f0:
443
+
444
+ ```shell
445
+ python train_index.py -c configs/config.json
446
+ ```
447
+
448
+ The output of the model will be in `logs/44k/feature_and_index.pkl`
449
+
450
+ - Inference process:
451
+ - The `--feature_retrieval` needs to be formulated first, and the clustering mode automatically switches to the feature retrieval mode.
452
+ - Specify `cluster_model_path` in `inference_main.py`. If not specified, the default is `logs/44k/feature_and_index.pkl`.
453
+ - Specify `cluster_infer_ratio` in `inference_main.py`, where `0` means not using feature retrieval at all, `1` means only using feature retrieval, and usually `0.5` is sufficient.
454
+
455
+ ## 🗜️ Model compression
456
+
457
+ The generated model contains data that is needed for further training. If you confirm that the model is final and not be used in further training, it is safe to remove these data to get smaller file size (about 1/3).
458
+
459
+ ```shell
460
+ # Example
461
+ python compress_model.py -c="configs/config.json" -i="logs/44k/G_30400.pth" -o="logs/44k/release.pth"
462
+ ```
463
+
464
+ ## 👨‍🔧 Timbre mixing
465
+
466
+ ### Static Tone Mixing
467
+
468
+ **Refer to `webUI.py` file for stable Timbre mixing of the gadget/lab feature.**
469
+
470
+ Introduction: This function can combine multiple models into one model (convex combination or linear combination of multiple model parameters) to create mixed voice that do not exist in reality
471
+
472
+ **Note:**
473
+ 1. This feature is only supported for single-speaker models
474
+ 2. If you force a multi-speaker model, it is critical to make sure there are the same number of speakers in each model. This will ensure that sounds with the same SpeakerID can be mixed correctly.
475
+ 3. Ensure that the `model` fields in config.json of all models to be mixed are the same
476
+ 4. The mixed model can use any config.json file from the models being synthesized. However, the clustering model will not be functional after mixed.
477
+ 5. When batch uploading models, it is best to put the models into a folder and upload them together after selecting them
478
+ 6. It is suggested to adjust the mixing ratio between 0 and 100, or to other numbers, but unknown effects will occur in the linear combination mode
479
+ 7. After mixing, the file named output.pth will be saved in the root directory of the project
480
+ 8. Convex combination mode will perform Softmax to add the mix ratio to 1, while linear combination mode will not
481
+
482
+ ### Dynamic timbre mixing
483
+
484
+ **Refer to the `spkmix.py` file for an introduction to dynamic timbre mixing**
485
+
486
+ Character mix track writing rules:
487
+
488
+ Role ID: \[\[Start time 1, end time 1, start value 1, start value 1], [Start time 2, end time 2, start value 2]]
489
+
490
+ The start time must be the same as the end time of the previous one. The first start time must be 0, and the last end time must be 1 (time ranges from 0 to 1).
491
+
492
+ All roles must be filled in. For unused roles, fill \[\[0., 1., 0., 0.]]
493
+
494
+ The fusion value can be filled in arbitrarily, and the linear change from the start value to the end value within the specified period of time. The
495
+
496
+ internal linear combination will be automatically guaranteed to be 1 (convex combination condition), so it can be used safely
497
+
498
+ Use the `--use_spk_mix` parameter when reasoning to enable dynamic timbre mixing
499
+
500
+ ## 📤 Exporting to Onnx
501
+
502
+ Use [onnx_export.py](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/onnx_export.py)
503
+
504
+ - Create a folder named `checkpoints` and open it
505
+ - Create a folder in the `checkpoints` folder as your project folder, naming it after your project, for example `aziplayer`
506
+ - Rename your model as `model.pth`, the configuration file as `config.json`, and place them in the `aziplayer` folder you just created
507
+ - Modify `"NyaruTaffy"` in `path = "NyaruTaffy"` in [onnx_export.py](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/onnx_export.py) to your project name, `path = "aziplayer"`(onnx_export_speaker_mix makes you can mix speaker's voice)
508
+ - Run [onnx_export.py](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/onnx_export.py)
509
+ - Wait for it to finish running. A `model.onnx` will be generated in your project folder, which is the exported model.
510
+
511
+ Note: For Hubert Onnx models, please use the models provided by MoeSS. Currently, they cannot be exported on their own (Hubert in fairseq has many unsupported operators and things involving constants that can cause errors or result in problems with the input/output shape and results when exported.)
512
+
513
+
514
+ ## 📎 Reference
515
+
516
+ | URL | Designation | Title | Implementation Source |
517
+ | --- | ----------- | ----- | --------------------- |
518
+ |[2106.06103](https://arxiv.org/abs/2106.06103) | VITS (Synthesizer)| Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech | [jaywalnut310/vits](https://github.com/jaywalnut310/vits) |
519
+ |[2111.02392](https://arxiv.org/abs/2111.02392) | SoftVC (Speech Encoder)| A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion | [bshall/hubert](https://github.com/bshall/hubert) |
520
+ |[2204.09224](https://arxiv.org/abs/2204.09224) | ContentVec (Speech Encoder)| ContentVec: An Improved Self-Supervised Speech Representation by Disentangling Speakers | [auspicious3000/contentvec](https://github.com/auspicious3000/contentvec) |
521
+ |[2212.04356](https://arxiv.org/abs/2212.04356) | Whisper (Speech Encoder) | Robust Speech Recognition via Large-Scale Weak Supervision | [openai/whisper](https://github.com/openai/whisper) |
522
+ |[2110.13900](https://arxiv.org/abs/2110.13900) | WavLM (Speech Encoder) | WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing | [microsoft/unilm/wavlm](https://github.com/microsoft/unilm/tree/master/wavlm) |
523
+ |[2305.17651](https://arxiv.org/abs/2305.17651) | DPHubert (Speech Encoder) | DPHuBERT: Joint Distillation and Pruning of Self-Supervised Speech Models | [pyf98/DPHuBERT](https://github.com/pyf98/DPHuBERT) |
524
+ |[DOI:10.21437/Interspeech.2017-68](http://dx.doi.org/10.21437/Interspeech.2017-68) | Harvest (F0 Predictor) | Harvest: A high-performance fundamental frequency estimator from speech signals | [mmorise/World/harvest](https://github.com/mmorise/World/blob/master/src/harvest.cpp) |
525
+ |[aes35-000039](https://www.aes.org/e-lib/online/browse.cfm?elib=15165) | Dio (F0 Predictor) | Fast and reliable F0 estimation method based on the period extraction of vocal fold vibration of singing voice and speech | [mmorise/World/dio](https://github.com/mmorise/World/blob/master/src/dio.cpp) |
526
+ |[8461329](https://ieeexplore.ieee.org/document/8461329) | Crepe (F0 Predictor) | Crepe: A Convolutional Representation for Pitch Estimation | [maxrmorrison/torchcrepe](https://github.com/maxrmorrison/torchcrepe) |
527
+ |[DOI:10.1016/j.wocn.2018.07.001](https://doi.org/10.1016/j.wocn.2018.07.001) | Parselmouth (F0 Predictor) | Introducing Parselmouth: A Python interface to Praat | [YannickJadoul/Parselmouth](https://github.com/YannickJadoul/Parselmouth) |
528
+ |[2306.15412v2](https://arxiv.org/abs/2306.15412v2) | RMVPE (F0 Predictor) | RMVPE: A Robust Model for Vocal Pitch Estimation in Polyphonic Music | [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) |
529
+ |[2010.05646](https://arxiv.org/abs/2010.05646) | HIFIGAN (Vocoder) | HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | [jik876/hifi-gan](https://github.com/jik876/hifi-gan) |
530
+ |[1810.11946](https://arxiv.org/abs/1810.11946.pdf) | NSF (Vocoder) | Neural source-filter-based waveform model for statistical parametric speech synthesis | [openvpi/DiffSinger/modules/nsf_hifigan](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan)
531
+ |[2006.08195](https://arxiv.org/abs/2006.08195) | Snake (Vocoder) | Neural Networks Fail to Learn Periodic Functions and How to Fix It | [EdwardDixon/snake](https://github.com/EdwardDixon/snake)
532
+ |[2105.02446v3](https://arxiv.org/abs/2105.02446v3) | Shallow Diffusion (PostProcessing)| DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism | [CNChTu/Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) |
533
+ |[K-means](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=01D65490BADCC216F350D06F84D721AD?doi=10.1.1.308.8619&rep=rep1&type=pdf) | Feature K-means Clustering (PreProcessing)| Some methods for classification and analysis of multivariate observations | This repo |
534
+ | | Feature TopK Retrieval (PreProcessing)| Retrieval based Voice Conversion | [RVC-Project/Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) |
535
+ | | whisper ppg| whisper ppg | [PlayVoice/whisper_ppg](https://github.com/PlayVoice/whisper_ppg) |
536
+ | | bigvgan| bigvgan | [PlayVoice/so-vits-svc-5.0](https://github.com/PlayVoice/so-vits-svc-5.0/tree/bigvgan-mix-v2/vits_decoder/alias) |
537
+
538
+
539
+ ## ☀️ Previous contributors
540
+
541
+ For some reason the author deleted the original repository. Because of the negligence of the organization members, the contributor list was cleared because all files were directly reuploaded to this repository at the beginning of the reconstruction of this repository. Now add a previous contributor list to README.md.
542
+
543
+ *Some members have not listed according to their personal wishes.*
544
+
545
+ <table>
546
+ <tr>
547
+ <td align="center"><a href="https://github.com/MistEO"><img src="https://avatars.githubusercontent.com/u/18511905?v=4" width="100px;" alt=""/><br /><sub><b>MistEO</b></sub></a><br /></td>
548
+ <td align="center"><a href="https://github.com/XiaoMiku01"><img src="https://avatars.githubusercontent.com/u/54094119?v=4" width="100px;" alt=""/><br /><sub><b>XiaoMiku01</b></sub></a><br /></td>
549
+ <td align="center"><a href="https://github.com/ForsakenRei"><img src="https://avatars.githubusercontent.com/u/23041178?v=4" width="100px;" alt=""/><br /><sub><b>しぐれ</b></sub></a><br /></td>
550
+ <td align="center"><a href="https://github.com/TomoGaSukunai"><img src="https://avatars.githubusercontent.com/u/25863522?v=4" width="100px;" alt=""/><br /><sub><b>TomoGaSukunai</b></sub></a><br /></td>
551
+ <td align="center"><a href="https://github.com/Plachtaa"><img src="https://avatars.githubusercontent.com/u/112609742?v=4" width="100px;" alt=""/><br /><sub><b>Plachtaa</b></sub></a><br /></td>
552
+ <td align="center"><a href="https://github.com/zdxiaoda"><img src="https://avatars.githubusercontent.com/u/45501959?v=4" width="100px;" alt=""/><br /><sub><b>zd小达</b></sub></a><br /></td>
553
+ <td align="center"><a href="https://github.com/Archivoice"><img src="https://avatars.githubusercontent.com/u/107520869?v=4" width="100px;" alt=""/><br /><sub><b>凍聲響世</b></sub></a><br /></td>
554
+ </tr>
555
+ </table>
556
+
557
+ ## 📚 Some legal provisions for reference
558
+
559
+ #### Any country, region, organization, or individual using this project must comply with the following laws.
560
+
561
+ #### 《民法典》
562
+
563
+ ##### 第一千零一十九条
564
+
565
+ 任何组织或者个人不得以丑化、污损,或者利用信息技术手段伪造等方式侵害他人的肖像权。未经肖像权人同意,不得制作、使用、公开肖像权人的肖像,但是法律另有规定的除外。未经肖像权人同意,肖像作品权利人不得以发表、复制、发行、出租、展览等方式使用或者公开肖像权人的肖像。对自然人声音的保护,参照适用肖像权保护的有关规定。
566
+
567
+ ##### 第一千零二十四条
568
+
569
+ 【名誉权】民事主体享有名誉权。任何组织或者个人不得以侮辱、诽谤等方式侵害他人的名誉权。
570
+
571
+ ##### 第一千零二十七条
572
+
573
+ 【作品侵害名誉权】行为人发表的文学、艺术作品以真人真事或者特定人为描述对象,含有侮辱、诽谤内容,侵害他人名誉权的,受害人有权依法请求该行为人承担民事责任。行为人发表的文学、艺术作品不以特定人为描述对象,仅其中的情节与该特定人的情况相似的,不承担民事责任。
574
+
575
+ #### 《[中华人民共和国宪法](http://www.gov.cn/guoqing/2018-03/22/content_5276318.htm)》
576
+
577
+ #### 《[中华人民共和国刑法](http://gongbao.court.gov.cn/Details/f8e30d0689b23f57bfc782d21035c3.html?sw=%E4%B8%AD%E5%8D%8E%E4%BA%BA%E6%B0%91%E5%85%B1%E5%92%8C%E5%9B%BD%E5%88%91%E6%B3%95)》
578
+
579
+ #### 《[中华人民共和国民法典](http://gongbao.court.gov.cn/Details/51eb6750b8361f79be8f90d09bc202.html)》
580
+
581
+ #### 《[中华人民共和国合同法](http://www.npc.gov.cn/zgrdw/npc/lfzt/rlyw/2016-07/01/content_1992739.htm)》
582
+
583
+ ## 💪 Thanks to all contributors for their efforts
584
+ <a href="https://github.com/svc-develop-team/so-vits-svc/graphs/contributors" target="_blank">
585
+ <img src="https://contrib.rocks/image?repo=svc-develop-team/so-vits-svc" />
586
+ </a>
README_zh_CN.md ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img alt="LOGO" src="https://avatars.githubusercontent.com/u/127122328?s=400&u=5395a98a4f945a3a50cb0cc96c2747505d190dbc&v=4" width="300" height="300" />
3
+
4
+ # SoftVC VITS Singing Voice Conversion
5
+
6
+ [**English**](./README.md) | [**中文简体**](./README_zh_CN.md)
7
+
8
+ [![在Google Cloab中打开](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb)
9
+ [![LICENSE](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE)
10
+
11
+ 本轮限时更新即将结束,仓库将进入Archieve状态,望周知
12
+
13
+ </div>
14
+
15
+
16
+ #### ✨ 带有 F0 曲线编辑器,角色混合时间轴编辑器的推理端 (Onnx 模型的用途): [MoeVoiceStudio](https://github.com/NaruseMioShirakana/MoeVoiceStudio)
17
+
18
+ #### ✨ 改善了交互的一个分支推荐: [34j/so-vits-svc-fork](https://github.com/34j/so-vits-svc-fork)
19
+
20
+ #### ✨ 支持实时转换的一个客户端: [w-okada/voice-changer](https://github.com/w-okada/voice-changer)
21
+
22
+ **本项目与 Vits 有着根本上的不同。Vits 是 TTS,本项目是 SVC。本项目无法实现 TTS,Vits 也无法实现 SVC,这两个项目的模型是完全不通用的。**
23
+
24
+ ## 重要通知
25
+
26
+ 这个项目是为了让开发者最喜欢的动画角色唱歌而开发的,任何涉及真人的东西都与开发者的意图背道而驰。
27
+
28
+ ## 声明
29
+
30
+ 本项目为开源、离线的项目,SvcDevelopTeam 的所有成员与本项目的所有开发者以及维护者(以下简称贡献者)对本项目没有控制力。本项目的贡献者从未向任何组织或个人提供包括但不限于数据集提取、数据集加工、算力支持、训练支持、推理等一切形式的帮助;本项目的贡献者不知晓也无法知晓使用者使用该项目的用途。故一切基于本项目训练的 AI 模型和合成的音频都与本项目贡献者无关。一切由此造成的问题由使用者自行承担。
31
+
32
+ 此项目完全离线运行,不能收集任何用户信息或获取用户输入数据。因此,这个项目的贡献者不知道所有的用户输入和模型,因此不负责任何用户输入。
33
+
34
+ 本项目只是一个框架项目,本身并没有语音合成的功能,所有的功能都需要用户自己训练模型。同时,这个项目没有任何模型,任何二次分发的项目都与这个项目的贡献者无关。
35
+
36
+ ## 📏 使用规约
37
+
38
+ # Warning:请自行解决数据集授权问题,禁止使用非授权数据集进行训练!任何由于使用非授权数据集进行训练造成的问题,需自行承担全部责任和后果!与仓库、仓库维护者、svc develop team 无关!
39
+
40
+ 1. 本项目是基于学术交流目的建立,仅供交流与学习使用,并非为生产环境准备。
41
+ 2. 任何发布到视频平台的基于 sovits 制作的视频,都必须要在简介明确指明用于变声器转换的输入源歌声、音频,例如:使用他人发布的视频 / 音频,通过分离的人声作为输入源进行转换的,必须要给出明确的原视频、音乐链接;若使用是自己的人声,或是使用其他歌声合成引擎合成的声音作为输入源进行转换的,也必须在简介加以说明。
42
+ 3. 由输入源造成的侵权问题需自行承担全部责任和一切后果。使用其他商用歌声合成软件作为输入源时,请确保遵守该软件的使用条例,注意,许多歌声合成引擎使用条例中明确指明不可用于输入源进行转换!
43
+ 4. 禁止使用该项目从事违法行为与宗教、政治等活动,该项目维护者坚决抵制上述行为,不同意此条则禁止使用该项目。
44
+ 5. 继续使用视为已同意本仓库 README 所述相关条例,本仓库 README 已进行劝导义务,不对后续可能存在问题负责。
45
+ 6. 如果将此项目用于任何其他企划,请提前联系并告知本仓库作者,十分感谢。
46
+
47
+ ## 📝 模型简介
48
+
49
+ 歌声音色转换模型,通过 SoftVC 内容编码器提取源音频语音特征,与 F0 同时输入 VITS 替换原本的文本输入达到歌声转换的效果。同时,更换声码器为 [NSF HiFiGAN](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) 解决断音问题。
50
+
51
+ ### 🆕 4.1-Stable 版本更新内容
52
+
53
+ + 特征输入更换为 [Content Vec](https://github.com/auspicious3000/contentvec) 的第 12 层 Transformer 输出,并兼容 4.0 分支
54
+ + 更新浅层扩散,可以使用浅层扩散模型提升音质
55
+ + 增加 whisper 语音编码器的支持
56
+ + 增加静态/动态声线融合
57
+ + 增加响度嵌入
58
+ + 增加特征检索,来自于 [RVC](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI)
59
+
60
+ ### 🆕 关于兼容 4.0 模型的问题
61
+
62
+ + 可通过修改 4.0 模型的 config.json 对 4.0 的模型进行支持,需要在 config.json 的 model 字段中添加 speech_encoder 字段,具体见下
63
+
64
+ ```
65
+ "model": {
66
+ .........
67
+ "ssl_dim": 256,
68
+ "n_speakers": 200,
69
+ "speech_encoder":"vec256l9"
70
+ }
71
+ ```
72
+
73
+ ### 🆕 关于浅扩散
74
+ ![Diagram](shadowdiffusion.png)
75
+
76
+ ## 💬 关于 Python 版本问题
77
+
78
+ 在进行测试后,我们认为`Python 3.8.9`能够稳定地运行该项目
79
+
80
+ ## 📥 预先下载的模型文件
81
+
82
+ #### **必须项**
83
+
84
+ **以下编码器需要选择一个使用**
85
+
86
+ ##### **1. 若使用 contentvec 作为声音编码器(推荐)**
87
+
88
+ `vec768l12`与`vec256l9` 需要该编码器
89
+
90
+ + contentvec :[checkpoint_best_legacy_500.pt](https://ibm.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr)
91
+ + 放在`pretrain`目录下
92
+
93
+ 或者下载下面的 ContentVec,大小只有 199MB,但效果相同:
94
+ + contentvec :[hubert_base.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt)
95
+ + 将文件名改为`checkpoint_best_legacy_500.pt`后,放在`pretrain`目录下
96
+
97
+ ```shell
98
+ # contentvec
99
+ wget -P pretrain/ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -O checkpoint_best_legacy_500.pt
100
+ # 也可手动下载放在 pretrain 目录
101
+ ```
102
+
103
+ ##### **2. 若使用 hubertsoft 作为声音编码器**
104
+ + soft vc hubert:[hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
105
+ + 放在`pretrain`目录下
106
+
107
+ ##### **3. 若使用 Whisper-ppg 作为声音编码器**
108
+ + 下载模型 [medium.pt](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt), 该模型适配`whisper-ppg`
109
+ + 下载模型 [large-v2.pt](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt), 该模型适配`whisper-ppg-large`
110
+ + 放在`pretrain`目录下
111
+
112
+ ##### **4. 若使用 cnhubertlarge 作为声音编码器**
113
+ + 下载模型 [chinese-hubert-large-fairseq-ckpt.pt](https://huggingface.co/TencentGameMate/chinese-hubert-large/resolve/main/chinese-hubert-large-fairseq-ckpt.pt)
114
+ + 放在`pretrain`目录下
115
+
116
+ ##### **5. 若使用 dphubert 作为声音编码器**
117
+ + 下载模型 [DPHuBERT-sp0.75.pth](https://huggingface.co/pyf98/DPHuBERT/resolve/main/DPHuBERT-sp0.75.pth)
118
+ + 放在`pretrain`目录下
119
+
120
+ ##### **6. 若使用 WavLM 作为声音编码器**
121
+ + 下载模型 [WavLM-Base+.pt](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D), 该模型适配`wavlmbase+`
122
+ + 放在`pretrain`目录下
123
+
124
+ ##### **7. 若使用 OnnxHubert/ContentVec 作为声音编码器**
125
+ + 下载模型 [MoeSS-SUBModel](https://huggingface.co/NaruseMioShirakana/MoeSS-SUBModel/tree/main)
126
+ + 放在`pretrain`目录下
127
+
128
+ #### **编码器列表**
129
+ - "vec768l12"
130
+ - "vec256l9"
131
+ - "vec256l9-onnx"
132
+ - "vec256l12-onnx"
133
+ - "vec768l9-onnx"
134
+ - "vec768l12-onnx"
135
+ - "hubertsoft-onnx"
136
+ - "hubertsoft"
137
+ - "whisper-ppg"
138
+ - "cnhubertlarge"
139
+ - "dphubert"
140
+ - "whisper-ppg-large"
141
+ - "wavlmbase+"
142
+
143
+ #### **可选项(强烈建议使用)**
144
+
145
+ + 预训练底模文件: `G_0.pth` `D_0.pth`
146
+ + 放在`logs/44k`目录下
147
+
148
+ + 扩散模型预训练底模文件: `model_0.pt`
149
+ + 放在`logs/44k/diffusion`目录下
150
+
151
+ 从 svc-develop-team(待定)或任何其他地方获取 Sovits 底模
152
+
153
+ 扩散模型引用了 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 的 Diffusion Model,底模与 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 的扩散模型底模通用,可以去 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 获取扩散模型的底模
154
+
155
+ 虽然底模一般不会引起什么版权问题,但还是请注意一下,比如事先询问作者,又或者作者在模型描述中明确写明了可行的用途
156
+
157
+ #### **可选项(根据情况选择)**
158
+
159
+ ##### NSF-HIFIGAN
160
+
161
+ 如果使用`NSF-HIFIGAN 增强器`或`浅层扩散`的话,需要下载预训练的 NSF-HIFIGAN 模型,如果不需要可以不下载
162
+
163
+ + 预训练的 NSF-HIFIGAN 声码器 :[nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
164
+ + 解压后,将四个文件放在`pretrain/nsf_hifigan`目录下
165
+
166
+ ```shell
167
+ # nsf_hifigan
168
+ wget -P pretrain/ https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip
169
+ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
170
+ # 也可手动下载放在 pretrain/nsf_hifigan 目录
171
+ # 地址:https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
172
+ ```
173
+
174
+ ##### RMVPE
175
+
176
+ 如果使用`rmvpe`F0预测器的话,需要下载预训练的 RMVPE 模型
177
+
178
+ + 下载模型[rmvpe.zip](https://github.com/yxlllc/RMVPE/releases/download/230917/rmvpe.zip),目前首推该权重。
179
+ + 解压缩`rmvpe.zip`,并将其中的`model.pt`文件改名为`rmvpe.pt`并放在`pretrain`目录下
180
+
181
+ + ~~下载模型 [rmvpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt)~~
182
+ + ~~放在`pretrain`目录下~~
183
+
184
+ ##### FCPE(预览版)
185
+
186
+ > 你说的对,但是[FCPE](https://github.com/CNChTu/MelPE)是由svc-develop-team自主研发的一款全新的F0预测器,后面忘了
187
+
188
+ [FCPE(Fast Context-base Pitch Estimator)](https://github.com/CNChTu/MelPE)是一个为实时语音转换所设计的专用F0预测器,他将在未来成为Sovits实时语音转换的首选F0预测器.(论文未来会有的)
189
+
190
+ 如果使用 `fcpe` F0预测器的话,需要下载预训练的 FCPE 模型
191
+
192
+ + 下载模型 [fcpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt)
193
+ + 放在`pretrain`目录下
194
+
195
+
196
+ ## 📊 数据集准备
197
+
198
+ 仅需要以以下文件结构将数据集放入 dataset_raw 目录即可。
199
+
200
+ ```
201
+ dataset_raw
202
+ ├───speaker0
203
+ │ ├───xxx1-xxx1.wav
204
+ │ ├───...
205
+ │ └───Lxx-0xx8.wav
206
+ └───speaker1
207
+ ├───xx2-0xxx2.wav
208
+ ├───...
209
+ └───xxx7-xxx007.wav
210
+ ```
211
+ 对于每一个音频文件的名称并没有格式的限制(`000001.wav`~`999999.wav`之类的命名方式也是合法的),不过文件类型必须是`wav`。
212
+
213
+ 可以自定义说话人名称
214
+
215
+ ```
216
+ dataset_raw
217
+ └───suijiSUI
218
+ ├───1.wav
219
+ ├───...
220
+ └───25788785-20221210-200143-856_01_(Vocals)_0_0.wav
221
+ ```
222
+
223
+ ## 🛠️ 数据预处理
224
+
225
+ ### 0. 音频切片
226
+
227
+ 将音频切片至`5s - 15s`, 稍微长点也无伤大雅,实在太长可能会导致训练中途甚至预处理就爆显存
228
+
229
+ 可以使用 [audio-slicer-GUI](https://github.com/flutydeer/audio-slicer)、[audio-slicer-CLI](https://github.com/openvpi/audio-slicer)
230
+
231
+ 一般情况下只需调整其中的`Minimum Interval`,普通陈述素材通常保持默认即可,歌唱素材可以调整至`100`甚至`50`
232
+
233
+ 切完之后手动删除过长过短的音频
234
+
235
+ **如果你使用 Whisper-ppg 声音编码器进行训练,所有的切片长度必须小于 30s**
236
+
237
+ ### 1. 重采样至 44100Hz 单声道
238
+
239
+ ```shell
240
+ python resample.py
241
+ ```
242
+
243
+ #### 注意
244
+
245
+ 虽然本项目拥有重采样、转换单声道与响度匹配的脚本 resample.py,但是默认的响度匹配是匹配到 0db。这可能会造成音质的受损。而 python 的响度匹配包 pyloudnorm 无法对电平进行压限,这会导致爆音。所以建议可以考虑使用专业声音处理软件如`adobe audition`等软件做响度匹配处理。若已经使用其他软件做响度匹配,可以在运行上述命令时添加`--skip_loudnorm`跳过响度匹配步骤。如:
246
+
247
+ ```shell
248
+ python resample.py --skip_loudnorm
249
+ ```
250
+
251
+ ### 2. 自动划分训练集、验证集,以及自动生成配置文件
252
+
253
+ ```shell
254
+ python preprocess_flist_config.py --speech_encoder vec768l12
255
+ ```
256
+
257
+ speech_encoder 拥有以下选择
258
+
259
+ ```
260
+ vec768l12
261
+ vec256l9
262
+ hubertsoft
263
+ whisper-ppg
264
+ whisper-ppg-large
265
+ cnhubertlarge
266
+ dphubert
267
+ wavlmbase+
268
+ ```
269
+
270
+ 如果省略 speech_encoder 参数,默认值为 vec768l12
271
+
272
+ **使用响度嵌入**
273
+
274
+ 若使用响度嵌入,需要增加`--vol_aug`参数,比如:
275
+
276
+ ```shell
277
+ python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug
278
+ ```
279
+ 使用后训练出的模型将匹配到输入源响度,否则为训练集响度。
280
+
281
+ #### 此时可以在生成的 config.json 与 diffusion.yaml 修改部分参数
282
+
283
+ ##### config.json
284
+
285
+ * `keep_ckpts`:训练时保留最后几个模型,`0`为保留所有,默认只保留最后`3`个
286
+
287
+ * `all_in_mem`:加载所有数据集到内存中,某些平台的硬盘 IO 过于低下、同时内存容量 **远大于** 数据集体积时可以启用
288
+
289
+ * `batch_size`:单次训练加载到 GPU 的数据量,调整到低于显存容量的大小即可
290
+
291
+ * `vocoder_name` : 选择一种声码器,默认为`nsf-hifigan`.
292
+
293
+ ##### diffusion.yaml
294
+
295
+ * `cache_all_data`:加载所有数据集到内存中,某些平台的硬盘 IO 过于低下、同时内存容量 **远大于** 数据集体积时可以启用
296
+
297
+ * `duration`:训练时音频切片时长,可根据显存大小调整,**注意,该值必须小于训练集内音频的最短时间!**
298
+
299
+ * `batch_size`:单次训练加载到 GPU 的数据量,调整到低于显存容量的大小即可
300
+
301
+ * `timesteps` : 扩散模型总步数,默认为 1000.
302
+
303
+ * `k_step_max` : 训练时可仅训练`k_step_max`步扩散以节约训练时间,注意,该值必须小于`timesteps`,0 为训练整个扩散模型,**注意,如果不训练整个扩散模型将无法使用仅扩散模型推理!**
304
+
305
+ ##### **声码器列表**
306
+
307
+ ```
308
+ nsf-hifigan
309
+ nsf-snake-hifigan
310
+ ```
311
+
312
+ ### 3. 生成 hubert 与 f0
313
+
314
+ ```shell
315
+ python preprocess_hubert_f0.py --f0_predictor dio
316
+ ```
317
+
318
+ f0_predictor 拥有以下选择
319
+
320
+ ```
321
+ crepe
322
+ dio
323
+ pm
324
+ harvest
325
+ rmvpe
326
+ fcpe
327
+ ```
328
+
329
+ 如果训练集过于嘈杂,请使用 crepe 处理 f0
330
+
331
+ 如果省略 f0_predictor 参数,默认值为 rmvpe
332
+
333
+ 尚若需要浅扩散功能(可选),需要增加--use_diff 参数,比如
334
+
335
+ ```shell
336
+ python preprocess_hubert_f0.py --f0_predictor dio --use_diff
337
+ ```
338
+
339
+ **加速预处理**
340
+ 如若您的数据集比较大,可以尝试添加`--num_processes`参数:
341
+ ```shell
342
+ python preprocess_hubert_f0.py --f0_predictor dio --use_diff --num_processes 8
343
+ ```
344
+ 所有的Workers会被自动分配到多个线程上
345
+
346
+ 执行完以上步骤后 dataset 目录便是预处理完成的数据,可以删除 dataset_raw 文件夹了
347
+
348
+ ## 🏋️‍ 训练
349
+
350
+ ### 主模型训练
351
+
352
+ ```shell
353
+ python train.py -c configs/config.json -m 44k
354
+ ```
355
+
356
+ ### 扩散模型(可选)
357
+
358
+ 尚若需要浅扩散功能,需要训练扩散模型,扩散模型训练方法为:
359
+
360
+ ```shell
361
+ python train_diff.py -c configs/diffusion.yaml
362
+ ```
363
+
364
+ 模型训练结束后,模型文件保存在`logs/44k`目录下,扩散模型在`logs/44k/diffusion`下
365
+
366
+ ## 🤖 推理
367
+
368
+ 使用 [inference_main.py](inference_main.py)
369
+
370
+ ```shell
371
+ # 例
372
+ python inference_main.py -m "logs/44k/G_30400.pth" -c "configs/config.json" -n "君の知らない物語-src.wav" -t 0 -s "nen"
373
+ ```
374
+
375
+ 必填项部分:
376
+ + `-m` | `--model_path`:模型路径
377
+ + `-c` | `--config_path`:配置文件路径
378
+ + `-n` | `--clean_names`:wav 文件名列表,放在 raw 文件夹下
379
+ + `-t` | `--trans`:音高调整,支持正负(半音)
380
+ + `-s` | `--spk_list`:合成目标说话人名称
381
+ + `-cl` | `--clip`:音频强制切片,默认 0 为自动切片,单位为秒/s
382
+
383
+ 可选项部分:部分具体见下一节
384
+ + `-lg` | `--linear_gradient`:两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值 0,单位为秒
385
+ + `-f0p` | `--f0_predictor`:选择 F0 预测器,可选择 crepe,pm,dio,harvest,rmvpe,fcpe, 默认为 pm(注意:crepe 为原 F0 使用均值滤波器)
386
+ + `-a` | `--auto_predict_f0`:语音转换自动预测音高,转换歌声时不要打开这个会严重跑调
387
+ + `-cm` | `--cluster_model_path`:聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填
388
+ + `-cr` | `--cluster_infer_ratio`:聚类方案或特征检索占比,范围 0-1,若没有训练聚类模型或特征检索则默认 0 即可
389
+ + `-eh` | `--enhance`:是否使用 NSF_HIFIGAN 增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭
390
+ + `-shd` | `--shallow_diffusion`:是否使用浅层扩散,使用后可解决一部分电音问题,默认关闭,该选项打开时,NSF_HIFIGAN 增强器将会被禁止
391
+ + `-usm` | `--use_spk_mix`:是否使用角色融合/动态声线融合
392
+ + `-lea` | `--loudness_envelope_adjustment`:输入源响度包络替换输出响度包络融合比例,越靠近 1 越使用输出响度包络
393
+ + `-fr` | `--feature_retrieval`:是否使用特征检索,如果使用聚类模型将被禁用,且 cm 与 cr 参数将会变成特征检索的索引路径与混合比例
394
+
395
+ 浅扩散设置:
396
+ + `-dm` | `--diffusion_model_path`:扩散模型路径
397
+ + `-dc` | `--diffusion_config_path`:扩散模型配置文件路径
398
+ + `-ks` | `--k_step`:扩散步数,越大越接近扩散模型的结果,默认 100
399
+ + `-od` | `--only_diffusion`:纯扩散模式,该模式不会加载 sovits 模型,以扩散模型推理
400
+ + `-se` | `--second_encoding`:二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,有时候效果好,有时候效果差
401
+
402
+ ### 注意!
403
+
404
+ 如果使用`whisper-ppg` 声音编码器进行推理,需要将`--clip`设置为 25,`-lg`设置为 1。否则将无法正常推理。
405
+
406
+ ## 🤔 可选项
407
+
408
+ 如果前面的效果已经满意,或者没看明白下面在讲啥,那后面的内容都可以忽略,不影响模型使用(这些可选项影响比较小,可能在某些特定数据上有点效果,但大部分情况似乎都感知不太明显)
409
+
410
+ ### 自动 f0 预测
411
+
412
+ 4.0 模型训练过程会训练一个 f0 预测器,对于语音转换可以开启自动音高预测,如果效果不好也可以使用手动的,但转换歌声时请不要启用此功能!!!会严重跑调!!
413
+ + 在 inference_main 中设置 auto_predict_f0 为 true 即可
414
+
415
+ ### 聚类音色泄漏控制
416
+
417
+ 介绍:聚类方案可以减小音色泄漏,使得模型训练出来更像目标的音色(但其实不是特别明显),但是单纯的聚类方案会降低模型的咬字(会口齿不清)(这个很明显),本模型采用了融合的方式,可以线性控制聚类方案与非聚类方案的占比,也就是可以手动在"像目标音色" 和 "咬字清晰" 之间调整比例,找到合适的折中点
418
+
419
+ 使用聚类前面的已有步骤不用进行任何的变动,只需要额外训练一个聚类模型,虽然效果比较有限,但训练成本也比较低
420
+
421
+ + 训练过程:
422
+ + 使用 cpu 性能较好的机器训练,据我的经验在腾讯云 6 核 cpu 训练每个 speaker 需要约 4 分钟即可完成训练
423
+ + 执行`python cluster/train_cluster.py`,模型的输出会在`logs/44k/kmeans_10000.pt`
424
+ + 聚类模型目前可以使用 gpu 进行训练,执行`python cluster/train_cluster.py --gpu`
425
+ + 推理过程:
426
+ + `inference_main.py`中指定`cluster_model_path` 为模型输出文件,留空则默认为`logs/44k/kmeans_10000.pt`
427
+ + `inference_main.py`中指定`cluster_infer_ratio`,`0`为完全不使用聚类,`1`为只使用聚类,通常设置`0.5`即可
428
+
429
+ ### 特征检索
430
+
431
+ 介绍:跟聚类方案一样可以减小音色泄漏,咬字比聚类稍好,但会降低推理速度,采用了融合的方式,可以线性控制特征检索与非特征检索的占比,
432
+
433
+ + 训练过程:
434
+ 首先需要在生成 hubert 与 f0 后执行:
435
+
436
+ ```shell
437
+ python train_index.py -c configs/config.json
438
+ ```
439
+
440
+ 模型的输出会在`logs/44k/feature_and_index.pkl`
441
+
442
+ + 推理过程:
443
+ + 需要首先指定`--feature_retrieval`,此时聚类方案会自动切换到特征检索方案
444
+ + `inference_main.py`中指定`cluster_model_path` 为模型输出文件,留空则默认为`logs/44k/feature_and_index.pkl`
445
+ + `inference_main.py`中指定`cluster_infer_ratio`,`0`为完全不使用特征检索,`1`为只使用特征检索,通常设置`0.5`即可
446
+
447
+
448
+ ## 🗜️ 模型压缩
449
+
450
+ 生成的模型含有继续训练所需的信息。如果确认不再训练,可以移除模型中此部分信息,得到约 1/3 大小的最终模型。
451
+
452
+ 使用 [compress_model.py](compress_model.py)
453
+
454
+ ```shell
455
+ # 例
456
+ python compress_model.py -c="configs/config.json" -i="logs/44k/G_30400.pth" -o="logs/44k/release.pth"
457
+ ```
458
+
459
+ ## 👨‍🔧 声线混合
460
+
461
+ ### 静态声线混合
462
+
463
+ **参考`webUI.py`文件中,小工具/实验室特性的静态声线融合。**
464
+
465
+ 介绍:该功能可以将多个声音模型合成为一个声音模型(多个模型参数的凸组合或线性组合),从而制造出现实中不存在的声线
466
+ **注意:**
467
+
468
+ 1. 该功能仅支持单说话人的模型
469
+ 2. 如果强行使用多说话人模型,需要保证多个模型的说话人数量相同,这样可以混合同一个 SpaekerID 下的声音
470
+ 3. 保证所有待混合模型的 config.json 中的 model 字段是相同的
471
+ 4. 输出的混合模型可以使用待合成模型的任意一个 config.json,但聚类模型将不能使用
472
+ 5. 批量上传模型的时候最好把模型放到一个文件夹选中后一起上传
473
+ 6. 混合比例调整建议大小在 0-100 之间,也可以调为其他数字,但在线性组合模式下会出现未知的效果
474
+ 7. 混合完毕后,文件将会保存在项目根目录中,文件名为 output.pth
475
+ 8. 凸组合模式会将混合比例执行 Softmax 使混合比例相加为 1,而线性组合模式不会
476
+
477
+ ### 动态声线混合
478
+
479
+ **参考`spkmix.py`文件中关于动态声线混合的介绍**
480
+
481
+ 角色混合轨道 编写规则:
482
+
483
+ 角色 ID : \[\[起始时间 1, 终止时间 1, 起始数值 1, 起始数值 1], [起始时间 2, 终止时间 2, 起始数值 2, 起始数值 2]]
484
+
485
+ 起始时间和前一个的终止时间必须相同,第一个起始时间必须为 0,最后一个终止时间必须为 1 (时间的范围为 0-1)
486
+
487
+ 全部角色必须填写,不使用的角色填、[\[0., 1., 0., 0.]] 即可
488
+
489
+ 融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为 1(凸组合条件),可以放心使用
490
+
491
+ 推理的时候使用`--use_spk_mix`参数即可启用动态声线混合
492
+
493
+ ## 📤 Onnx 导出
494
+
495
+ 使用 [onnx_export.py](onnx_export.py)
496
+
497
+ + 新建文件夹:`checkpoints` 并打开
498
+ + 在`checkpoints`文件夹中新建一个文件夹作为项目文件夹,文件夹名为你的项目名称,比如`aziplayer`
499
+ + 将你的模型更名为`model.pth`,配置文件更名为`config.json`,并放置到刚才创建的`aziplayer`文件夹下
500
+ + 将 [onnx_export.py](onnx_export.py) 中`path = "NyaruTaffy"` 的 `"NyaruTaffy"` 修改为你的项目名称,`path = "aziplayer" (onnx_export_speaker_mix,为支持角色混合的 onnx 导出)`
501
+ + 运行 [onnx_export.py](onnx_export.py)
502
+ + 等待执行完毕,在你的项目文件夹下会生成一个`model.onnx`,即为导出的模型
503
+
504
+ 注意:Hubert Onnx 模型请使用 MoeSS 提供的模型,目前无法自行导出(fairseq 中 Hubert 有不少 onnx 不支持的算子和涉及到常量的东西,在导出时会报错或者导出的模型输入输出 shape 和结果都有问题)
505
+
506
+ ## 📎 引用及论文
507
+
508
+ | URL | 名称 | 标题 | 源码 |
509
+ | --- | ----------- | ----- | --------------------- |
510
+ |[2106.06103](https://arxiv.org/abs/2106.06103) | VITS (Synthesizer)| Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech | [jaywalnut310/vits](https://github.com/jaywalnut310/vits) |
511
+ |[2111.02392](https://arxiv.org/abs/2111.02392) | SoftVC (Speech Encoder)| A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion | [bshall/hubert](https://github.com/bshall/hubert) |
512
+ |[2204.09224](https://arxiv.org/abs/2204.09224) | ContentVec (Speech Encoder)| ContentVec: An Improved Self-Supervised Speech Representation by Disentangling Speakers | [auspicious3000/contentvec](https://github.com/auspicious3000/contentvec) |
513
+ |[2212.04356](https://arxiv.org/abs/2212.04356) | Whisper (Speech Encoder) | Robust Speech Recognition via Large-Scale Weak Supervision | [openai/whisper](https://github.com/openai/whisper) |
514
+ |[2110.13900](https://arxiv.org/abs/2110.13900) | WavLM (Speech Encoder) | WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing | [microsoft/unilm/wavlm](https://github.com/microsoft/unilm/tree/master/wavlm) |
515
+ |[2305.17651](https://arxiv.org/abs/2305.17651) | DPHubert (Speech Encoder) | DPHuBERT: Joint Distillation and Pruning of Self-Supervised Speech Models | [pyf98/DPHuBERT](https://github.com/pyf98/DPHuBERT) |
516
+ |[DOI:10.21437/Interspeech.2017-68](http://dx.doi.org/10.21437/Interspeech.2017-68) | Harvest (F0 Predictor) | Harvest: A high-performance fundamental frequency estimator from speech signals | [mmorise/World/harvest](https://github.com/mmorise/World/blob/master/src/harvest.cpp) |
517
+ |[aes35-000039](https://www.aes.org/e-lib/online/browse.cfm?elib=15165) | Dio (F0 Predictor) | Fast and reliable F0 estimation method based on the period extraction of vocal fold vibration of singing voice and speech | [mmorise/World/dio](https://github.com/mmorise/World/blob/master/src/dio.cpp) |
518
+ |[8461329](https://ieeexplore.ieee.org/document/8461329) | Crepe (F0 Predictor) | Crepe: A Convolutional Representation for Pitch Estimation | [maxrmorrison/torchcrepe](https://github.com/maxrmorrison/torchcrepe) |
519
+ |[DOI:10.1016/j.wocn.2018.07.001](https://doi.org/10.1016/j.wocn.2018.07.001) | Parselmouth (F0 Predictor) | Introducing Parselmouth: A Python interface to Praat | [YannickJadoul/Parselmouth](https://github.com/YannickJadoul/Parselmouth) |
520
+ |[2306.15412v2](https://arxiv.org/abs/2306.15412v2) | RMVPE (F0 Predictor) | RMVPE: A Robust Model for Vocal Pitch Estimation in Polyphonic Music | [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) |
521
+ |[2010.05646](https://arxiv.org/abs/2010.05646) | HIFIGAN (Vocoder) | HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | [jik876/hifi-gan](https://github.com/jik876/hifi-gan) |
522
+ |[1810.11946](https://arxiv.org/abs/1810.11946.pdf) | NSF (Vocoder) | Neural source-filter-based waveform model for statistical parametric speech synthesis | [openvpi/DiffSinger/modules/nsf_hifigan](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan)
523
+ |[2006.08195](https://arxiv.org/abs/2006.08195) | Snake (Vocoder) | Neural Networks Fail to Learn Periodic Functions and How to Fix It | [EdwardDixon/snake](https://github.com/EdwardDixon/snake)
524
+ |[2105.02446v3](https://arxiv.org/abs/2105.02446v3) | Shallow Diffusion (PostProcessing)| DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism | [CNChTu/Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) |
525
+ |[K-means](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=01D65490BADCC216F350D06F84D721AD?doi=10.1.1.308.8619&rep=rep1&type=pdf) | Feature K-means Clustering (PreProcessing)| Some methods for classification and analysis of multivariate observations | 本代码库 |
526
+ | | Feature TopK Retrieval (PreProcessing)| Retrieval based Voice Conversion | [RVC-Project/Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) |
527
+
528
+ ## ☀️ 旧贡献者
529
+
530
+ 因为某些原因原作者进行了删库处理,本仓库重建之初由于组织成员疏忽直接重新上传了所有文件导致以前的 contributors 全部木大,现在在 README 里重新添加一个旧贡献者列表
531
+
532
+ *某些成员已根据其个人意愿不将其列出*
533
+
534
+ <table>
535
+ <tr>
536
+ <td align="center"><a href="https://github.com/MistEO"><img src="https://avatars.githubusercontent.com/u/18511905?v=4" width="100px;" alt=""/><br /><sub><b>MistEO</b></sub></a><br /></td>
537
+ <td align="center"><a href="https://github.com/XiaoMiku01"><img src="https://avatars.githubusercontent.com/u/54094119?v=4" width="100px;" alt=""/><br /><sub><b>XiaoMiku01</b></sub></a><br /></td>
538
+ <td align="center"><a href="https://github.com/ForsakenRei"><img src="https://avatars.githubusercontent.com/u/23041178?v=4" width="100px;" alt=""/><br /><sub><b>しぐれ</b></sub></a><br /></td>
539
+ <td align="center"><a href="https://github.com/TomoGaSukunai"><img src="https://avatars.githubusercontent.com/u/25863522?v=4" width="100px;" alt=""/><br /><sub><b>TomoGaSukunai</b></sub></a><br /></td>
540
+ <td align="center"><a href="https://github.com/Plachtaa"><img src="https://avatars.githubusercontent.com/u/112609742?v=4" width="100px;" alt=""/><br /><sub><b>Plachtaa</b></sub></a><br /></td>
541
+ <td align="center"><a href="https://github.com/zdxiaoda"><img src="https://avatars.githubusercontent.com/u/45501959?v=4" width="100px;" alt=""/><br /><sub><b>zd 小达</b></sub></a><br /></td>
542
+ <td align="center"><a href="https://github.com/Archivoice"><img src="https://avatars.githubusercontent.com/u/107520869?v=4" width="100px;" alt=""/><br /><sub><b>凍聲響世</b></sub></a><br /></td>
543
+ </tr>
544
+ </table>
545
+
546
+ ## 📚 一些法律条例参考
547
+
548
+ #### 任何国家,地区,组织和个人使用此项目必须遵守以下法律
549
+
550
+ #### 《民法典》
551
+
552
+ ##### 第一千零一十九条
553
+
554
+ 任何组织或者个人不得以丑化、污损,或者利用信息技术手段伪造等方式侵害他人的肖像权。未经肖像权人同意,不得制作、使用、公开肖像权人的肖像,但是法律另有规定的除外。未经肖像权人同意,肖像作品权利人不得以发表、复制、发行、出租、展览等方式使用或者公开肖像权人的肖像。对自然人声音的保护,参照适用肖像权保护的有关规定。
555
+
556
+ ##### 第一千零二十四条
557
+
558
+ 【名誉权】民事主体享有名誉权。任何组织或者个人不得以侮辱、诽谤等方式侵害他人的名誉权。
559
+
560
+ ##### 第一千零二十七条
561
+
562
+ 【作品侵害名誉权】行为人发表的文学、艺术作品以真人真事或者特定人为描述对象,含有侮辱、诽谤内容,侵害他人名誉权的,受害人有权依法请求该行为人承担民事责任。行为人发表的文学、艺术作品不以特定人为描述对象,仅其中的情节与该特定人的情况相似的,不承担民事责任。
563
+
564
+ #### 《[中华人民共和国宪法](http://www.gov.cn/guoqing/2018-03/22/content_5276318.htm)》
565
+
566
+ #### 《[中华人民共和国刑法](http://gongbao.court.gov.cn/Details/f8e30d0689b23f57bfc782d21035c3.html?sw=中华人民共和国刑法)》
567
+
568
+ #### 《[中华人民共和国民法典](http://gongbao.court.gov.cn/Details/51eb6750b8361f79be8f90d09bc202.html)》
569
+
570
+ #### 《[中华人民共和国合同法](http://www.npc.gov.cn/zgrdw/npc/lfzt/rlyw/2016-07/01/content_1992739.htm)》
571
+
572
+ ## 💪 感谢所有的贡献者
573
+ <a href="https://github.com/svc-develop-team/so-vits-svc/graphs/contributors" target="_blank">
574
+ <img src="https://contrib.rocks/image?repo=svc-develop-team/so-vits-svc" />
575
+ </a>
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import tempfile
7
+ from functools import partial
8
+
9
+ import gradio as gr
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile
13
+
14
+ from edgetts.tts_voices import SUPPORTED_LANGUAGES
15
+ from inference.infer_tool import Svc
16
+
17
+ MAXOCTAVE = 2
18
+ TEMPDIR = None
19
+
20
+ def generate_tempfile(suffix=None, prefix=None):
21
+ global TEMPDIR
22
+ _, filepath = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=TEMPDIR)
23
+ return filepath
24
+
25
+ def find_sovits_model(dirpath):
26
+ for filename in os.listdir(dirpath):
27
+ if filename.endswith(".pth"):
28
+ return os.path.join(dirpath, filename)
29
+ return None
30
+
31
+ def find_diffusion_model(dirpath):
32
+ for filename in os.listdir(dirpath):
33
+ if filename.startswith("model") and filename.endswith(".pt"):
34
+ return os.path.join(dirpath, filename)
35
+ return None
36
+
37
+ def find_static_file(dirpath, filename):
38
+ filepath = os.path.join(dirpath, filename)
39
+ return filepath if os.path.exists(filepath) else None
40
+
41
+ def model_fn(modeldir, model, leakctrl, diffonly, enhancer):
42
+ if model is not None:
43
+ model.unload_model()
44
+
45
+ # locate trained models
46
+ sovits_model_path = find_sovits_model(modeldir)
47
+ sovits_config_path = find_static_file(modeldir, "config.json")
48
+ diffusion_model_path = find_diffusion_model(modeldir)
49
+ diffusion_config_path = find_static_file(modeldir, "config.yaml")
50
+ kmeans_model_path = find_static_file(modeldir, "kmeans_10000.pt")
51
+ feature_index_path = find_static_file(modeldir, "feature_and_index.pkl")
52
+
53
+ feature_retrieval = leakctrl == "Feature retrieval"
54
+ cluster_model_path = feature_index_path if feature_retrieval else kmeans_model_path
55
+
56
+ model = Svc(
57
+ sovits_model_path,
58
+ sovits_config_path,
59
+ cluster_model_path=cluster_model_path,
60
+ feature_retrieval=feature_retrieval,
61
+ diffusion_model_path=diffusion_model_path,
62
+ diffusion_config_path=diffusion_config_path,
63
+ shallow_diffusion=True,
64
+ only_diffusion=diffonly,
65
+ nsf_hifigan_enhance=enhancer,
66
+ )
67
+ speakers = list(model.spk2id.keys())
68
+
69
+ return (
70
+ model,
71
+ "Reload Model",
72
+ f"Successfully loaded model into device {str(model.dev)}",
73
+ gr.Dropdown(choices=speakers, value=speakers[0]),
74
+ )
75
+
76
+ def preset_fn(preset):
77
+ if preset == "Singing":
78
+ f0_predictor = "none"
79
+ leakctrl_ratio = 0.5
80
+ else:
81
+ f0_predictor = "rmvpe"
82
+ leakctrl_ratio = 0
83
+ """
84
+ f0_predictor, pitch_shift, leakctrl_ratio, diff_steps, noise_scale,
85
+ silent_padding, db_threshold, auto_clip, clip_overlap, cross_fade,
86
+ adaptive_key, crepe_f0, loudness_ratio, reencode_audio,
87
+ """
88
+ return (
89
+ f0_predictor, 0, leakctrl_ratio, 100, 0.4,
90
+ 0.5, -40, 0, 0, 0.75,
91
+ 0, 0.05, 0, False,
92
+ )
93
+
94
+ def tts_fn(text, gender, lang, rate, volume):
95
+ def to_percent(x):
96
+ return f"+{int(x * 100)}%" if x >= 0 else f"{int(x * 100)}%"
97
+
98
+ rate = to_percent(rate)
99
+ volume = to_percent(volume)
100
+
101
+ outfile = generate_tempfile(suffix=".wav")
102
+ subprocess.run([sys.executable, "edgetts/tts.py", text, lang, rate, volume, gender, outfile])
103
+ result, orig_sr = librosa.load(outfile)
104
+ os.remove(outfile)
105
+
106
+ target_sr = 44100
107
+ resampled = librosa.resample(result, orig_sr=orig_sr, target_sr=target_sr)
108
+ return target_sr, resampled
109
+
110
+ def inference_fn(
111
+ model, speaker, input_audio,
112
+ f0_predictor, pitch_shift, leakctrl_ratio, diff_steps, noise_scale,
113
+ silent_padding, db_threshold, auto_clip, clip_overlap, cross_fade,
114
+ adaptive_key, crepe_f0, loudness_ratio, reencode_audio,
115
+ ):
116
+ if model is None:
117
+ return "Error: please load model first", None
118
+ if input_audio is None:
119
+ return "Error: please upload an audio", None
120
+
121
+ sample_rate, audio = input_audio
122
+ if np.issubdtype(audio.dtype, np.integer):
123
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
124
+ if len(audio.shape) > 1:
125
+ audio = librosa.to_mono(audio.transpose(1, 0))
126
+
127
+ infile = generate_tempfile(suffix=".wav")
128
+ soundfile.write(infile, audio, sample_rate, format="wav")
129
+
130
+ result = model.slice_inference(
131
+ infile,
132
+ speaker,
133
+ pitch_shift,
134
+ db_threshold,
135
+ leakctrl_ratio,
136
+ f0_predictor != "none",
137
+ noise_scale,
138
+ pad_seconds=silent_padding,
139
+ clip_seconds=auto_clip,
140
+ lg_num=clip_overlap,
141
+ lgr_num=cross_fade,
142
+ f0_predictor="crepe" if f0_predictor == "none" else f0_predictor,
143
+ enhancer_adaptive_key=adaptive_key,
144
+ cr_threshold=crepe_f0,
145
+ k_step=diff_steps,
146
+ use_spk_mix=False,
147
+ second_encoding=reencode_audio,
148
+ loudness_envelope_adjustment=loudness_ratio,
149
+ )
150
+ model.clear_empty()
151
+ os.remove(infile)
152
+
153
+ # gr.Audio force normalize the audio if supplied as a numpy array
154
+ # we must write to a temporary file and return the filepath here
155
+ prefix = f"{speaker}_{f0_predictor}_pitch{pitch_shift}_timbre{leakctrl_ratio}_diff{diff_steps}_"
156
+ outfile = generate_tempfile(suffix=".wav", prefix=prefix)
157
+ soundfile.write(outfile, result, model.target_sample, format="wav")
158
+ return "Success", outfile
159
+
160
+ if __name__ == "__main__":
161
+ parser = argparse.ArgumentParser(description="so-vits-svc WebUI")
162
+ parser.add_argument("-m", "--model", default="./trained")
163
+ parser.add_argument("-t", "--temp", default="./workspace")
164
+ args = parser.parse_args()
165
+
166
+ shutil.rmtree(args.temp, ignore_errors=True)
167
+ os.makedirs(args.temp, exist_ok=True)
168
+ TEMPDIR = args.temp
169
+
170
+ with gr.Blocks() as app:
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ title = gr.Markdown(value="""# AI Sora Singing Voice Conversion""")
175
+ with gr.Column():
176
+ with gr.Accordion(label="About", open=False):
177
+ about = gr.Markdown(value="""Space by [KasugaiSakura](https://huggingface.co/KasugaiSakura)<br/>Based on a modified version of [so-vits-svc](https://github.com/meimisaki/so-vits-svc/tree/4.1-Stable)<br/>Voice copyright belongs to [CUFFS/Sphere](https://www.cuffs.co.jp/)""")
178
+
179
+ with gr.Row():
180
+ with gr.Column():
181
+ with gr.Accordion(label="Model setup", open=True):
182
+ leakctrl = gr.Radio(
183
+ label="Timbre leakage control method",
184
+ choices=["Feature retrieval", "K-means clustering"],
185
+ value="Feature retrieval",
186
+ )
187
+ diffonly = gr.Checkbox(label="Diffusion only mode")
188
+ enhancer = gr.Checkbox(label="NSF-HiFiGAN enhancer (not recommended)")
189
+ modelptr = gr.State(None)
190
+ modelbtn = gr.Button(value="Load Model", variant="primary")
191
+ modelmsg = gr.Textbox(label="Model info")
192
+ speaker = gr.Dropdown(label="Speaker", interactive=True)
193
+
194
+ with gr.Accordion(label="Text to speech", open=False):
195
+ tts_text = gr.Textbox(label="Text", placeholder="Enter text here")
196
+ tts_gender = gr.Radio(label="Gender", choices=["Male","Female"], value="Male")
197
+ tts_lang = gr.Dropdown(label="Language", choices=SUPPORTED_LANGUAGES, value="Auto")
198
+ tts_rate = gr.Slider(
199
+ label="Relative speed",
200
+ minimum=-1, maximum=3, value=0, step=0.1
201
+ )
202
+ tts_volume = gr.Slider(
203
+ label="Relative volume",
204
+ minimum=-1, maximum=1.5, value=0, step=0.1
205
+ )
206
+ tts_btn = gr.Button(value="Synthesize")
207
+
208
+ with gr.Accordion(label="Voice conversion", open=True):
209
+ input_audio = gr.Audio(label="Input audio", type="numpy")
210
+ inference_btn = gr.Button(value="Inference")
211
+ output_msg = gr.Textbox(label="Output message")
212
+ output_audio = gr.Audio(label="Output audio", type="filepath")
213
+
214
+ with gr.Column():
215
+ with gr.Accordion(label="Inference options", open=True):
216
+ inference_preset = gr.Radio(
217
+ label="Preset",
218
+ choices=["Singing", "Speaking"],
219
+ value="Singing",
220
+ interactive=True,
221
+ )
222
+ f0_predictor = gr.Dropdown(
223
+ label="F0 predictor",
224
+ choices=["none", "crepe", "dio", "harvest", "pm", "rmvpe"],
225
+ value="none",
226
+ )
227
+ pitch_shift = gr.Slider(
228
+ label="Pitch shift (in semitones, 12 in an octave)",
229
+ minimum=-12*MAXOCTAVE, maximum=12*MAXOCTAVE, value=0, step=1,
230
+ )
231
+ leakctrl_ratio = gr.Slider(
232
+ label="Timbre leakage control mix ratio (set to 0 to disable it)",
233
+ minimum=0, maximum=1, value=0.5, step=0.1,
234
+ )
235
+ diff_steps = gr.Slider(
236
+ label="Shallow diffusion steps",
237
+ minimum=0, maximum=1000, value=100, step=10,
238
+ )
239
+ noise_scale = gr.Slider(
240
+ label="Noise scale (try NOT to modify this parameter)",
241
+ minimum=0, maximum=1, value=0.4, step=0.01,
242
+ )
243
+ silent_padding = gr.Slider(
244
+ label="Add silent padding to workaround noise caused by unknown reason (in seconds)",
245
+ minimum=0, maximum=3, value=0.5, step=0.01,
246
+ )
247
+ db_threshold = gr.Slider(
248
+ label="Silence dB threshold (for slicing audio into chunks)",
249
+ minimum=-100, maximum=0, value=-40, step=1,
250
+ )
251
+ auto_clip = gr.Slider(
252
+ label="Apply auto clip to reduce memory consumption (in seconds)",
253
+ minimum=0, maximum=100, value=0, step=1,
254
+ )
255
+ clip_overlap = gr.Slider(
256
+ label="Overlap duration between auto clips (in seconds)",
257
+ minimum=0, maximum=3, value=0, step=0.01,
258
+ )
259
+ cross_fade = gr.Slider(
260
+ label="Cross fade ratio of overlapping regions",
261
+ minimum=0, maximum=1, value=0.75, step=0.01,
262
+ )
263
+ adaptive_key = gr.Slider(
264
+ label="Enhancer adaptive key (in semitones, 12 in an octave)",
265
+ minimum=-12*MAXOCTAVE, maximum=12*MAXOCTAVE, value=0, step=1,
266
+ )
267
+ crepe_f0 = gr.Slider(
268
+ label="CREPE F0 threshold (increase to reduce noise but may result in out-of-tune)",
269
+ minimum=0, maximum=1, value=0.05, step=0.01,
270
+ )
271
+ loudness_ratio = gr.Slider(
272
+ label="Loudness envelope mix ratio of input and output (0 is input and 1 is output)",
273
+ minimum=0, maximum=1, value=0, step=0.01,
274
+ )
275
+ reencode_audio = gr.Checkbox(
276
+ label="Re-encode audio before shallow diffusion, with unknown impact on final result"
277
+ )
278
+
279
+ modelbtn.click(
280
+ partial(model_fn, args.model),
281
+ inputs=[modelptr, leakctrl, diffonly, enhancer],
282
+ outputs=[modelptr, modelbtn, modelmsg, speaker],
283
+ )
284
+
285
+ inference_preset.change(
286
+ preset_fn,
287
+ inputs=[inference_preset],
288
+ outputs=[
289
+ f0_predictor, pitch_shift, leakctrl_ratio, diff_steps, noise_scale,
290
+ silent_padding, db_threshold, auto_clip, clip_overlap, cross_fade,
291
+ adaptive_key, crepe_f0, loudness_ratio, reencode_audio,
292
+ ],
293
+ )
294
+
295
+ tts_btn.click(
296
+ tts_fn,
297
+ inputs=[tts_text, tts_gender, tts_lang, tts_rate, tts_volume],
298
+ outputs=[input_audio],
299
+ )
300
+
301
+ inference_btn.click(
302
+ inference_fn,
303
+ inputs=[
304
+ modelptr, speaker, input_audio,
305
+ f0_predictor, pitch_shift, leakctrl_ratio, diff_steps, noise_scale,
306
+ silent_padding, db_threshold, auto_clip, clip_overlap, cross_fade,
307
+ adaptive_key, crepe_f0, loudness_ratio, reencode_audio,
308
+ ],
309
+ outputs=[output_msg, output_audio],
310
+ )
311
+
312
+ app.launch(debug=True, share=True)
cluster/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sklearn.cluster import KMeans
3
+
4
+
5
+ def get_cluster_model(ckpt_path):
6
+ checkpoint = torch.load(ckpt_path)
7
+ kmeans_dict = {}
8
+ for spk, ckpt in checkpoint.items():
9
+ km = KMeans(ckpt["n_features_in_"])
10
+ km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
11
+ km.__dict__["_n_threads"] = ckpt["_n_threads"]
12
+ km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
13
+ kmeans_dict[spk] = km
14
+ return kmeans_dict
15
+
16
+ def get_cluster_result(model, x, speaker):
17
+ """
18
+ x: np.array [t, 256]
19
+ return cluster class result
20
+ """
21
+ return model[speaker].predict(x)
22
+
23
+ def get_cluster_center_result(model, x,speaker):
24
+ """x: np.array [t, 256]"""
25
+ predict = model[speaker].predict(x)
26
+ return model[speaker].cluster_centers_[predict]
27
+
28
+ def get_center(model, x,speaker):
29
+ return model[speaker].cluster_centers_[x]
cluster/kmeans.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import pynvml
5
+ import torch
6
+ from torch.nn.functional import normalize
7
+
8
+
9
+ # device=torch.device("cuda:0")
10
+ def _kpp(data: torch.Tensor, k: int, sample_size: int = -1):
11
+ """ Picks k points in the data based on the kmeans++ method.
12
+
13
+ Parameters
14
+ ----------
15
+ data : torch.Tensor
16
+ Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
17
+ data, rank 2 multidimensional data, in which case one
18
+ row is one observation.
19
+ k : int
20
+ Number of samples to generate.
21
+ sample_size : int
22
+ sample data to avoid memory overflow during calculation
23
+
24
+ Returns
25
+ -------
26
+ init : ndarray
27
+ A 'k' by 'N' containing the initial centroids.
28
+
29
+ References
30
+ ----------
31
+ .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
32
+ careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
33
+ on Discrete Algorithms, 2007.
34
+ .. [2] scipy/cluster/vq.py: _kpp
35
+ """
36
+ batch_size=data.shape[0]
37
+ if batch_size>sample_size:
38
+ data = data[torch.randint(0, batch_size,[sample_size], device=data.device)]
39
+ dims = data.shape[1] if len(data.shape) > 1 else 1
40
+ init = torch.zeros((k, dims)).to(data.device)
41
+ r = torch.distributions.uniform.Uniform(0, 1)
42
+ for i in range(k):
43
+ if i == 0:
44
+ init[i, :] = data[torch.randint(data.shape[0], [1])]
45
+ else:
46
+ D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0)
47
+ probs = D2 / torch.sum(D2)
48
+ cumprobs = torch.cumsum(probs, dim=0)
49
+ init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))]
50
+ return init
51
+ class KMeansGPU:
52
+ '''
53
+ Kmeans clustering algorithm implemented with PyTorch
54
+
55
+ Parameters:
56
+ n_clusters: int,
57
+ Number of clusters
58
+
59
+ max_iter: int, default: 100
60
+ Maximum number of iterations
61
+
62
+ tol: float, default: 0.0001
63
+ Tolerance
64
+
65
+ verbose: int, default: 0
66
+ Verbosity
67
+
68
+ mode: {'euclidean', 'cosine'}, default: 'euclidean'
69
+ Type of distance measure
70
+
71
+ init_method: {'random', 'point', '++'}
72
+ Type of initialization
73
+
74
+ minibatch: {None, int}, default: None
75
+ Batch size of MinibatchKmeans algorithm
76
+ if None perform full KMeans algorithm
77
+
78
+ Attributes:
79
+ centroids: torch.Tensor, shape: [n_clusters, n_features]
80
+ cluster centroids
81
+ '''
82
+ def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")):
83
+ self.n_clusters = n_clusters
84
+ self.max_iter = max_iter
85
+ self.tol = tol
86
+ self.verbose = verbose
87
+ self.mode = mode
88
+ self.device=device
89
+ pynvml.nvmlInit()
90
+ gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index)
91
+ info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
92
+ self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024)
93
+ print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch)
94
+
95
+ @staticmethod
96
+ def cos_sim(a, b):
97
+ """
98
+ Compute cosine similarity of 2 sets of vectors
99
+
100
+ Parameters:
101
+ a: torch.Tensor, shape: [m, n_features]
102
+
103
+ b: torch.Tensor, shape: [n, n_features]
104
+ """
105
+ return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1)
106
+
107
+ @staticmethod
108
+ def euc_sim(a, b):
109
+ """
110
+ Compute euclidean similarity of 2 sets of vectors
111
+ Parameters:
112
+ a: torch.Tensor, shape: [m, n_features]
113
+ b: torch.Tensor, shape: [n, n_features]
114
+ """
115
+ return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :]
116
+
117
+ def max_sim(self, a, b):
118
+ """
119
+ Compute maximum similarity (or minimum distance) of each vector
120
+ in a with all of the vectors in b
121
+ Parameters:
122
+ a: torch.Tensor, shape: [m, n_features]
123
+ b: torch.Tensor, shape: [n, n_features]
124
+ """
125
+ if self.mode == 'cosine':
126
+ sim_func = self.cos_sim
127
+ elif self.mode == 'euclidean':
128
+ sim_func = self.euc_sim
129
+ sim = sim_func(a, b)
130
+ max_sim_v, max_sim_i = sim.max(dim=-1)
131
+ return max_sim_v, max_sim_i
132
+
133
+ def fit_predict(self, X):
134
+ """
135
+ Combination of fit() and predict() methods.
136
+ This is faster than calling fit() and predict() seperately.
137
+ Parameters:
138
+ X: torch.Tensor, shape: [n_samples, n_features]
139
+ centroids: {torch.Tensor, None}, default: None
140
+ if given, centroids will be initialized with given tensor
141
+ if None, centroids will be randomly chosen from X
142
+ Return:
143
+ labels: torch.Tensor, shape: [n_samples]
144
+
145
+ mini_=33kk/k*remain
146
+ mini=min(mini_,fea_shape)
147
+ offset=log2(k/1000)*1.5
148
+ kpp_all=min(mini_*10/offset,fea_shape)
149
+ kpp_sample=min(mini_/12/offset,fea_shape)
150
+ """
151
+ assert isinstance(X, torch.Tensor), "input must be torch.Tensor"
152
+ assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point"
153
+ assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] "
154
+ # print("verbose:%s"%self.verbose)
155
+
156
+ offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2)
157
+ with torch.no_grad():
158
+ batch_size= X.shape[0]
159
+ # print(self.minibatch, int(self.minibatch * 10 / offset), batch_size)
160
+ start_time = time()
161
+ if (self.minibatch*10//offset< batch_size):
162
+ x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device)
163
+ else:
164
+ x = X.to(self.device)
165
+ # print(x.device)
166
+ self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size))
167
+ del x
168
+ torch.cuda.empty_cache()
169
+ # self.centroids = self.centroids.to(self.device)
170
+ num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype)#全1
171
+ closest = None#[3098036]#int64
172
+ if(self.minibatch>=batch_size//2 and self.minibatch<batch_size):
173
+ X = X[torch.randint(0, batch_size,[self.minibatch])].to(self.device)
174
+ elif(self.minibatch>=batch_size):
175
+ X=X.to(self.device)
176
+ for i in range(self.max_iter):
177
+ iter_time = time()
178
+ if self.minibatch<batch_size//2:#可用minibatch数太小,每次都得从内存倒腾到显存
179
+ x = X[torch.randint(0, batch_size, [self.minibatch])].to(self.device)
180
+ else:#否则直接全部缓存
181
+ x = X
182
+
183
+ closest = self.max_sim(a=x, b=self.centroids)[1].to(torch.int16)#[3098036]#int64#0~999
184
+ matched_clusters, counts = closest.unique(return_counts=True)#int64#1k
185
+ expanded_closest = closest[None].expand(self.n_clusters, -1)#[1000, 3098036]#int16#0~999
186
+ mask = (expanded_closest==torch.arange(self.n_clusters, device=self.device)[:, None]).to(X.dtype)#==后者是int64*1000
187
+ c_grad = mask @ x / mask.sum(-1)[..., :, None]
188
+ c_grad[c_grad!=c_grad] = 0 # remove NaNs
189
+ error = (c_grad - self.centroids).pow(2).sum()
190
+ if self.minibatch is not None:
191
+ lr = 1/num_points_in_clusters[:,None] * 0.9 + 0.1
192
+ else:
193
+ lr = 1
194
+ matched_clusters=matched_clusters.long()
195
+ num_points_in_clusters[matched_clusters] += counts#IndexError: tensors used as indices must be long, byte or bool tensors
196
+ self.centroids = self.centroids * (1-lr) + c_grad * lr
197
+ if self.verbose >= 2:
198
+ print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4))
199
+ if error <= self.tol:
200
+ break
201
+
202
+ if self.verbose >= 1:
203
+ print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters')
204
+ return closest
cluster/train_cluster.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import tqdm
10
+ from kmeans import KMeansGPU
11
+ from sklearn.cluster import KMeans, MiniBatchKMeans
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑
17
+ if str(in_dir).endswith(".ipynb_checkpoints"):
18
+ logger.info(f"Ignore {in_dir}")
19
+
20
+ logger.info(f"Loading features from {in_dir}")
21
+ features = []
22
+ nums = 0
23
+ for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
24
+ # for name in os.listdir(in_dir):
25
+ # path="%s/%s"%(in_dir,name)
26
+ features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T)
27
+ # print(features[-1].shape)
28
+ features = np.concatenate(features, axis=0)
29
+ print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
30
+ features = features.astype(np.float32)
31
+ logger.info(f"Clustering features of shape: {features.shape}")
32
+ t = time.time()
33
+ if(use_gpu is False):
34
+ if use_minibatch:
35
+ kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
36
+ else:
37
+ kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
38
+ else:
39
+ kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)#
40
+ features=torch.from_numpy(features)#.to(device)
41
+ kmeans.fit_predict(features)#
42
+
43
+ print(time.time()-t, "s")
44
+
45
+ x = {
46
+ "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1],
47
+ "_n_threads": kmeans._n_threads if use_gpu is False else 4,
48
+ "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(),
49
+ }
50
+ print("end")
51
+
52
+ return x
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('--dataset', type=Path, default="./dataset/44k",
57
+ help='path of training data directory')
58
+ parser.add_argument('--output', type=Path, default="logs/44k",
59
+ help='path of model output directory')
60
+ parser.add_argument('--gpu',action='store_true', default=False ,
61
+ help='to use GPU')
62
+
63
+
64
+ args = parser.parse_args()
65
+
66
+ checkpoint_dir = args.output
67
+ dataset = args.dataset
68
+ use_gpu = args.gpu
69
+ n_clusters = 10000
70
+
71
+ ckpt = {}
72
+ for spk in os.listdir(dataset):
73
+ if os.path.isdir(dataset/spk):
74
+ print(f"train kmeans for {spk}...")
75
+ in_dir = dataset/spk
76
+ x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu)
77
+ ckpt[spk] = x
78
+
79
+ checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
80
+ checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
81
+ torch.save(
82
+ ckpt,
83
+ checkpoint_path,
84
+ )
85
+
compress_model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+ import utils
6
+ from models import SynthesizerTrn
7
+
8
+
9
+ def copyStateDict(state_dict):
10
+ if list(state_dict.keys())[0].startswith('module'):
11
+ start_idx = 1
12
+ else:
13
+ start_idx = 0
14
+ new_state_dict = OrderedDict()
15
+ for k, v in state_dict.items():
16
+ name = ','.join(k.split('.')[start_idx:])
17
+ new_state_dict[name] = v
18
+ return new_state_dict
19
+
20
+
21
+ def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
22
+ hps = utils.get_hparams_from_file(config)
23
+
24
+ net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
25
+ hps.train.segment_size // hps.data.hop_length,
26
+ **hps.model)
27
+
28
+ optim_g = torch.optim.AdamW(net_g.parameters(),
29
+ hps.train.learning_rate,
30
+ betas=hps.train.betas,
31
+ eps=hps.train.eps)
32
+
33
+ state_dict_g = torch.load(input_model, map_location="cpu")
34
+ new_dict_g = copyStateDict(state_dict_g)
35
+ keys = []
36
+ for k, v in new_dict_g['model'].items():
37
+ if "enc_q" in k: continue # noqa: E701
38
+ keys.append(k)
39
+
40
+ new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
41
+
42
+ torch.save(
43
+ {
44
+ 'model': new_dict_g,
45
+ 'iteration': 0,
46
+ 'optimizer': optim_g.state_dict(),
47
+ 'learning_rate': 0.0001
48
+ }, output_model)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ import argparse
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("-c",
55
+ "--config",
56
+ type=str,
57
+ default='configs/config.json')
58
+ parser.add_argument("-i", "--input", type=str)
59
+ parser.add_argument("-o", "--output", type=str, default=None)
60
+ parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
61
+
62
+ args = parser.parse_args()
63
+
64
+ output = args.output
65
+
66
+ if output is None:
67
+ import os.path
68
+ filename, ext = os.path.splitext(args.input)
69
+ half = "_half" if args.half else ""
70
+ output = filename + "_release" + half + ext
71
+
72
+ removeOptimizer(args.config, args.input, args.half, output)
configs/diffusion.yaml ADDED
File without changes
configs_template/config_template.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 800,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 6,
14
+ "fp16_run": false,
15
+ "half_type": "fp16",
16
+ "lr_decay": 0.999875,
17
+ "segment_size": 10240,
18
+ "init_lr_ratio": 1,
19
+ "warmup_epochs": 0,
20
+ "c_mel": 45,
21
+ "c_kl": 1.0,
22
+ "use_sr": true,
23
+ "max_speclen": 512,
24
+ "port": "8001",
25
+ "keep_ckpts": 3,
26
+ "all_in_mem": false,
27
+ "vol_aug":false
28
+ },
29
+ "data": {
30
+ "training_files": "filelists/train.txt",
31
+ "validation_files": "filelists/val.txt",
32
+ "max_wav_value": 32768.0,
33
+ "sampling_rate": 44100,
34
+ "filter_length": 2048,
35
+ "hop_length": 512,
36
+ "win_length": 2048,
37
+ "n_mel_channels": 80,
38
+ "mel_fmin": 0.0,
39
+ "mel_fmax": 22050,
40
+ "unit_interpolate_mode":"nearest"
41
+ },
42
+ "model": {
43
+ "inter_channels": 192,
44
+ "hidden_channels": 192,
45
+ "filter_channels": 768,
46
+ "n_heads": 2,
47
+ "n_layers": 6,
48
+ "kernel_size": 3,
49
+ "p_dropout": 0.1,
50
+ "resblock": "1",
51
+ "resblock_kernel_sizes": [3,7,11],
52
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
53
+ "upsample_rates": [ 8, 8, 2, 2, 2],
54
+ "upsample_initial_channel": 512,
55
+ "upsample_kernel_sizes": [16,16, 4, 4, 4],
56
+ "n_layers_q": 3,
57
+ "n_layers_trans_flow": 3,
58
+ "n_flow_layer": 4,
59
+ "use_spectral_norm": false,
60
+ "gin_channels": 768,
61
+ "ssl_dim": 768,
62
+ "n_speakers": 200,
63
+ "vocoder_name":"nsf-hifigan",
64
+ "speech_encoder":"vec768l12",
65
+ "speaker_embedding":false,
66
+ "vol_embedding":false,
67
+ "use_depthwise_conv":false,
68
+ "flow_share_parameter": false,
69
+ "use_automatic_f0_prediction": true,
70
+ "use_transformer_flow": false
71
+ },
72
+ "spk": {
73
+ "nyaru": 0,
74
+ "huiyu": 1,
75
+ "nen": 2,
76
+ "paimon": 3,
77
+ "yunhao": 4
78
+ }
79
+ }
configs_template/config_tiny_template.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 800,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 6,
14
+ "fp16_run": false,
15
+ "half_type": "fp16",
16
+ "lr_decay": 0.999875,
17
+ "segment_size": 10240,
18
+ "init_lr_ratio": 1,
19
+ "warmup_epochs": 0,
20
+ "c_mel": 45,
21
+ "c_kl": 1.0,
22
+ "use_sr": true,
23
+ "max_speclen": 512,
24
+ "port": "8001",
25
+ "keep_ckpts": 3,
26
+ "all_in_mem": false,
27
+ "vol_aug":false
28
+ },
29
+ "data": {
30
+ "training_files": "filelists/train.txt",
31
+ "validation_files": "filelists/val.txt",
32
+ "max_wav_value": 32768.0,
33
+ "sampling_rate": 44100,
34
+ "filter_length": 2048,
35
+ "hop_length": 512,
36
+ "win_length": 2048,
37
+ "n_mel_channels": 80,
38
+ "mel_fmin": 0.0,
39
+ "mel_fmax": 22050,
40
+ "unit_interpolate_mode":"nearest"
41
+ },
42
+ "model": {
43
+ "inter_channels": 192,
44
+ "hidden_channels": 192,
45
+ "filter_channels": 512,
46
+ "n_heads": 2,
47
+ "n_layers": 6,
48
+ "kernel_size": 3,
49
+ "p_dropout": 0.1,
50
+ "resblock": "1",
51
+ "resblock_kernel_sizes": [3,7,11],
52
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
53
+ "upsample_rates": [ 8, 8, 2, 2, 2],
54
+ "upsample_initial_channel": 400,
55
+ "upsample_kernel_sizes": [16,16, 4, 4, 4],
56
+ "n_layers_q": 3,
57
+ "n_layers_trans_flow": 3,
58
+ "n_flow_layer": 4,
59
+ "use_spectral_norm": false,
60
+ "gin_channels": 768,
61
+ "ssl_dim": 768,
62
+ "n_speakers": 200,
63
+ "vocoder_name":"nsf-hifigan",
64
+ "speech_encoder":"vec768l12",
65
+ "speaker_embedding":false,
66
+ "vol_embedding":false,
67
+ "use_depthwise_conv":true,
68
+ "flow_share_parameter": true,
69
+ "use_automatic_f0_prediction": true,
70
+ "use_transformer_flow": false
71
+ },
72
+ "spk": {
73
+ "nyaru": 0,
74
+ "huiyu": 1,
75
+ "nen": 2,
76
+ "paimon": 3,
77
+ "yunhao": 4
78
+ }
79
+ }
configs_template/diffusion_template.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ sampling_rate: 44100
3
+ block_size: 512 # Equal to hop_length
4
+ duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip
5
+ encoder: 'vec768l12' # 'hubertsoft', 'vec256l9', 'vec768l12'
6
+ cnhubertsoft_gate: 10
7
+ encoder_sample_rate: 16000
8
+ encoder_hop_size: 320
9
+ encoder_out_channels: 768 # 256 if using 'hubertsoft'
10
+ training_files: "filelists/train.txt"
11
+ validation_files: "filelists/val.txt"
12
+ extensions: # List of extension included in the data collection
13
+ - wav
14
+ unit_interpolate_mode: "nearest"
15
+ model:
16
+ type: 'Diffusion'
17
+ n_layers: 20
18
+ n_chans: 512
19
+ n_hidden: 256
20
+ use_pitch_aug: true
21
+ timesteps : 1000
22
+ k_step_max: 0 # must <= timesteps, If it is 0, train all
23
+ n_spk: 1 # max number of different speakers
24
+ device: cuda
25
+ vocoder:
26
+ type: 'nsf-hifigan'
27
+ ckpt: 'pretrain/nsf_hifigan/model'
28
+ infer:
29
+ speedup: 10
30
+ method: 'dpm-solver++' # 'pndm' or 'dpm-solver' or 'ddim' or 'unipc' or 'dpm-solver++'
31
+ env:
32
+ expdir: logs/44k/diffusion
33
+ gpu_id: 0
34
+ train:
35
+ num_workers: 4 # If your cpu and gpu are both very strong, set to 0 may be faster!
36
+ amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu)
37
+ batch_size: 48
38
+ cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow
39
+ cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu
40
+ cache_fp16: true
41
+ epochs: 100000
42
+ interval_log: 10
43
+ interval_val: 2000
44
+ interval_force_save: 5000
45
+ lr: 0.0001
46
+ decay_step: 100000
47
+ gamma: 0.5
48
+ weight_decay: 0
49
+ save_opt: false
50
+ spk:
51
+ 'nyaru': 0
data_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data
7
+
8
+ import utils
9
+ from modules.mel_processing import spectrogram_torch
10
+ from utils import load_filepaths_and_text, load_wav_to_torch
11
+
12
+ # import h5py
13
+
14
+
15
+ """Multi speaker version"""
16
+
17
+
18
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
19
+ """
20
+ 1) loads audio, speaker_id, text pairs
21
+ 2) normalizes text and converts them to sequences of integers
22
+ 3) computes spectrograms from audio files.
23
+ """
24
+
25
+ def __init__(self, audiopaths, hparams, all_in_mem: bool = False, vol_aug: bool = True):
26
+ self.audiopaths = load_filepaths_and_text(audiopaths)
27
+ self.hparams = hparams
28
+ self.max_wav_value = hparams.data.max_wav_value
29
+ self.sampling_rate = hparams.data.sampling_rate
30
+ self.filter_length = hparams.data.filter_length
31
+ self.hop_length = hparams.data.hop_length
32
+ self.win_length = hparams.data.win_length
33
+ self.unit_interpolate_mode = hparams.data.unit_interpolate_mode
34
+ self.sampling_rate = hparams.data.sampling_rate
35
+ self.use_sr = hparams.train.use_sr
36
+ self.spec_len = hparams.train.max_speclen
37
+ self.spk_map = hparams.spk
38
+ self.vol_emb = hparams.model.vol_embedding
39
+ self.vol_aug = hparams.train.vol_aug and vol_aug
40
+ random.seed(1234)
41
+ random.shuffle(self.audiopaths)
42
+
43
+ self.all_in_mem = all_in_mem
44
+ if self.all_in_mem:
45
+ self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
46
+
47
+ def get_audio(self, filename):
48
+ filename = filename.replace("\\", "/")
49
+ audio, sampling_rate = load_wav_to_torch(filename)
50
+ if sampling_rate != self.sampling_rate:
51
+ raise ValueError(
52
+ "Sample Rate not match. Expect {} but got {} from {}".format(
53
+ self.sampling_rate, sampling_rate, filename))
54
+ audio_norm = audio / self.max_wav_value
55
+ audio_norm = audio_norm.unsqueeze(0)
56
+ spec_filename = filename.replace(".wav", ".spec.pt")
57
+
58
+ # Ideally, all data generated after Mar 25 should have .spec.pt
59
+ if os.path.exists(spec_filename):
60
+ spec = torch.load(spec_filename)
61
+ else:
62
+ spec = spectrogram_torch(audio_norm, self.filter_length,
63
+ self.sampling_rate, self.hop_length, self.win_length,
64
+ center=False)
65
+ spec = torch.squeeze(spec, 0)
66
+ torch.save(spec, spec_filename)
67
+
68
+ spk = filename.split("/")[-2]
69
+ spk = torch.LongTensor([self.spk_map[spk]])
70
+
71
+ f0, uv = np.load(filename + ".f0.npy",allow_pickle=True)
72
+
73
+ f0 = torch.FloatTensor(np.array(f0,dtype=float))
74
+ uv = torch.FloatTensor(np.array(uv,dtype=float))
75
+
76
+ c = torch.load(filename+ ".soft.pt")
77
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0], mode=self.unit_interpolate_mode)
78
+ if self.vol_emb:
79
+ volume_path = filename + ".vol.npy"
80
+ volume = np.load(volume_path)
81
+ volume = torch.from_numpy(volume).float()
82
+ else:
83
+ volume = None
84
+
85
+ lmin = min(c.size(-1), spec.size(-1))
86
+ assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
87
+ assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
88
+ spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
89
+ audio_norm = audio_norm[:, :lmin * self.hop_length]
90
+ if volume is not None:
91
+ volume = volume[:lmin]
92
+ return c, f0, spec, audio_norm, spk, uv, volume
93
+
94
+ def random_slice(self, c, f0, spec, audio_norm, spk, uv, volume):
95
+ # if spec.shape[1] < 30:
96
+ # print("skip too short audio:", filename)
97
+ # return None
98
+
99
+ if random.choice([True, False]) and self.vol_aug and volume is not None:
100
+ max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5
101
+ max_shift = min(1, np.log10(1/max_amp))
102
+ log10_vol_shift = random.uniform(-1, max_shift)
103
+ audio_norm = audio_norm * (10 ** log10_vol_shift)
104
+ volume = volume * (10 ** log10_vol_shift)
105
+ spec = spectrogram_torch(audio_norm,
106
+ self.hparams.data.filter_length,
107
+ self.hparams.data.sampling_rate,
108
+ self.hparams.data.hop_length,
109
+ self.hparams.data.win_length,
110
+ center=False)[0]
111
+
112
+ if spec.shape[1] > 800:
113
+ start = random.randint(0, spec.shape[1]-800)
114
+ end = start + 790
115
+ spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end]
116
+ audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length]
117
+ if volume is not None:
118
+ volume = volume[start:end]
119
+ return c, f0, spec, audio_norm, spk, uv,volume
120
+
121
+ def __getitem__(self, index):
122
+ if self.all_in_mem:
123
+ return self.random_slice(*self.cache[index])
124
+ else:
125
+ return self.random_slice(*self.get_audio(self.audiopaths[index][0]))
126
+
127
+ def __len__(self):
128
+ return len(self.audiopaths)
129
+
130
+
131
+ class TextAudioCollate:
132
+
133
+ def __call__(self, batch):
134
+ batch = [b for b in batch if b is not None]
135
+
136
+ input_lengths, ids_sorted_decreasing = torch.sort(
137
+ torch.LongTensor([x[0].shape[1] for x in batch]),
138
+ dim=0, descending=True)
139
+
140
+ max_c_len = max([x[0].size(1) for x in batch])
141
+ max_wav_len = max([x[3].size(1) for x in batch])
142
+
143
+ lengths = torch.LongTensor(len(batch))
144
+
145
+ c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len)
146
+ f0_padded = torch.FloatTensor(len(batch), max_c_len)
147
+ spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len)
148
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
149
+ spkids = torch.LongTensor(len(batch), 1)
150
+ uv_padded = torch.FloatTensor(len(batch), max_c_len)
151
+ volume_padded = torch.FloatTensor(len(batch), max_c_len)
152
+
153
+ c_padded.zero_()
154
+ spec_padded.zero_()
155
+ f0_padded.zero_()
156
+ wav_padded.zero_()
157
+ uv_padded.zero_()
158
+ volume_padded.zero_()
159
+
160
+ for i in range(len(ids_sorted_decreasing)):
161
+ row = batch[ids_sorted_decreasing[i]]
162
+
163
+ c = row[0]
164
+ c_padded[i, :, :c.size(1)] = c
165
+ lengths[i] = c.size(1)
166
+
167
+ f0 = row[1]
168
+ f0_padded[i, :f0.size(0)] = f0
169
+
170
+ spec = row[2]
171
+ spec_padded[i, :, :spec.size(1)] = spec
172
+
173
+ wav = row[3]
174
+ wav_padded[i, :, :wav.size(1)] = wav
175
+
176
+ spkids[i, 0] = row[4]
177
+
178
+ uv = row[5]
179
+ uv_padded[i, :uv.size(0)] = uv
180
+ volume = row[6]
181
+ if volume is not None:
182
+ volume_padded[i, :volume.size(0)] = volume
183
+ else :
184
+ volume_padded = None
185
+ return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded, volume_padded
diffusion/__init__.py ADDED
File without changes
diffusion/data_loaders.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from tqdm import tqdm
9
+
10
+ from utils import repeat_expand_2d
11
+
12
+
13
+ def traverse_dir(
14
+ root_dir,
15
+ extensions,
16
+ amount=None,
17
+ str_include=None,
18
+ str_exclude=None,
19
+ is_pure=False,
20
+ is_sort=False,
21
+ is_ext=True):
22
+
23
+ file_list = []
24
+ cnt = 0
25
+ for root, _, files in os.walk(root_dir):
26
+ for file in files:
27
+ if any([file.endswith(f".{ext}") for ext in extensions]):
28
+ # path
29
+ mix_path = os.path.join(root, file)
30
+ pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
31
+
32
+ # amount
33
+ if (amount is not None) and (cnt == amount):
34
+ if is_sort:
35
+ file_list.sort()
36
+ return file_list
37
+
38
+ # check string
39
+ if (str_include is not None) and (str_include not in pure_path):
40
+ continue
41
+ if (str_exclude is not None) and (str_exclude in pure_path):
42
+ continue
43
+
44
+ if not is_ext:
45
+ ext = pure_path.split('.')[-1]
46
+ pure_path = pure_path[:-(len(ext)+1)]
47
+ file_list.append(pure_path)
48
+ cnt += 1
49
+ if is_sort:
50
+ file_list.sort()
51
+ return file_list
52
+
53
+
54
+ def get_data_loaders(args, whole_audio=False):
55
+ data_train = AudioDataset(
56
+ filelists = args.data.training_files,
57
+ waveform_sec=args.data.duration,
58
+ hop_size=args.data.block_size,
59
+ sample_rate=args.data.sampling_rate,
60
+ load_all_data=args.train.cache_all_data,
61
+ whole_audio=whole_audio,
62
+ extensions=args.data.extensions,
63
+ n_spk=args.model.n_spk,
64
+ spk=args.spk,
65
+ device=args.train.cache_device,
66
+ fp16=args.train.cache_fp16,
67
+ unit_interpolate_mode = args.data.unit_interpolate_mode,
68
+ use_aug=True)
69
+ loader_train = torch.utils.data.DataLoader(
70
+ data_train ,
71
+ batch_size=args.train.batch_size if not whole_audio else 1,
72
+ shuffle=True,
73
+ num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
74
+ persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
75
+ pin_memory=True if args.train.cache_device=='cpu' else False
76
+ )
77
+ data_valid = AudioDataset(
78
+ filelists = args.data.validation_files,
79
+ waveform_sec=args.data.duration,
80
+ hop_size=args.data.block_size,
81
+ sample_rate=args.data.sampling_rate,
82
+ load_all_data=args.train.cache_all_data,
83
+ whole_audio=True,
84
+ spk=args.spk,
85
+ extensions=args.data.extensions,
86
+ unit_interpolate_mode = args.data.unit_interpolate_mode,
87
+ n_spk=args.model.n_spk)
88
+ loader_valid = torch.utils.data.DataLoader(
89
+ data_valid,
90
+ batch_size=1,
91
+ shuffle=False,
92
+ num_workers=0,
93
+ pin_memory=True
94
+ )
95
+ return loader_train, loader_valid
96
+
97
+
98
+ class AudioDataset(Dataset):
99
+ def __init__(
100
+ self,
101
+ filelists,
102
+ waveform_sec,
103
+ hop_size,
104
+ sample_rate,
105
+ spk,
106
+ load_all_data=True,
107
+ whole_audio=False,
108
+ extensions=['wav'],
109
+ n_spk=1,
110
+ device='cpu',
111
+ fp16=False,
112
+ use_aug=False,
113
+ unit_interpolate_mode = 'left'
114
+ ):
115
+ super().__init__()
116
+
117
+ self.waveform_sec = waveform_sec
118
+ self.sample_rate = sample_rate
119
+ self.hop_size = hop_size
120
+ self.filelists = filelists
121
+ self.whole_audio = whole_audio
122
+ self.use_aug = use_aug
123
+ self.data_buffer={}
124
+ self.pitch_aug_dict = {}
125
+ self.unit_interpolate_mode = unit_interpolate_mode
126
+ # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
127
+ if load_all_data:
128
+ print('Load all the data filelists:', filelists)
129
+ else:
130
+ print('Load the f0, volume data filelists:', filelists)
131
+ with open(filelists,"r") as f:
132
+ self.paths = f.read().splitlines()
133
+ for name_ext in tqdm(self.paths, total=len(self.paths)):
134
+ path_audio = name_ext
135
+ duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
136
+
137
+ path_f0 = name_ext + ".f0.npy"
138
+ f0,_ = np.load(path_f0,allow_pickle=True)
139
+ f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device)
140
+
141
+ path_volume = name_ext + ".vol.npy"
142
+ volume = np.load(path_volume)
143
+ volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
144
+
145
+ path_augvol = name_ext + ".aug_vol.npy"
146
+ aug_vol = np.load(path_augvol)
147
+ aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
148
+
149
+ if n_spk is not None and n_spk > 1:
150
+ spk_name = name_ext.split("/")[-2]
151
+ spk_id = spk[spk_name] if spk_name in spk else 0
152
+ if spk_id < 0 or spk_id >= n_spk:
153
+ raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ')
154
+ else:
155
+ spk_id = 0
156
+ spk_id = torch.LongTensor(np.array([spk_id])).to(device)
157
+
158
+ if load_all_data:
159
+ '''
160
+ audio, sr = librosa.load(path_audio, sr=self.sample_rate)
161
+ if len(audio.shape) > 1:
162
+ audio = librosa.to_mono(audio)
163
+ audio = torch.from_numpy(audio).to(device)
164
+ '''
165
+ path_mel = name_ext + ".mel.npy"
166
+ mel = np.load(path_mel)
167
+ mel = torch.from_numpy(mel).to(device)
168
+
169
+ path_augmel = name_ext + ".aug_mel.npy"
170
+ aug_mel,keyshift = np.load(path_augmel, allow_pickle=True)
171
+ aug_mel = np.array(aug_mel,dtype=float)
172
+ aug_mel = torch.from_numpy(aug_mel).to(device)
173
+ self.pitch_aug_dict[name_ext] = keyshift
174
+
175
+ path_units = name_ext + ".soft.pt"
176
+ units = torch.load(path_units).to(device)
177
+ units = units[0]
178
+ units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1)
179
+
180
+ if fp16:
181
+ mel = mel.half()
182
+ aug_mel = aug_mel.half()
183
+ units = units.half()
184
+
185
+ self.data_buffer[name_ext] = {
186
+ 'duration': duration,
187
+ 'mel': mel,
188
+ 'aug_mel': aug_mel,
189
+ 'units': units,
190
+ 'f0': f0,
191
+ 'volume': volume,
192
+ 'aug_vol': aug_vol,
193
+ 'spk_id': spk_id
194
+ }
195
+ else:
196
+ path_augmel = name_ext + ".aug_mel.npy"
197
+ aug_mel,keyshift = np.load(path_augmel, allow_pickle=True)
198
+ self.pitch_aug_dict[name_ext] = keyshift
199
+ self.data_buffer[name_ext] = {
200
+ 'duration': duration,
201
+ 'f0': f0,
202
+ 'volume': volume,
203
+ 'aug_vol': aug_vol,
204
+ 'spk_id': spk_id
205
+ }
206
+
207
+
208
+ def __getitem__(self, file_idx):
209
+ name_ext = self.paths[file_idx]
210
+ data_buffer = self.data_buffer[name_ext]
211
+ # check duration. if too short, then skip
212
+ if data_buffer['duration'] < (self.waveform_sec + 0.1):
213
+ return self.__getitem__( (file_idx + 1) % len(self.paths))
214
+
215
+ # get item
216
+ return self.get_data(name_ext, data_buffer)
217
+
218
+ def get_data(self, name_ext, data_buffer):
219
+ name = os.path.splitext(name_ext)[0]
220
+ frame_resolution = self.hop_size / self.sample_rate
221
+ duration = data_buffer['duration']
222
+ waveform_sec = duration if self.whole_audio else self.waveform_sec
223
+
224
+ # load audio
225
+ idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
226
+ start_frame = int(idx_from / frame_resolution)
227
+ units_frame_len = int(waveform_sec / frame_resolution)
228
+ aug_flag = random.choice([True, False]) and self.use_aug
229
+ '''
230
+ audio = data_buffer.get('audio')
231
+ if audio is None:
232
+ path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
233
+ audio, sr = librosa.load(
234
+ path_audio,
235
+ sr = self.sample_rate,
236
+ offset = start_frame * frame_resolution,
237
+ duration = waveform_sec)
238
+ if len(audio.shape) > 1:
239
+ audio = librosa.to_mono(audio)
240
+ # clip audio into N seconds
241
+ audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
242
+ audio = torch.from_numpy(audio).float()
243
+ else:
244
+ audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
245
+ '''
246
+ # load mel
247
+ mel_key = 'aug_mel' if aug_flag else 'mel'
248
+ mel = data_buffer.get(mel_key)
249
+ if mel is None:
250
+ mel = name_ext + ".mel.npy"
251
+ mel = np.load(mel)
252
+ mel = mel[start_frame : start_frame + units_frame_len]
253
+ mel = torch.from_numpy(mel).float()
254
+ else:
255
+ mel = mel[start_frame : start_frame + units_frame_len]
256
+
257
+ # load f0
258
+ f0 = data_buffer.get('f0')
259
+ aug_shift = 0
260
+ if aug_flag:
261
+ aug_shift = self.pitch_aug_dict[name_ext]
262
+ f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
263
+
264
+ # load units
265
+ units = data_buffer.get('units')
266
+ if units is None:
267
+ path_units = name_ext + ".soft.pt"
268
+ units = torch.load(path_units)
269
+ units = units[0]
270
+ units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1)
271
+
272
+ units = units[start_frame : start_frame + units_frame_len]
273
+
274
+ # load volume
275
+ vol_key = 'aug_vol' if aug_flag else 'volume'
276
+ volume = data_buffer.get(vol_key)
277
+ volume_frames = volume[start_frame : start_frame + units_frame_len]
278
+
279
+ # load spk_id
280
+ spk_id = data_buffer.get('spk_id')
281
+
282
+ # load shift
283
+ aug_shift = torch.from_numpy(np.array([[aug_shift]])).float()
284
+
285
+ return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext)
286
+
287
+ def __len__(self):
288
+ return len(self.paths)
diffusion/diffusion.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from functools import partial
3
+ from inspect import isfunction
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from tqdm import tqdm
10
+
11
+
12
+ def exists(x):
13
+ return x is not None
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ def extract(a, t, x_shape):
23
+ b, *_ = t.shape
24
+ out = a.gather(-1, t)
25
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
26
+
27
+
28
+ def noise_like(shape, device, repeat=False):
29
+ def repeat_noise():
30
+ return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
31
+ def noise():
32
+ return torch.randn(shape, device=device)
33
+ return repeat_noise() if repeat else noise()
34
+
35
+
36
+ def linear_beta_schedule(timesteps, max_beta=0.02):
37
+ """
38
+ linear schedule
39
+ """
40
+ betas = np.linspace(1e-4, max_beta, timesteps)
41
+ return betas
42
+
43
+
44
+ def cosine_beta_schedule(timesteps, s=0.008):
45
+ """
46
+ cosine schedule
47
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
48
+ """
49
+ steps = timesteps + 1
50
+ x = np.linspace(0, steps, steps)
51
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
52
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
53
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
54
+ return np.clip(betas, a_min=0, a_max=0.999)
55
+
56
+
57
+ beta_schedule = {
58
+ "cosine": cosine_beta_schedule,
59
+ "linear": linear_beta_schedule,
60
+ }
61
+
62
+
63
+ class GaussianDiffusion(nn.Module):
64
+ def __init__(self,
65
+ denoise_fn,
66
+ out_dims=128,
67
+ timesteps=1000,
68
+ k_step=1000,
69
+ max_beta=0.02,
70
+ spec_min=-12,
71
+ spec_max=2):
72
+
73
+ super().__init__()
74
+ self.denoise_fn = denoise_fn
75
+ self.out_dims = out_dims
76
+ betas = beta_schedule['linear'](timesteps, max_beta=max_beta)
77
+
78
+ alphas = 1. - betas
79
+ alphas_cumprod = np.cumprod(alphas, axis=0)
80
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
81
+
82
+ timesteps, = betas.shape
83
+ self.num_timesteps = int(timesteps)
84
+ self.k_step = k_step if k_step>0 and k_step<timesteps else timesteps
85
+
86
+ self.noise_list = deque(maxlen=4)
87
+
88
+ to_torch = partial(torch.tensor, dtype=torch.float32)
89
+
90
+ self.register_buffer('betas', to_torch(betas))
91
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
92
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
93
+
94
+ # calculations for diffusion q(x_t | x_{t-1}) and others
95
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
96
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
97
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
98
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
99
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
100
+
101
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
102
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
103
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
104
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
105
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
106
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
107
+ self.register_buffer('posterior_mean_coef1', to_torch(
108
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
109
+ self.register_buffer('posterior_mean_coef2', to_torch(
110
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
111
+
112
+ self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims])
113
+ self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims])
114
+
115
+ def q_mean_variance(self, x_start, t):
116
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
117
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
118
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
119
+ return mean, variance, log_variance
120
+
121
+ def predict_start_from_noise(self, x_t, t, noise):
122
+ return (
123
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
124
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
125
+ )
126
+
127
+ def q_posterior(self, x_start, x_t, t):
128
+ posterior_mean = (
129
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
130
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
131
+ )
132
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
133
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
134
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
135
+
136
+ def p_mean_variance(self, x, t, cond):
137
+ noise_pred = self.denoise_fn(x, t, cond=cond)
138
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
139
+
140
+ x_recon.clamp_(-1., 1.)
141
+
142
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
143
+ return model_mean, posterior_variance, posterior_log_variance
144
+
145
+ @torch.no_grad()
146
+ def p_sample_ddim(self, x, t, interval, cond):
147
+ """
148
+ Use the DDIM method from
149
+ """
150
+ a_t = extract(self.alphas_cumprod, t, x.shape)
151
+ a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
152
+
153
+ noise_pred = self.denoise_fn(x, t, cond=cond)
154
+ x_prev = a_prev.sqrt() * (x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt()-((1 - a_t) / a_t).sqrt()) * noise_pred)
155
+ return x_prev
156
+
157
+ @torch.no_grad()
158
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
159
+ b, *_, device = *x.shape, x.device
160
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
161
+ noise = noise_like(x.shape, device, repeat_noise)
162
+ # no noise when t == 0
163
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
164
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
165
+
166
+ @torch.no_grad()
167
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
168
+ """
169
+ Use the PLMS method from
170
+ [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
171
+ """
172
+
173
+ def get_x_pred(x, noise_t, t):
174
+ a_t = extract(self.alphas_cumprod, t, x.shape)
175
+ a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
176
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
177
+
178
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
179
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
180
+ x_pred = x + x_delta
181
+
182
+ return x_pred
183
+
184
+ noise_list = self.noise_list
185
+ noise_pred = self.denoise_fn(x, t, cond=cond)
186
+
187
+ if len(noise_list) == 0:
188
+ x_pred = get_x_pred(x, noise_pred, t)
189
+ noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
190
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
191
+ elif len(noise_list) == 1:
192
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
193
+ elif len(noise_list) == 2:
194
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
195
+ else:
196
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
197
+
198
+ x_prev = get_x_pred(x, noise_pred_prime, t)
199
+ noise_list.append(noise_pred)
200
+
201
+ return x_prev
202
+
203
+ def q_sample(self, x_start, t, noise=None):
204
+ noise = default(noise, lambda: torch.randn_like(x_start))
205
+ return (
206
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
207
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
208
+ )
209
+
210
+ def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'):
211
+ noise = default(noise, lambda: torch.randn_like(x_start))
212
+
213
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
214
+ x_recon = self.denoise_fn(x_noisy, t, cond)
215
+
216
+ if loss_type == 'l1':
217
+ loss = (noise - x_recon).abs().mean()
218
+ elif loss_type == 'l2':
219
+ loss = F.mse_loss(noise, x_recon)
220
+ else:
221
+ raise NotImplementedError()
222
+
223
+ return loss
224
+
225
+ def forward(self,
226
+ condition,
227
+ gt_spec=None,
228
+ infer=True,
229
+ infer_speedup=10,
230
+ method='dpm-solver',
231
+ k_step=300,
232
+ use_tqdm=True):
233
+ """
234
+ conditioning diffusion, use fastspeech2 encoder output as the condition
235
+ """
236
+ cond = condition.transpose(1, 2)
237
+ b, device = condition.shape[0], condition.device
238
+
239
+ if not infer:
240
+ spec = self.norm_spec(gt_spec)
241
+ t = torch.randint(0, self.k_step, (b,), device=device).long()
242
+ norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
243
+ return self.p_losses(norm_spec, t, cond=cond)
244
+ else:
245
+ shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
246
+
247
+ if gt_spec is None:
248
+ t = self.k_step
249
+ x = torch.randn(shape, device=device)
250
+ else:
251
+ t = k_step
252
+ norm_spec = self.norm_spec(gt_spec)
253
+ norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
254
+ x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
255
+
256
+ if method is not None and infer_speedup > 1:
257
+ if method == 'dpm-solver' or method == 'dpm-solver++':
258
+ from .dpm_solver_pytorch import (
259
+ DPM_Solver,
260
+ NoiseScheduleVP,
261
+ model_wrapper,
262
+ )
263
+ # 1. Define the noise schedule.
264
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
265
+
266
+ # 2. Convert your discrete-time `model` to the continuous-time
267
+ # noise prediction model. Here is an example for a diffusion model
268
+ # `model` with the noise prediction type ("noise") .
269
+ def my_wrapper(fn):
270
+ def wrapped(x, t, **kwargs):
271
+ ret = fn(x, t, **kwargs)
272
+ if use_tqdm:
273
+ self.bar.update(1)
274
+ return ret
275
+
276
+ return wrapped
277
+
278
+ model_fn = model_wrapper(
279
+ my_wrapper(self.denoise_fn),
280
+ noise_schedule,
281
+ model_type="noise", # or "x_start" or "v" or "score"
282
+ model_kwargs={"cond": cond}
283
+ )
284
+
285
+ # 3. Define dpm-solver and sample by singlestep DPM-Solver.
286
+ # (We recommend singlestep DPM-Solver for unconditional sampling)
287
+ # You can adjust the `steps` to balance the computation
288
+ # costs and the sample quality.
289
+ if method == 'dpm-solver':
290
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
291
+ elif method == 'dpm-solver++':
292
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
293
+
294
+ steps = t // infer_speedup
295
+ if use_tqdm:
296
+ self.bar = tqdm(desc="sample time step", total=steps)
297
+ x = dpm_solver.sample(
298
+ x,
299
+ steps=steps,
300
+ order=2,
301
+ skip_type="time_uniform",
302
+ method="multistep",
303
+ )
304
+ if use_tqdm:
305
+ self.bar.close()
306
+ elif method == 'pndm':
307
+ self.noise_list = deque(maxlen=4)
308
+ if use_tqdm:
309
+ for i in tqdm(
310
+ reversed(range(0, t, infer_speedup)), desc='sample time step',
311
+ total=t // infer_speedup,
312
+ ):
313
+ x = self.p_sample_plms(
314
+ x, torch.full((b,), i, device=device, dtype=torch.long),
315
+ infer_speedup, cond=cond
316
+ )
317
+ else:
318
+ for i in reversed(range(0, t, infer_speedup)):
319
+ x = self.p_sample_plms(
320
+ x, torch.full((b,), i, device=device, dtype=torch.long),
321
+ infer_speedup, cond=cond
322
+ )
323
+ elif method == 'ddim':
324
+ if use_tqdm:
325
+ for i in tqdm(
326
+ reversed(range(0, t, infer_speedup)), desc='sample time step',
327
+ total=t // infer_speedup,
328
+ ):
329
+ x = self.p_sample_ddim(
330
+ x, torch.full((b,), i, device=device, dtype=torch.long),
331
+ infer_speedup, cond=cond
332
+ )
333
+ else:
334
+ for i in reversed(range(0, t, infer_speedup)):
335
+ x = self.p_sample_ddim(
336
+ x, torch.full((b,), i, device=device, dtype=torch.long),
337
+ infer_speedup, cond=cond
338
+ )
339
+ elif method == 'unipc':
340
+ from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper
341
+ # 1. Define the noise schedule.
342
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
343
+
344
+ # 2. Convert your discrete-time `model` to the continuous-time
345
+ # noise prediction model. Here is an example for a diffusion model
346
+ # `model` with the noise prediction type ("noise") .
347
+ def my_wrapper(fn):
348
+ def wrapped(x, t, **kwargs):
349
+ ret = fn(x, t, **kwargs)
350
+ if use_tqdm:
351
+ self.bar.update(1)
352
+ return ret
353
+
354
+ return wrapped
355
+
356
+ model_fn = model_wrapper(
357
+ my_wrapper(self.denoise_fn),
358
+ noise_schedule,
359
+ model_type="noise", # or "x_start" or "v" or "score"
360
+ model_kwargs={"cond": cond}
361
+ )
362
+
363
+ # 3. Define uni_pc and sample by multistep UniPC.
364
+ # You can adjust the `steps` to balance the computation
365
+ # costs and the sample quality.
366
+ uni_pc = UniPC(model_fn, noise_schedule, variant='bh2')
367
+
368
+ steps = t // infer_speedup
369
+ if use_tqdm:
370
+ self.bar = tqdm(desc="sample time step", total=steps)
371
+ x = uni_pc.sample(
372
+ x,
373
+ steps=steps,
374
+ order=2,
375
+ skip_type="time_uniform",
376
+ method="multistep",
377
+ )
378
+ if use_tqdm:
379
+ self.bar.close()
380
+ else:
381
+ raise NotImplementedError(method)
382
+ else:
383
+ if use_tqdm:
384
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
385
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
386
+ else:
387
+ for i in reversed(range(0, t)):
388
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
389
+ x = x.squeeze(1).transpose(1, 2) # [B, T, M]
390
+ return self.denorm_spec(x)
391
+
392
+ def norm_spec(self, x):
393
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
394
+
395
+ def denorm_spec(self, x):
396
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
diffusion/diffusion_onnx.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import deque
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from torch.nn import Conv1d, Mish
11
+ from tqdm import tqdm
12
+
13
+
14
+ def exists(x):
15
+ return x is not None
16
+
17
+
18
+ def default(val, d):
19
+ if exists(val):
20
+ return val
21
+ return d() if isfunction(d) else d
22
+
23
+
24
+ def extract(a, t):
25
+ return a[t].reshape((1, 1, 1, 1))
26
+
27
+
28
+ def noise_like(shape, device, repeat=False):
29
+ def repeat_noise():
30
+ return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
31
+ def noise():
32
+ return torch.randn(shape, device=device)
33
+ return repeat_noise() if repeat else noise()
34
+
35
+
36
+ def linear_beta_schedule(timesteps, max_beta=0.02):
37
+ """
38
+ linear schedule
39
+ """
40
+ betas = np.linspace(1e-4, max_beta, timesteps)
41
+ return betas
42
+
43
+
44
+ def cosine_beta_schedule(timesteps, s=0.008):
45
+ """
46
+ cosine schedule
47
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
48
+ """
49
+ steps = timesteps + 1
50
+ x = np.linspace(0, steps, steps)
51
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
52
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
53
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
54
+ return np.clip(betas, a_min=0, a_max=0.999)
55
+
56
+
57
+ beta_schedule = {
58
+ "cosine": cosine_beta_schedule,
59
+ "linear": linear_beta_schedule,
60
+ }
61
+
62
+
63
+ def extract_1(a, t):
64
+ return a[t].reshape((1, 1, 1, 1))
65
+
66
+
67
+ def predict_stage0(noise_pred, noise_pred_prev):
68
+ return (noise_pred + noise_pred_prev) / 2
69
+
70
+
71
+ def predict_stage1(noise_pred, noise_list):
72
+ return (noise_pred * 3
73
+ - noise_list[-1]) / 2
74
+
75
+
76
+ def predict_stage2(noise_pred, noise_list):
77
+ return (noise_pred * 23
78
+ - noise_list[-1] * 16
79
+ + noise_list[-2] * 5) / 12
80
+
81
+
82
+ def predict_stage3(noise_pred, noise_list):
83
+ return (noise_pred * 55
84
+ - noise_list[-1] * 59
85
+ + noise_list[-2] * 37
86
+ - noise_list[-3] * 9) / 24
87
+
88
+
89
+ class SinusoidalPosEmb(nn.Module):
90
+ def __init__(self, dim):
91
+ super().__init__()
92
+ self.dim = dim
93
+ self.half_dim = dim // 2
94
+ self.emb = 9.21034037 / (self.half_dim - 1)
95
+ self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0)
96
+ self.emb = self.emb.cpu()
97
+
98
+ def forward(self, x):
99
+ emb = self.emb * x
100
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
101
+ return emb
102
+
103
+
104
+ class ResidualBlock(nn.Module):
105
+ def __init__(self, encoder_hidden, residual_channels, dilation):
106
+ super().__init__()
107
+ self.residual_channels = residual_channels
108
+ self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
109
+ self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
110
+ self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
111
+ self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
112
+
113
+ def forward(self, x, conditioner, diffusion_step):
114
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
115
+ conditioner = self.conditioner_projection(conditioner)
116
+ y = x + diffusion_step
117
+ y = self.dilated_conv(y) + conditioner
118
+
119
+ gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
120
+
121
+ y = torch.sigmoid(gate) * torch.tanh(filter_1)
122
+ y = self.output_projection(y)
123
+
124
+ residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
125
+
126
+ return (x + residual) / 1.41421356, skip
127
+
128
+
129
+ class DiffNet(nn.Module):
130
+ def __init__(self, in_dims, n_layers, n_chans, n_hidden):
131
+ super().__init__()
132
+ self.encoder_hidden = n_hidden
133
+ self.residual_layers = n_layers
134
+ self.residual_channels = n_chans
135
+ self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
136
+ self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
137
+ dim = self.residual_channels
138
+ self.mlp = nn.Sequential(
139
+ nn.Linear(dim, dim * 4),
140
+ Mish(),
141
+ nn.Linear(dim * 4, dim)
142
+ )
143
+ self.residual_layers = nn.ModuleList([
144
+ ResidualBlock(self.encoder_hidden, self.residual_channels, 1)
145
+ for i in range(self.residual_layers)
146
+ ])
147
+ self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
148
+ self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
149
+ nn.init.zeros_(self.output_projection.weight)
150
+
151
+ def forward(self, spec, diffusion_step, cond):
152
+ x = spec.squeeze(0)
153
+ x = self.input_projection(x) # x [B, residual_channel, T]
154
+ x = F.relu(x)
155
+ # skip = torch.randn_like(x)
156
+ diffusion_step = diffusion_step.float()
157
+ diffusion_step = self.diffusion_embedding(diffusion_step)
158
+ diffusion_step = self.mlp(diffusion_step)
159
+
160
+ x, skip = self.residual_layers[0](x, cond, diffusion_step)
161
+ # noinspection PyTypeChecker
162
+ for layer in self.residual_layers[1:]:
163
+ x, skip_connection = layer.forward(x, cond, diffusion_step)
164
+ skip = skip + skip_connection
165
+ x = skip / math.sqrt(len(self.residual_layers))
166
+ x = self.skip_projection(x)
167
+ x = F.relu(x)
168
+ x = self.output_projection(x) # [B, 80, T]
169
+ return x.unsqueeze(1)
170
+
171
+
172
+ class AfterDiffusion(nn.Module):
173
+ def __init__(self, spec_max, spec_min, v_type='a'):
174
+ super().__init__()
175
+ self.spec_max = spec_max
176
+ self.spec_min = spec_min
177
+ self.type = v_type
178
+
179
+ def forward(self, x):
180
+ x = x.squeeze(1).permute(0, 2, 1)
181
+ mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
182
+ if self.type == 'nsf-hifigan-log10':
183
+ mel_out = mel_out * 0.434294
184
+ return mel_out.transpose(2, 1)
185
+
186
+
187
+ class Pred(nn.Module):
188
+ def __init__(self, alphas_cumprod):
189
+ super().__init__()
190
+ self.alphas_cumprod = alphas_cumprod
191
+
192
+ def forward(self, x_1, noise_t, t_1, t_prev):
193
+ a_t = extract(self.alphas_cumprod, t_1).cpu()
194
+ a_prev = extract(self.alphas_cumprod, t_prev).cpu()
195
+ a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu()
196
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (
197
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
198
+ x_pred = x_1 + x_delta.cpu()
199
+
200
+ return x_pred
201
+
202
+
203
+ class GaussianDiffusion(nn.Module):
204
+ def __init__(self,
205
+ out_dims=128,
206
+ n_layers=20,
207
+ n_chans=384,
208
+ n_hidden=256,
209
+ timesteps=1000,
210
+ k_step=1000,
211
+ max_beta=0.02,
212
+ spec_min=-12,
213
+ spec_max=2):
214
+ super().__init__()
215
+ self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden)
216
+ self.out_dims = out_dims
217
+ self.mel_bins = out_dims
218
+ self.n_hidden = n_hidden
219
+ betas = beta_schedule['linear'](timesteps, max_beta=max_beta)
220
+
221
+ alphas = 1. - betas
222
+ alphas_cumprod = np.cumprod(alphas, axis=0)
223
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
224
+ timesteps, = betas.shape
225
+ self.num_timesteps = int(timesteps)
226
+ self.k_step = k_step
227
+
228
+ self.noise_list = deque(maxlen=4)
229
+
230
+ to_torch = partial(torch.tensor, dtype=torch.float32)
231
+
232
+ self.register_buffer('betas', to_torch(betas))
233
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
234
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
235
+
236
+ # calculations for diffusion q(x_t | x_{t-1}) and others
237
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
238
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
239
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
240
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
241
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
242
+
243
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
244
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
245
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
246
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
247
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
248
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
249
+ self.register_buffer('posterior_mean_coef1', to_torch(
250
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
251
+ self.register_buffer('posterior_mean_coef2', to_torch(
252
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
253
+
254
+ self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims])
255
+ self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims])
256
+ self.ad = AfterDiffusion(self.spec_max, self.spec_min)
257
+ self.xp = Pred(self.alphas_cumprod)
258
+
259
+ def q_mean_variance(self, x_start, t):
260
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
261
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
262
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
263
+ return mean, variance, log_variance
264
+
265
+ def predict_start_from_noise(self, x_t, t, noise):
266
+ return (
267
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
268
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
269
+ )
270
+
271
+ def q_posterior(self, x_start, x_t, t):
272
+ posterior_mean = (
273
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
274
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
275
+ )
276
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
277
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
278
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
279
+
280
+ def p_mean_variance(self, x, t, cond):
281
+ noise_pred = self.denoise_fn(x, t, cond=cond)
282
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
283
+
284
+ x_recon.clamp_(-1., 1.)
285
+
286
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
287
+ return model_mean, posterior_variance, posterior_log_variance
288
+
289
+ @torch.no_grad()
290
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
291
+ b, *_, device = *x.shape, x.device
292
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
293
+ noise = noise_like(x.shape, device, repeat_noise)
294
+ # no noise when t == 0
295
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
296
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
297
+
298
+ @torch.no_grad()
299
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
300
+ """
301
+ Use the PLMS method from
302
+ [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
303
+ """
304
+
305
+ def get_x_pred(x, noise_t, t):
306
+ a_t = extract(self.alphas_cumprod, t)
307
+ a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)))
308
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
309
+
310
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
311
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
312
+ x_pred = x + x_delta
313
+
314
+ return x_pred
315
+
316
+ noise_list = self.noise_list
317
+ noise_pred = self.denoise_fn(x, t, cond=cond)
318
+
319
+ if len(noise_list) == 0:
320
+ x_pred = get_x_pred(x, noise_pred, t)
321
+ noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
322
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
323
+ elif len(noise_list) == 1:
324
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
325
+ elif len(noise_list) == 2:
326
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
327
+ else:
328
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
329
+
330
+ x_prev = get_x_pred(x, noise_pred_prime, t)
331
+ noise_list.append(noise_pred)
332
+
333
+ return x_prev
334
+
335
+ def q_sample(self, x_start, t, noise=None):
336
+ noise = default(noise, lambda: torch.randn_like(x_start))
337
+ return (
338
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
339
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
340
+ )
341
+
342
+ def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'):
343
+ noise = default(noise, lambda: torch.randn_like(x_start))
344
+
345
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
346
+ x_recon = self.denoise_fn(x_noisy, t, cond)
347
+
348
+ if loss_type == 'l1':
349
+ loss = (noise - x_recon).abs().mean()
350
+ elif loss_type == 'l2':
351
+ loss = F.mse_loss(noise, x_recon)
352
+ else:
353
+ raise NotImplementedError()
354
+
355
+ return loss
356
+
357
+ def org_forward(self,
358
+ condition,
359
+ init_noise=None,
360
+ gt_spec=None,
361
+ infer=True,
362
+ infer_speedup=100,
363
+ method='pndm',
364
+ k_step=1000,
365
+ use_tqdm=True):
366
+ """
367
+ conditioning diffusion, use fastspeech2 encoder output as the condition
368
+ """
369
+ cond = condition
370
+ b, device = condition.shape[0], condition.device
371
+ if not infer:
372
+ spec = self.norm_spec(gt_spec)
373
+ t = torch.randint(0, self.k_step, (b,), device=device).long()
374
+ norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
375
+ return self.p_losses(norm_spec, t, cond=cond)
376
+ else:
377
+ shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
378
+
379
+ if gt_spec is None:
380
+ t = self.k_step
381
+ if init_noise is None:
382
+ x = torch.randn(shape, device=device)
383
+ else:
384
+ x = init_noise
385
+ else:
386
+ t = k_step
387
+ norm_spec = self.norm_spec(gt_spec)
388
+ norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
389
+ x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
390
+
391
+ if method is not None and infer_speedup > 1:
392
+ if method == 'dpm-solver':
393
+ from .dpm_solver_pytorch import (
394
+ DPM_Solver,
395
+ NoiseScheduleVP,
396
+ model_wrapper,
397
+ )
398
+ # 1. Define the noise schedule.
399
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
400
+
401
+ # 2. Convert your discrete-time `model` to the continuous-time
402
+ # noise prediction model. Here is an example for a diffusion model
403
+ # `model` with the noise prediction type ("noise") .
404
+ def my_wrapper(fn):
405
+ def wrapped(x, t, **kwargs):
406
+ ret = fn(x, t, **kwargs)
407
+ if use_tqdm:
408
+ self.bar.update(1)
409
+ return ret
410
+
411
+ return wrapped
412
+
413
+ model_fn = model_wrapper(
414
+ my_wrapper(self.denoise_fn),
415
+ noise_schedule,
416
+ model_type="noise", # or "x_start" or "v" or "score"
417
+ model_kwargs={"cond": cond}
418
+ )
419
+
420
+ # 3. Define dpm-solver and sample by singlestep DPM-Solver.
421
+ # (We recommend singlestep DPM-Solver for unconditional sampling)
422
+ # You can adjust the `steps` to balance the computation
423
+ # costs and the sample quality.
424
+ dpm_solver = DPM_Solver(model_fn, noise_schedule)
425
+
426
+ steps = t // infer_speedup
427
+ if use_tqdm:
428
+ self.bar = tqdm(desc="sample time step", total=steps)
429
+ x = dpm_solver.sample(
430
+ x,
431
+ steps=steps,
432
+ order=3,
433
+ skip_type="time_uniform",
434
+ method="singlestep",
435
+ )
436
+ if use_tqdm:
437
+ self.bar.close()
438
+ elif method == 'pndm':
439
+ self.noise_list = deque(maxlen=4)
440
+ if use_tqdm:
441
+ for i in tqdm(
442
+ reversed(range(0, t, infer_speedup)), desc='sample time step',
443
+ total=t // infer_speedup,
444
+ ):
445
+ x = self.p_sample_plms(
446
+ x, torch.full((b,), i, device=device, dtype=torch.long),
447
+ infer_speedup, cond=cond
448
+ )
449
+ else:
450
+ for i in reversed(range(0, t, infer_speedup)):
451
+ x = self.p_sample_plms(
452
+ x, torch.full((b,), i, device=device, dtype=torch.long),
453
+ infer_speedup, cond=cond
454
+ )
455
+ else:
456
+ raise NotImplementedError(method)
457
+ else:
458
+ if use_tqdm:
459
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
460
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
461
+ else:
462
+ for i in reversed(range(0, t)):
463
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
464
+ x = x.squeeze(1).transpose(1, 2) # [B, T, M]
465
+ return self.denorm_spec(x).transpose(2, 1)
466
+
467
+ def norm_spec(self, x):
468
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
469
+
470
+ def denorm_spec(self, x):
471
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
472
+
473
+ def get_x_pred(self, x_1, noise_t, t_1, t_prev):
474
+ a_t = extract(self.alphas_cumprod, t_1)
475
+ a_prev = extract(self.alphas_cumprod, t_prev)
476
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
477
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / (
478
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
479
+ x_pred = x_1 + x_delta
480
+ return x_pred
481
+
482
+ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True):
483
+ cond = torch.randn([1, self.n_hidden, 10]).cpu()
484
+ if init_noise is None:
485
+ x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu()
486
+ else:
487
+ x = init_noise
488
+ pndms = 100
489
+
490
+ org_y_x = self.org_forward(cond, init_noise=x)
491
+
492
+ device = cond.device
493
+ n_frames = cond.shape[2]
494
+ step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0)
495
+ plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
496
+ noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
497
+
498
+ ot = step_range[0]
499
+ ot_1 = torch.full((1,), ot, device=device, dtype=torch.long)
500
+ if export_denoise:
501
+ torch.onnx.export(
502
+ self.denoise_fn,
503
+ (x.cpu(), ot_1.cpu(), cond.cpu()),
504
+ f"{project_name}_denoise.onnx",
505
+ input_names=["noise", "time", "condition"],
506
+ output_names=["noise_pred"],
507
+ dynamic_axes={
508
+ "noise": [3],
509
+ "condition": [2]
510
+ },
511
+ opset_version=16
512
+ )
513
+
514
+ for t in step_range:
515
+ t_1 = torch.full((1,), t, device=device, dtype=torch.long)
516
+ noise_pred = self.denoise_fn(x, t_1, cond)
517
+ t_prev = t_1 - pndms
518
+ t_prev = t_prev * (t_prev > 0)
519
+ if plms_noise_stage == 0:
520
+ if export_pred:
521
+ torch.onnx.export(
522
+ self.xp,
523
+ (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()),
524
+ f"{project_name}_pred.onnx",
525
+ input_names=["noise", "noise_pred", "time", "time_prev"],
526
+ output_names=["noise_pred_o"],
527
+ dynamic_axes={
528
+ "noise": [3],
529
+ "noise_pred": [3]
530
+ },
531
+ opset_version=16
532
+ )
533
+
534
+ x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
535
+ noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
536
+ noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
537
+
538
+ elif plms_noise_stage == 1:
539
+ noise_pred_prime = predict_stage1(noise_pred, noise_list)
540
+
541
+ elif plms_noise_stage == 2:
542
+ noise_pred_prime = predict_stage2(noise_pred, noise_list)
543
+
544
+ else:
545
+ noise_pred_prime = predict_stage3(noise_pred, noise_list)
546
+
547
+ noise_pred = noise_pred.unsqueeze(0)
548
+
549
+ if plms_noise_stage < 3:
550
+ noise_list = torch.cat((noise_list, noise_pred), dim=0)
551
+ plms_noise_stage = plms_noise_stage + 1
552
+
553
+ else:
554
+ noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
555
+
556
+ x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
557
+ if export_after:
558
+ torch.onnx.export(
559
+ self.ad,
560
+ x.cpu(),
561
+ f"{project_name}_after.onnx",
562
+ input_names=["x"],
563
+ output_names=["mel_out"],
564
+ dynamic_axes={
565
+ "x": [3]
566
+ },
567
+ opset_version=16
568
+ )
569
+ x = self.ad(x)
570
+
571
+ print((x == org_y_x).all())
572
+ return x
573
+
574
+ def forward(self, condition=None, init_noise=None, pndms=None, k_step=None):
575
+ cond = condition
576
+ x = init_noise
577
+
578
+ device = cond.device
579
+ n_frames = cond.shape[2]
580
+ step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0)
581
+ plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
582
+ noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
583
+
584
+ for t in step_range:
585
+ t_1 = torch.full((1,), t, device=device, dtype=torch.long)
586
+ noise_pred = self.denoise_fn(x, t_1, cond)
587
+ t_prev = t_1 - pndms
588
+ t_prev = t_prev * (t_prev > 0)
589
+ if plms_noise_stage == 0:
590
+ x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
591
+ noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
592
+ noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev)
593
+
594
+ elif plms_noise_stage == 1:
595
+ noise_pred_prime = predict_stage1(noise_pred, noise_list)
596
+
597
+ elif plms_noise_stage == 2:
598
+ noise_pred_prime = predict_stage2(noise_pred, noise_list)
599
+
600
+ else:
601
+ noise_pred_prime = predict_stage3(noise_pred, noise_list)
602
+
603
+ noise_pred = noise_pred.unsqueeze(0)
604
+
605
+ if plms_noise_stage < 3:
606
+ noise_list = torch.cat((noise_list, noise_pred), dim=0)
607
+ plms_noise_stage = plms_noise_stage + 1
608
+
609
+ else:
610
+ noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0)
611
+
612
+ x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev)
613
+ x = self.ad(x)
614
+ return x
diffusion/dpm_solver_pytorch.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class NoiseScheduleVP:
5
+ def __init__(
6
+ self,
7
+ schedule='discrete',
8
+ betas=None,
9
+ alphas_cumprod=None,
10
+ continuous_beta_0=0.1,
11
+ continuous_beta_1=20.,
12
+ dtype=torch.float32,
13
+ ):
14
+ """Create a wrapper class for the forward SDE (VP type).
15
+
16
+ ***
17
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
18
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
19
+ ***
20
+
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+
25
+ log_alpha_t = self.marginal_log_mean_coeff(t)
26
+ sigma_t = self.marginal_std(t)
27
+ lambda_t = self.marginal_lambda(t)
28
+
29
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
30
+
31
+ t = self.inverse_lambda(lambda_t)
32
+
33
+ ===============================================================
34
+
35
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
36
+
37
+ 1. For discrete-time DPMs:
38
+
39
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
40
+ t_i = (i + 1) / N
41
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
42
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
43
+
44
+ Args:
45
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
46
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
47
+
48
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
49
+
50
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
51
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
52
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
53
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
54
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
55
+ and
56
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
57
+
58
+
59
+ 2. For continuous-time DPMs:
60
+
61
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
62
+ schedule are the default settings in Yang Song's ScoreSDE:
63
+
64
+ Args:
65
+ beta_min: A `float` number. The smallest beta for the linear schedule.
66
+ beta_max: A `float` number. The largest beta for the linear schedule.
67
+ T: A `float` number. The ending time of the forward process.
68
+
69
+ ===============================================================
70
+
71
+ Args:
72
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
73
+ 'linear' for continuous-time DPMs.
74
+ Returns:
75
+ A wrapper object of the forward SDE (VP type).
76
+
77
+ ===============================================================
78
+
79
+ Example:
80
+
81
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
82
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
83
+
84
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
86
+
87
+ # For continuous-time DPMs (VPSDE), linear schedule:
88
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
89
+
90
+ """
91
+
92
+ if schedule not in ['discrete', 'linear']:
93
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
94
+
95
+ self.schedule = schedule
96
+ if schedule == 'discrete':
97
+ if betas is not None:
98
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
99
+ else:
100
+ assert alphas_cumprod is not None
101
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
102
+ self.T = 1.
103
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
104
+ self.total_N = self.log_alpha_array.shape[1]
105
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
106
+ else:
107
+ self.T = 1.
108
+ self.total_N = 1000
109
+ self.beta_0 = continuous_beta_0
110
+ self.beta_1 = continuous_beta_1
111
+
112
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
113
+ """
114
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
115
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
116
+ Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
117
+ """
118
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
119
+ lambs = log_alphas - log_sigmas
120
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
121
+ if idx > 0:
122
+ log_alphas = log_alphas[:-idx]
123
+ return log_alphas
124
+
125
+ def marginal_log_mean_coeff(self, t):
126
+ """
127
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
+ """
129
+ if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
131
+ elif self.schedule == 'linear':
132
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
+
134
+ def marginal_alpha(self, t):
135
+ """
136
+ Compute alpha_t of a given continuous-time label t in [0, T].
137
+ """
138
+ return torch.exp(self.marginal_log_mean_coeff(t))
139
+
140
+ def marginal_std(self, t):
141
+ """
142
+ Compute sigma_t of a given continuous-time label t in [0, T].
143
+ """
144
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
145
+
146
+ def marginal_lambda(self, t):
147
+ """
148
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
149
+ """
150
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
151
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
152
+ return log_mean_coeff - log_std
153
+
154
+ def inverse_lambda(self, lamb):
155
+ """
156
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
157
+ """
158
+ if self.schedule == 'linear':
159
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
160
+ Delta = self.beta_0**2 + tmp
161
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
162
+ elif self.schedule == 'discrete':
163
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
164
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
165
+ return t.reshape((-1,))
166
+
167
+
168
+ def model_wrapper(
169
+ model,
170
+ noise_schedule,
171
+ model_type="noise",
172
+ model_kwargs={},
173
+ guidance_type="uncond",
174
+ condition=None,
175
+ unconditional_condition=None,
176
+ guidance_scale=1.,
177
+ classifier_fn=None,
178
+ classifier_kwargs={},
179
+ ):
180
+ """Create a wrapper function for the noise prediction model.
181
+
182
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
183
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
184
+
185
+ We support four types of the diffusion model by setting `model_type`:
186
+
187
+ 1. "noise": noise prediction model. (Trained by predicting noise).
188
+
189
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
190
+
191
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
192
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
193
+
194
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
195
+ arXiv preprint arXiv:2202.00512 (2022).
196
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
197
+ arXiv preprint arXiv:2210.02303 (2022).
198
+
199
+ 4. "score": marginal score function. (Trained by denoising score matching).
200
+ Note that the score function and the noise prediction model follows a simple relationship:
201
+ ```
202
+ noise(x_t, t) = -sigma_t * score(x_t, t)
203
+ ```
204
+
205
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
206
+ 1. "uncond": unconditional sampling by DPMs.
207
+ The input `model` has the following format:
208
+ ``
209
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
210
+ ``
211
+
212
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
213
+ The input `model` has the following format:
214
+ ``
215
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
216
+ ``
217
+
218
+ The input `classifier_fn` has the following format:
219
+ ``
220
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
221
+ ``
222
+
223
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
224
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
225
+
226
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
227
+ The input `model` has the following format:
228
+ ``
229
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
230
+ ``
231
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
232
+
233
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
234
+ arXiv preprint arXiv:2207.12598 (2022).
235
+
236
+
237
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
238
+ or continuous-time labels (i.e. epsilon to T).
239
+
240
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
241
+ ``
242
+ def model_fn(x, t_continuous) -> noise:
243
+ t_input = get_model_input_time(t_continuous)
244
+ return noise_pred(model, x, t_input, **model_kwargs)
245
+ ``
246
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
247
+
248
+ ===============================================================
249
+
250
+ Args:
251
+ model: A diffusion model with the corresponding format described above.
252
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
253
+ model_type: A `str`. The parameterization type of the diffusion model.
254
+ "noise" or "x_start" or "v" or "score".
255
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
256
+ guidance_type: A `str`. The type of the guidance for sampling.
257
+ "uncond" or "classifier" or "classifier-free".
258
+ condition: A pytorch tensor. The condition for the guided sampling.
259
+ Only used for "classifier" or "classifier-free" guidance type.
260
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
261
+ Only used for "classifier-free" guidance type.
262
+ guidance_scale: A `float`. The scale for the guided sampling.
263
+ classifier_fn: A classifier function. Only used for the classifier guidance.
264
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
265
+ Returns:
266
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
267
+ """
268
+
269
+ def get_model_input_time(t_continuous):
270
+ """
271
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
272
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
273
+ For continuous-time DPMs, we just use `t_continuous`.
274
+ """
275
+ if noise_schedule.schedule == 'discrete':
276
+ return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
277
+ else:
278
+ return t_continuous
279
+
280
+ def noise_pred_fn(x, t_continuous, cond=None):
281
+ t_input = get_model_input_time(t_continuous)
282
+ if cond is None:
283
+ output = model(x, t_input, **model_kwargs)
284
+ else:
285
+ output = model(x, t_input, cond, **model_kwargs)
286
+ if model_type == "noise":
287
+ return output
288
+ elif model_type == "x_start":
289
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
290
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
291
+ elif model_type == "v":
292
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
293
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
294
+ elif model_type == "score":
295
+ sigma_t = noise_schedule.marginal_std(t_continuous)
296
+ return -expand_dims(sigma_t, x.dim()) * output
297
+
298
+ def cond_grad_fn(x, t_input):
299
+ """
300
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
301
+ """
302
+ with torch.enable_grad():
303
+ x_in = x.detach().requires_grad_(True)
304
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
305
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
306
+
307
+ def model_fn(x, t_continuous):
308
+ """
309
+ The noise predicition model function that is used for DPM-Solver.
310
+ """
311
+ if guidance_type == "uncond":
312
+ return noise_pred_fn(x, t_continuous)
313
+ elif guidance_type == "classifier":
314
+ assert classifier_fn is not None
315
+ t_input = get_model_input_time(t_continuous)
316
+ cond_grad = cond_grad_fn(x, t_input)
317
+ sigma_t = noise_schedule.marginal_std(t_continuous)
318
+ noise = noise_pred_fn(x, t_continuous)
319
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
320
+ elif guidance_type == "classifier-free":
321
+ if guidance_scale == 1. or unconditional_condition is None:
322
+ return noise_pred_fn(x, t_continuous, cond=condition)
323
+ else:
324
+ x_in = torch.cat([x] * 2)
325
+ t_in = torch.cat([t_continuous] * 2)
326
+ c_in = torch.cat([unconditional_condition, condition])
327
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
328
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
329
+
330
+ assert model_type in ["noise", "x_start", "v", "score"]
331
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
332
+ return model_fn
333
+
334
+
335
+ class DPM_Solver:
336
+ def __init__(
337
+ self,
338
+ model_fn,
339
+ noise_schedule,
340
+ algorithm_type="dpmsolver++",
341
+ correcting_x0_fn=None,
342
+ correcting_xt_fn=None,
343
+ thresholding_max_val=1.,
344
+ dynamic_thresholding_ratio=0.995,
345
+ ):
346
+ """Construct a DPM-Solver.
347
+
348
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
349
+
350
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
351
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
352
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
353
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
354
+ DPMs (such as stable-diffusion).
355
+
356
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
357
+ both x0 and xt.
358
+
359
+ Args:
360
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
361
+ ``
362
+ def model_fn(x, t_continuous):
363
+ return noise
364
+ ``
365
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
366
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
367
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
368
+ correcting_x0_fn: A `str` or a function with the following format:
369
+ ```
370
+ def correcting_x0_fn(x0, t):
371
+ x0_new = ...
372
+ return x0_new
373
+ ```
374
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
375
+ ```
376
+ x0_pred = data_pred_model(xt, t)
377
+ if correcting_x0_fn is not None:
378
+ x0_pred = correcting_x0_fn(x0_pred, t)
379
+ xt_1 = update(x0_pred, xt, t)
380
+ ```
381
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
382
+ correcting_xt_fn: A function with the following format:
383
+ ```
384
+ def correcting_xt_fn(xt, t, step):
385
+ x_new = ...
386
+ return x_new
387
+ ```
388
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
389
+ ```
390
+ xt = ...
391
+ xt = correcting_xt_fn(xt, t, step)
392
+ ```
393
+ thresholding_max_val: A `float`. The max value for thresholding.
394
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
395
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
396
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
397
+
398
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
399
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
400
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
401
+ """
402
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
403
+ self.noise_schedule = noise_schedule
404
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
405
+ self.algorithm_type = algorithm_type
406
+ if correcting_x0_fn == "dynamic_thresholding":
407
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
408
+ else:
409
+ self.correcting_x0_fn = correcting_x0_fn
410
+ self.correcting_xt_fn = correcting_xt_fn
411
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
412
+ self.thresholding_max_val = thresholding_max_val
413
+
414
+ def dynamic_thresholding_fn(self, x0, t):
415
+ """
416
+ The dynamic thresholding method.
417
+ """
418
+ dims = x0.dim()
419
+ p = self.dynamic_thresholding_ratio
420
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
421
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
422
+ x0 = torch.clamp(x0, -s, s) / s
423
+ return x0
424
+
425
+ def noise_prediction_fn(self, x, t):
426
+ """
427
+ Return the noise prediction model.
428
+ """
429
+ return self.model(x, t)
430
+
431
+ def data_prediction_fn(self, x, t):
432
+ """
433
+ Return the data prediction model (with corrector).
434
+ """
435
+ noise = self.noise_prediction_fn(x, t)
436
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
437
+ x0 = (x - sigma_t * noise) / alpha_t
438
+ if self.correcting_x0_fn is not None:
439
+ x0 = self.correcting_x0_fn(x0, t)
440
+ return x0
441
+
442
+ def model_fn(self, x, t):
443
+ """
444
+ Convert the model to the noise prediction model or the data prediction model.
445
+ """
446
+ if self.algorithm_type == "dpmsolver++":
447
+ return self.data_prediction_fn(x, t)
448
+ else:
449
+ return self.noise_prediction_fn(x, t)
450
+
451
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
452
+ """Compute the intermediate time steps for sampling.
453
+
454
+ Args:
455
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
456
+ - 'logSNR': uniform logSNR for the time steps.
457
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
458
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
459
+ t_T: A `float`. The starting time of the sampling (default is T).
460
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
461
+ N: A `int`. The total number of the spacing of the time steps.
462
+ device: A torch device.
463
+ Returns:
464
+ A pytorch tensor of the time steps, with the shape (N + 1,).
465
+ """
466
+ if skip_type == 'logSNR':
467
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
468
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
469
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
470
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
471
+ elif skip_type == 'time_uniform':
472
+ return torch.linspace(t_T, t_0, N + 1).to(device)
473
+ elif skip_type == 'time_quadratic':
474
+ t_order = 2
475
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
476
+ return t
477
+ else:
478
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
479
+
480
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
481
+ """
482
+ Get the order of each step for sampling by the singlestep DPM-Solver.
483
+
484
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
485
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
486
+ - If order == 1:
487
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
488
+ - If order == 2:
489
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
490
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
491
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
492
+ - If order == 3:
493
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
494
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
495
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
496
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
497
+
498
+ ============================================
499
+ Args:
500
+ order: A `int`. The max order for the solver (2 or 3).
501
+ steps: A `int`. The total number of function evaluations (NFE).
502
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
503
+ - 'logSNR': uniform logSNR for the time steps.
504
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
505
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
506
+ t_T: A `float`. The starting time of the sampling (default is T).
507
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
508
+ device: A torch device.
509
+ Returns:
510
+ orders: A list of the solver order of each step.
511
+ """
512
+ if order == 3:
513
+ K = steps // 3 + 1
514
+ if steps % 3 == 0:
515
+ orders = [3,] * (K - 2) + [2, 1]
516
+ elif steps % 3 == 1:
517
+ orders = [3,] * (K - 1) + [1]
518
+ else:
519
+ orders = [3,] * (K - 1) + [2]
520
+ elif order == 2:
521
+ if steps % 2 == 0:
522
+ K = steps // 2
523
+ orders = [2,] * K
524
+ else:
525
+ K = steps // 2 + 1
526
+ orders = [2,] * (K - 1) + [1]
527
+ elif order == 1:
528
+ K = 1
529
+ orders = [1,] * steps
530
+ else:
531
+ raise ValueError("'order' must be '1' or '2' or '3'.")
532
+ if skip_type == 'logSNR':
533
+ # To reproduce the results in DPM-Solver paper
534
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
535
+ else:
536
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
537
+ return timesteps_outer, orders
538
+
539
+ def denoise_to_zero_fn(self, x, s):
540
+ """
541
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
542
+ """
543
+ return self.data_prediction_fn(x, s)
544
+
545
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
546
+ """
547
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
548
+
549
+ Args:
550
+ x: A pytorch tensor. The initial value at time `s`.
551
+ s: A pytorch tensor. The starting time, with the shape (1,).
552
+ t: A pytorch tensor. The ending time, with the shape (1,).
553
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
554
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
555
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
556
+ Returns:
557
+ x_t: A pytorch tensor. The approximated solution at time `t`.
558
+ """
559
+ ns = self.noise_schedule
560
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
561
+ h = lambda_t - lambda_s
562
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
563
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
564
+ alpha_t = torch.exp(log_alpha_t)
565
+
566
+ if self.algorithm_type == "dpmsolver++":
567
+ phi_1 = torch.expm1(-h)
568
+ if model_s is None:
569
+ model_s = self.model_fn(x, s)
570
+ x_t = (
571
+ sigma_t / sigma_s * x
572
+ - alpha_t * phi_1 * model_s
573
+ )
574
+ if return_intermediate:
575
+ return x_t, {'model_s': model_s}
576
+ else:
577
+ return x_t
578
+ else:
579
+ phi_1 = torch.expm1(h)
580
+ if model_s is None:
581
+ model_s = self.model_fn(x, s)
582
+ x_t = (
583
+ torch.exp(log_alpha_t - log_alpha_s) * x
584
+ - (sigma_t * phi_1) * model_s
585
+ )
586
+ if return_intermediate:
587
+ return x_t, {'model_s': model_s}
588
+ else:
589
+ return x_t
590
+
591
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
592
+ """
593
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
594
+
595
+ Args:
596
+ x: A pytorch tensor. The initial value at time `s`.
597
+ s: A pytorch tensor. The starting time, with the shape (1,).
598
+ t: A pytorch tensor. The ending time, with the shape (1,).
599
+ r1: A `float`. The hyperparameter of the second-order solver.
600
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
601
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
602
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
603
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
604
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
605
+ Returns:
606
+ x_t: A pytorch tensor. The approximated solution at time `t`.
607
+ """
608
+ if solver_type not in ['dpmsolver', 'taylor']:
609
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
610
+ if r1 is None:
611
+ r1 = 0.5
612
+ ns = self.noise_schedule
613
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
614
+ h = lambda_t - lambda_s
615
+ lambda_s1 = lambda_s + r1 * h
616
+ s1 = ns.inverse_lambda(lambda_s1)
617
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
618
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
619
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
620
+
621
+ if self.algorithm_type == "dpmsolver++":
622
+ phi_11 = torch.expm1(-r1 * h)
623
+ phi_1 = torch.expm1(-h)
624
+
625
+ if model_s is None:
626
+ model_s = self.model_fn(x, s)
627
+ x_s1 = (
628
+ (sigma_s1 / sigma_s) * x
629
+ - (alpha_s1 * phi_11) * model_s
630
+ )
631
+ model_s1 = self.model_fn(x_s1, s1)
632
+ if solver_type == 'dpmsolver':
633
+ x_t = (
634
+ (sigma_t / sigma_s) * x
635
+ - (alpha_t * phi_1) * model_s
636
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
637
+ )
638
+ elif solver_type == 'taylor':
639
+ x_t = (
640
+ (sigma_t / sigma_s) * x
641
+ - (alpha_t * phi_1) * model_s
642
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
643
+ )
644
+ else:
645
+ phi_11 = torch.expm1(r1 * h)
646
+ phi_1 = torch.expm1(h)
647
+
648
+ if model_s is None:
649
+ model_s = self.model_fn(x, s)
650
+ x_s1 = (
651
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
652
+ - (sigma_s1 * phi_11) * model_s
653
+ )
654
+ model_s1 = self.model_fn(x_s1, s1)
655
+ if solver_type == 'dpmsolver':
656
+ x_t = (
657
+ torch.exp(log_alpha_t - log_alpha_s) * x
658
+ - (sigma_t * phi_1) * model_s
659
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
660
+ )
661
+ elif solver_type == 'taylor':
662
+ x_t = (
663
+ torch.exp(log_alpha_t - log_alpha_s) * x
664
+ - (sigma_t * phi_1) * model_s
665
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
666
+ )
667
+ if return_intermediate:
668
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
669
+ else:
670
+ return x_t
671
+
672
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
673
+ """
674
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
675
+
676
+ Args:
677
+ x: A pytorch tensor. The initial value at time `s`.
678
+ s: A pytorch tensor. The starting time, with the shape (1,).
679
+ t: A pytorch tensor. The ending time, with the shape (1,).
680
+ r1: A `float`. The hyperparameter of the third-order solver.
681
+ r2: A `float`. The hyperparameter of the third-order solver.
682
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
683
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
684
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
685
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
686
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
687
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
688
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
689
+ Returns:
690
+ x_t: A pytorch tensor. The approximated solution at time `t`.
691
+ """
692
+ if solver_type not in ['dpmsolver', 'taylor']:
693
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
694
+ if r1 is None:
695
+ r1 = 1. / 3.
696
+ if r2 is None:
697
+ r2 = 2. / 3.
698
+ ns = self.noise_schedule
699
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
700
+ h = lambda_t - lambda_s
701
+ lambda_s1 = lambda_s + r1 * h
702
+ lambda_s2 = lambda_s + r2 * h
703
+ s1 = ns.inverse_lambda(lambda_s1)
704
+ s2 = ns.inverse_lambda(lambda_s2)
705
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
706
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
707
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
708
+
709
+ if self.algorithm_type == "dpmsolver++":
710
+ phi_11 = torch.expm1(-r1 * h)
711
+ phi_12 = torch.expm1(-r2 * h)
712
+ phi_1 = torch.expm1(-h)
713
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
714
+ phi_2 = phi_1 / h + 1.
715
+ phi_3 = phi_2 / h - 0.5
716
+
717
+ if model_s is None:
718
+ model_s = self.model_fn(x, s)
719
+ if model_s1 is None:
720
+ x_s1 = (
721
+ (sigma_s1 / sigma_s) * x
722
+ - (alpha_s1 * phi_11) * model_s
723
+ )
724
+ model_s1 = self.model_fn(x_s1, s1)
725
+ x_s2 = (
726
+ (sigma_s2 / sigma_s) * x
727
+ - (alpha_s2 * phi_12) * model_s
728
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
729
+ )
730
+ model_s2 = self.model_fn(x_s2, s2)
731
+ if solver_type == 'dpmsolver':
732
+ x_t = (
733
+ (sigma_t / sigma_s) * x
734
+ - (alpha_t * phi_1) * model_s
735
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
736
+ )
737
+ elif solver_type == 'taylor':
738
+ D1_0 = (1. / r1) * (model_s1 - model_s)
739
+ D1_1 = (1. / r2) * (model_s2 - model_s)
740
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
741
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
742
+ x_t = (
743
+ (sigma_t / sigma_s) * x
744
+ - (alpha_t * phi_1) * model_s
745
+ + (alpha_t * phi_2) * D1
746
+ - (alpha_t * phi_3) * D2
747
+ )
748
+ else:
749
+ phi_11 = torch.expm1(r1 * h)
750
+ phi_12 = torch.expm1(r2 * h)
751
+ phi_1 = torch.expm1(h)
752
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
753
+ phi_2 = phi_1 / h - 1.
754
+ phi_3 = phi_2 / h - 0.5
755
+
756
+ if model_s is None:
757
+ model_s = self.model_fn(x, s)
758
+ if model_s1 is None:
759
+ x_s1 = (
760
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
761
+ - (sigma_s1 * phi_11) * model_s
762
+ )
763
+ model_s1 = self.model_fn(x_s1, s1)
764
+ x_s2 = (
765
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
766
+ - (sigma_s2 * phi_12) * model_s
767
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
768
+ )
769
+ model_s2 = self.model_fn(x_s2, s2)
770
+ if solver_type == 'dpmsolver':
771
+ x_t = (
772
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
773
+ - (sigma_t * phi_1) * model_s
774
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
775
+ )
776
+ elif solver_type == 'taylor':
777
+ D1_0 = (1. / r1) * (model_s1 - model_s)
778
+ D1_1 = (1. / r2) * (model_s2 - model_s)
779
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
780
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
781
+ x_t = (
782
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
783
+ - (sigma_t * phi_1) * model_s
784
+ - (sigma_t * phi_2) * D1
785
+ - (sigma_t * phi_3) * D2
786
+ )
787
+
788
+ if return_intermediate:
789
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
790
+ else:
791
+ return x_t
792
+
793
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
794
+ """
795
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
796
+
797
+ Args:
798
+ x: A pytorch tensor. The initial value at time `s`.
799
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
800
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
801
+ t: A pytorch tensor. The ending time, with the shape (1,).
802
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
803
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
804
+ Returns:
805
+ x_t: A pytorch tensor. The approximated solution at time `t`.
806
+ """
807
+ if solver_type not in ['dpmsolver', 'taylor']:
808
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
809
+ ns = self.noise_schedule
810
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
811
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
812
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
813
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
814
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
815
+ alpha_t = torch.exp(log_alpha_t)
816
+
817
+ h_0 = lambda_prev_0 - lambda_prev_1
818
+ h = lambda_t - lambda_prev_0
819
+ r0 = h_0 / h
820
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
821
+ if self.algorithm_type == "dpmsolver++":
822
+ phi_1 = torch.expm1(-h)
823
+ if solver_type == 'dpmsolver':
824
+ x_t = (
825
+ (sigma_t / sigma_prev_0) * x
826
+ - (alpha_t * phi_1) * model_prev_0
827
+ - 0.5 * (alpha_t * phi_1) * D1_0
828
+ )
829
+ elif solver_type == 'taylor':
830
+ x_t = (
831
+ (sigma_t / sigma_prev_0) * x
832
+ - (alpha_t * phi_1) * model_prev_0
833
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
834
+ )
835
+ else:
836
+ phi_1 = torch.expm1(h)
837
+ if solver_type == 'dpmsolver':
838
+ x_t = (
839
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
840
+ - (sigma_t * phi_1) * model_prev_0
841
+ - 0.5 * (sigma_t * phi_1) * D1_0
842
+ )
843
+ elif solver_type == 'taylor':
844
+ x_t = (
845
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
846
+ - (sigma_t * phi_1) * model_prev_0
847
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
848
+ )
849
+ return x_t
850
+
851
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
852
+ """
853
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
854
+
855
+ Args:
856
+ x: A pytorch tensor. The initial value at time `s`.
857
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
858
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
859
+ t: A pytorch tensor. The ending time, with the shape (1,).
860
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
861
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
862
+ Returns:
863
+ x_t: A pytorch tensor. The approximated solution at time `t`.
864
+ """
865
+ ns = self.noise_schedule
866
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
867
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
868
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
869
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
870
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
871
+ alpha_t = torch.exp(log_alpha_t)
872
+
873
+ h_1 = lambda_prev_1 - lambda_prev_2
874
+ h_0 = lambda_prev_0 - lambda_prev_1
875
+ h = lambda_t - lambda_prev_0
876
+ r0, r1 = h_0 / h, h_1 / h
877
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
878
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
879
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
880
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
881
+ if self.algorithm_type == "dpmsolver++":
882
+ phi_1 = torch.expm1(-h)
883
+ phi_2 = phi_1 / h + 1.
884
+ phi_3 = phi_2 / h - 0.5
885
+ x_t = (
886
+ (sigma_t / sigma_prev_0) * x
887
+ - (alpha_t * phi_1) * model_prev_0
888
+ + (alpha_t * phi_2) * D1
889
+ - (alpha_t * phi_3) * D2
890
+ )
891
+ else:
892
+ phi_1 = torch.expm1(h)
893
+ phi_2 = phi_1 / h - 1.
894
+ phi_3 = phi_2 / h - 0.5
895
+ x_t = (
896
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
897
+ - (sigma_t * phi_1) * model_prev_0
898
+ - (sigma_t * phi_2) * D1
899
+ - (sigma_t * phi_3) * D2
900
+ )
901
+ return x_t
902
+
903
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
904
+ """
905
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
906
+
907
+ Args:
908
+ x: A pytorch tensor. The initial value at time `s`.
909
+ s: A pytorch tensor. The starting time, with the shape (1,).
910
+ t: A pytorch tensor. The ending time, with the shape (1,).
911
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
912
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
913
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
914
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
915
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
916
+ r2: A `float`. The hyperparameter of the third-order solver.
917
+ Returns:
918
+ x_t: A pytorch tensor. The approximated solution at time `t`.
919
+ """
920
+ if order == 1:
921
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
922
+ elif order == 2:
923
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
924
+ elif order == 3:
925
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
926
+ else:
927
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
928
+
929
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
930
+ """
931
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
932
+
933
+ Args:
934
+ x: A pytorch tensor. The initial value at time `s`.
935
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
936
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
937
+ t: A pytorch tensor. The ending time, with the shape (1,).
938
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
939
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
940
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
941
+ Returns:
942
+ x_t: A pytorch tensor. The approximated solution at time `t`.
943
+ """
944
+ if order == 1:
945
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
946
+ elif order == 2:
947
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
948
+ elif order == 3:
949
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
950
+ else:
951
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
952
+
953
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
954
+ """
955
+ The adaptive step size solver based on singlestep DPM-Solver.
956
+
957
+ Args:
958
+ x: A pytorch tensor. The initial value at time `t_T`.
959
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
960
+ t_T: A `float`. The starting time of the sampling (default is T).
961
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
962
+ h_init: A `float`. The initial step size (for logSNR).
963
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
964
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
965
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
966
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
967
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
968
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
969
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
970
+ Returns:
971
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
972
+
973
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
974
+ """
975
+ ns = self.noise_schedule
976
+ s = t_T * torch.ones((1,)).to(x)
977
+ lambda_s = ns.marginal_lambda(s)
978
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
979
+ h = h_init * torch.ones_like(s).to(x)
980
+ x_prev = x
981
+ nfe = 0
982
+ if order == 2:
983
+ r1 = 0.5
984
+ def lower_update(x, s, t):
985
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
986
+ def higher_update(x, s, t, **kwargs):
987
+ return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
988
+ elif order == 3:
989
+ r1, r2 = 1. / 3., 2. / 3.
990
+ def lower_update(x, s, t):
991
+ return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
992
+ def higher_update(x, s, t, **kwargs):
993
+ return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
994
+ else:
995
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
996
+ while torch.abs((s - t_0)).mean() > t_err:
997
+ t = ns.inverse_lambda(lambda_s + h)
998
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
999
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1000
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1001
+ def norm_fn(v):
1002
+ return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1003
+ E = norm_fn((x_higher - x_lower) / delta).max()
1004
+ if torch.all(E <= 1.):
1005
+ x = x_higher
1006
+ s = t
1007
+ x_prev = x_lower
1008
+ lambda_s = ns.marginal_lambda(s)
1009
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1010
+ nfe += order
1011
+ print('adaptive solver nfe', nfe)
1012
+ return x
1013
+
1014
+ def add_noise(self, x, t, noise=None):
1015
+ """
1016
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1017
+
1018
+ Args:
1019
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1020
+ t: A `torch.Tensor` with shape `(t_size,)`.
1021
+ Returns:
1022
+ xt with shape `(t_size, batch_size, *shape)`.
1023
+ """
1024
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1025
+ if noise is None:
1026
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1027
+ x = x.reshape((-1, *x.shape))
1028
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1029
+ if t.shape[0] == 1:
1030
+ return xt.squeeze(0)
1031
+ else:
1032
+ return xt
1033
+
1034
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1035
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1036
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1037
+ ):
1038
+ """
1039
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1040
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1041
+ """
1042
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1043
+ t_T = self.noise_schedule.T if t_end is None else t_end
1044
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1045
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1046
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
1047
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1048
+
1049
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1050
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1051
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1052
+ ):
1053
+ """
1054
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1055
+
1056
+ =====================================================
1057
+
1058
+ We support the following algorithms for both noise prediction model and data prediction model:
1059
+ - 'singlestep':
1060
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1061
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1062
+ The total number of function evaluations (NFE) == `steps`.
1063
+ Given a fixed NFE == `steps`, the sampling procedure is:
1064
+ - If `order` == 1:
1065
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1066
+ - If `order` == 2:
1067
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1068
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1069
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1070
+ - If `order` == 3:
1071
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1072
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1073
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1074
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1075
+ - 'multistep':
1076
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1077
+ We initialize the first `order` values by lower order multistep solvers.
1078
+ Given a fixed NFE == `steps`, the sampling procedure is:
1079
+ Denote K = steps.
1080
+ - If `order` == 1:
1081
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1082
+ - If `order` == 2:
1083
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1084
+ - If `order` == 3:
1085
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1086
+ - 'singlestep_fixed':
1087
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1088
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1089
+ - 'adaptive':
1090
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1091
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1092
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1093
+ (NFE) and the sample quality.
1094
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1095
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1096
+
1097
+ =====================================================
1098
+
1099
+ Some advices for choosing the algorithm:
1100
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1101
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1102
+ e.g., DPM-Solver:
1103
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1104
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1105
+ skip_type='time_uniform', method='singlestep')
1106
+ e.g., DPM-Solver++:
1107
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1108
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1109
+ skip_type='time_uniform', method='singlestep')
1110
+ - For **guided sampling with large guidance scale** by DPMs:
1111
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1112
+ e.g.
1113
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1114
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1115
+ skip_type='time_uniform', method='multistep')
1116
+
1117
+ We support three types of `skip_type`:
1118
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1119
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1120
+ - 'time_quadratic': quadratic time for the time steps.
1121
+
1122
+ =====================================================
1123
+ Args:
1124
+ x: A pytorch tensor. The initial value at time `t_start`
1125
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1126
+ steps: A `int`. The total number of function evaluations (NFE).
1127
+ t_start: A `float`. The starting time of the sampling.
1128
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1129
+ t_end: A `float`. The ending time of the sampling.
1130
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1131
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1132
+ For discrete-time DPMs:
1133
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1134
+ For continuous-time DPMs:
1135
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1136
+ order: A `int`. The order of DPM-Solver.
1137
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1138
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1139
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1140
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1141
+
1142
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1143
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1144
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1145
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1146
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1147
+ it for high-resolutional images.
1148
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1149
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1150
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1151
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1152
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1153
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1154
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1155
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1156
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1157
+ Returns:
1158
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1159
+
1160
+ """
1161
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1162
+ t_T = self.noise_schedule.T if t_start is None else t_start
1163
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1164
+ if return_intermediate:
1165
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1166
+ if self.correcting_xt_fn is not None:
1167
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1168
+ device = x.device
1169
+ intermediates = []
1170
+ with torch.no_grad():
1171
+ if method == 'adaptive':
1172
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1173
+ elif method == 'multistep':
1174
+ assert steps >= order
1175
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1176
+ assert timesteps.shape[0] - 1 == steps
1177
+ # Init the initial values.
1178
+ step = 0
1179
+ t = timesteps[step]
1180
+ t_prev_list = [t]
1181
+ model_prev_list = [self.model_fn(x, t)]
1182
+ if self.correcting_xt_fn is not None:
1183
+ x = self.correcting_xt_fn(x, t, step)
1184
+ if return_intermediate:
1185
+ intermediates.append(x)
1186
+ # Init the first `order` values by lower order multistep DPM-Solver.
1187
+ for step in range(1, order):
1188
+ t = timesteps[step]
1189
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
1190
+ if self.correcting_xt_fn is not None:
1191
+ x = self.correcting_xt_fn(x, t, step)
1192
+ if return_intermediate:
1193
+ intermediates.append(x)
1194
+ t_prev_list.append(t)
1195
+ model_prev_list.append(self.model_fn(x, t))
1196
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1197
+ for step in range(order, steps + 1):
1198
+ t = timesteps[step]
1199
+ # We only use lower order for steps < 10
1200
+ if lower_order_final and steps < 10:
1201
+ step_order = min(order, steps + 1 - step)
1202
+ else:
1203
+ step_order = order
1204
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
1205
+ if self.correcting_xt_fn is not None:
1206
+ x = self.correcting_xt_fn(x, t, step)
1207
+ if return_intermediate:
1208
+ intermediates.append(x)
1209
+ for i in range(order - 1):
1210
+ t_prev_list[i] = t_prev_list[i + 1]
1211
+ model_prev_list[i] = model_prev_list[i + 1]
1212
+ t_prev_list[-1] = t
1213
+ # We do not need to evaluate the final model value.
1214
+ if step < steps:
1215
+ model_prev_list[-1] = self.model_fn(x, t)
1216
+ elif method in ['singlestep', 'singlestep_fixed']:
1217
+ if method == 'singlestep':
1218
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1219
+ elif method == 'singlestep_fixed':
1220
+ K = steps // order
1221
+ orders = [order,] * K
1222
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1223
+ for step, order in enumerate(orders):
1224
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1225
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
1226
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1227
+ h = lambda_inner[-1] - lambda_inner[0]
1228
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1229
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1230
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1231
+ if self.correcting_xt_fn is not None:
1232
+ x = self.correcting_xt_fn(x, t, step)
1233
+ if return_intermediate:
1234
+ intermediates.append(x)
1235
+ else:
1236
+ raise ValueError("Got wrong method {}".format(method))
1237
+ if denoise_to_zero:
1238
+ t = torch.ones((1,)).to(device) * t_0
1239
+ x = self.denoise_to_zero_fn(x, t)
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step + 1)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ if return_intermediate:
1245
+ return x, intermediates
1246
+ else:
1247
+ return x
1248
+
1249
+
1250
+
1251
+ #############################################################
1252
+ # other utility functions
1253
+ #############################################################
1254
+
1255
+ def interpolate_fn(x, xp, yp):
1256
+ """
1257
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1258
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1259
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1260
+
1261
+ Args:
1262
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1263
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1264
+ yp: PyTorch tensor with shape [C, K].
1265
+ Returns:
1266
+ The function values f(x), with shape [N, C].
1267
+ """
1268
+ N, K = x.shape[0], xp.shape[1]
1269
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1270
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1271
+ x_idx = torch.argmin(x_indices, dim=2)
1272
+ cand_start_idx = x_idx - 1
1273
+ start_idx = torch.where(
1274
+ torch.eq(x_idx, 0),
1275
+ torch.tensor(1, device=x.device),
1276
+ torch.where(
1277
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1278
+ ),
1279
+ )
1280
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1281
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1282
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1283
+ start_idx2 = torch.where(
1284
+ torch.eq(x_idx, 0),
1285
+ torch.tensor(0, device=x.device),
1286
+ torch.where(
1287
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1288
+ ),
1289
+ )
1290
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1291
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1292
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1293
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1294
+ return cand
1295
+
1296
+
1297
+ def expand_dims(v, dims):
1298
+ """
1299
+ Expand the tensor `v` to the dim `dims`.
1300
+
1301
+ Args:
1302
+ `v`: a PyTorch tensor with shape [N].
1303
+ `dim`: a `int`.
1304
+ Returns:
1305
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1306
+ """
1307
+ return v[(...,) + (None,)*(dims - 1)]
diffusion/how to export onnx.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ - Open [onnx_export](onnx_export.py)
2
+ - project_name = "dddsp" change "project_name" to your project name
3
+ - model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path
4
+ - Run
diffusion/infer_gt_mel.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from diffusion.unit2mel import load_model_vocoder
5
+
6
+
7
+ class DiffGtMel:
8
+ def __init__(self, project_path=None, device=None):
9
+ self.project_path = project_path
10
+ if device is not None:
11
+ self.device = device
12
+ else:
13
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ self.model = None
15
+ self.vocoder = None
16
+ self.args = None
17
+
18
+ def flush_model(self, project_path, ddsp_config=None):
19
+ if (self.model is None) or (project_path != self.project_path):
20
+ model, vocoder, args = load_model_vocoder(project_path, device=self.device)
21
+ if self.check_args(ddsp_config, args):
22
+ self.model = model
23
+ self.vocoder = vocoder
24
+ self.args = args
25
+
26
+ def check_args(self, args1, args2):
27
+ if args1.data.block_size != args2.data.block_size:
28
+ raise ValueError("DDSP与DIFF模型的block_size不一致")
29
+ if args1.data.sampling_rate != args2.data.sampling_rate:
30
+ raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
31
+ if args1.data.encoder != args2.data.encoder:
32
+ raise ValueError("DDSP与DIFF模型的encoder不一致")
33
+ return True
34
+
35
+ def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm',
36
+ spk_mix_dict=None, start_frame=0):
37
+ input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
38
+ out_mel = self.model(
39
+ hubert,
40
+ f0,
41
+ volume,
42
+ spk_id=spk_id,
43
+ spk_mix_dict=spk_mix_dict,
44
+ gt_spec=input_mel,
45
+ infer=True,
46
+ infer_speedup=acc,
47
+ method=method,
48
+ k_step=k_step,
49
+ use_tqdm=False)
50
+ if start_frame > 0:
51
+ out_mel = out_mel[:, start_frame:, :]
52
+ f0 = f0[:, start_frame:, :]
53
+ output = self.vocoder.infer(out_mel, f0)
54
+ if start_frame > 0:
55
+ output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
56
+ return output
57
+
58
+ def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0,
59
+ use_silence=False, spk_mix_dict=None):
60
+ start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
61
+ if use_silence:
62
+ audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:]
63
+ f0 = f0[:, start_frame:, :]
64
+ hubert = hubert[:, start_frame:, :]
65
+ volume = volume[:, start_frame:, :]
66
+ _start_frame = 0
67
+ else:
68
+ _start_frame = start_frame
69
+ audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step,
70
+ method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
71
+ if use_silence:
72
+ if start_frame > 0:
73
+ audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
74
+ return audio
diffusion/logger/__init__.py ADDED
File without changes
diffusion/logger/saver.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ author: wayn391@mastertones
3
+ '''
4
+
5
+ import datetime
6
+ import os
7
+ import time
8
+
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ import yaml
12
+ from torch.utils.tensorboard import SummaryWriter
13
+
14
+
15
+ class Saver(object):
16
+ def __init__(
17
+ self,
18
+ args,
19
+ initial_global_step=-1):
20
+
21
+ self.expdir = args.env.expdir
22
+ self.sample_rate = args.data.sampling_rate
23
+
24
+ # cold start
25
+ self.global_step = initial_global_step
26
+ self.init_time = time.time()
27
+ self.last_time = time.time()
28
+
29
+ # makedirs
30
+ os.makedirs(self.expdir, exist_ok=True)
31
+
32
+ # path
33
+ self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
34
+
35
+ # ckpt
36
+ os.makedirs(self.expdir, exist_ok=True)
37
+
38
+ # writer
39
+ self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
40
+
41
+ # save config
42
+ path_config = os.path.join(self.expdir, 'config.yaml')
43
+ with open(path_config, "w") as out_config:
44
+ yaml.dump(dict(args), out_config)
45
+
46
+
47
+ def log_info(self, msg):
48
+ '''log method'''
49
+ if isinstance(msg, dict):
50
+ msg_list = []
51
+ for k, v in msg.items():
52
+ tmp_str = ''
53
+ if isinstance(v, int):
54
+ tmp_str = '{}: {:,}'.format(k, v)
55
+ else:
56
+ tmp_str = '{}: {}'.format(k, v)
57
+
58
+ msg_list.append(tmp_str)
59
+ msg_str = '\n'.join(msg_list)
60
+ else:
61
+ msg_str = msg
62
+
63
+ # dsplay
64
+ print(msg_str)
65
+
66
+ # save
67
+ with open(self.path_log_info, 'a') as fp:
68
+ fp.write(msg_str+'\n')
69
+
70
+ def log_value(self, dict):
71
+ for k, v in dict.items():
72
+ self.writer.add_scalar(k, v, self.global_step)
73
+
74
+ def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
75
+ spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
76
+ spec = spec_cat[0]
77
+ if isinstance(spec, torch.Tensor):
78
+ spec = spec.cpu().numpy()
79
+ fig = plt.figure(figsize=(12, 9))
80
+ plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
81
+ plt.tight_layout()
82
+ self.writer.add_figure(name, fig, self.global_step)
83
+
84
+ def log_audio(self, dict):
85
+ for k, v in dict.items():
86
+ self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
87
+
88
+ def get_interval_time(self, update=True):
89
+ cur_time = time.time()
90
+ time_interval = cur_time - self.last_time
91
+ if update:
92
+ self.last_time = cur_time
93
+ return time_interval
94
+
95
+ def get_total_time(self, to_str=True):
96
+ total_time = time.time() - self.init_time
97
+ if to_str:
98
+ total_time = str(datetime.timedelta(
99
+ seconds=total_time))[:-5]
100
+ return total_time
101
+
102
+ def save_model(
103
+ self,
104
+ model,
105
+ optimizer,
106
+ name='model',
107
+ postfix='',
108
+ to_json=False):
109
+ # path
110
+ if postfix:
111
+ postfix = '_' + postfix
112
+ path_pt = os.path.join(
113
+ self.expdir , name+postfix+'.pt')
114
+
115
+ # check
116
+ print(' [*] model checkpoint saved: {}'.format(path_pt))
117
+
118
+ # save
119
+ if optimizer is not None:
120
+ torch.save({
121
+ 'global_step': self.global_step,
122
+ 'model': model.state_dict(),
123
+ 'optimizer': optimizer.state_dict()}, path_pt)
124
+ else:
125
+ torch.save({
126
+ 'global_step': self.global_step,
127
+ 'model': model.state_dict()}, path_pt)
128
+
129
+
130
+ def delete_model(self, name='model', postfix=''):
131
+ # path
132
+ if postfix:
133
+ postfix = '_' + postfix
134
+ path_pt = os.path.join(
135
+ self.expdir , name+postfix+'.pt')
136
+
137
+ # delete
138
+ if os.path.exists(path_pt):
139
+ os.remove(path_pt)
140
+ print(' [*] model checkpoint deleted: {}'.format(path_pt))
141
+
142
+ def global_step_increment(self):
143
+ self.global_step += 1
144
+
145
+
diffusion/logger/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ import yaml
6
+
7
+
8
+ def traverse_dir(
9
+ root_dir,
10
+ extensions,
11
+ amount=None,
12
+ str_include=None,
13
+ str_exclude=None,
14
+ is_pure=False,
15
+ is_sort=False,
16
+ is_ext=True):
17
+
18
+ file_list = []
19
+ cnt = 0
20
+ for root, _, files in os.walk(root_dir):
21
+ for file in files:
22
+ if any([file.endswith(f".{ext}") for ext in extensions]):
23
+ # path
24
+ mix_path = os.path.join(root, file)
25
+ pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
26
+
27
+ # amount
28
+ if (amount is not None) and (cnt == amount):
29
+ if is_sort:
30
+ file_list.sort()
31
+ return file_list
32
+
33
+ # check string
34
+ if (str_include is not None) and (str_include not in pure_path):
35
+ continue
36
+ if (str_exclude is not None) and (str_exclude in pure_path):
37
+ continue
38
+
39
+ if not is_ext:
40
+ ext = pure_path.split('.')[-1]
41
+ pure_path = pure_path[:-(len(ext)+1)]
42
+ file_list.append(pure_path)
43
+ cnt += 1
44
+ if is_sort:
45
+ file_list.sort()
46
+ return file_list
47
+
48
+
49
+
50
+ class DotDict(dict):
51
+ def __getattr__(*args):
52
+ val = dict.get(*args)
53
+ return DotDict(val) if type(val) is dict else val
54
+
55
+ __setattr__ = dict.__setitem__
56
+ __delattr__ = dict.__delitem__
57
+
58
+
59
+ def get_network_paras_amount(model_dict):
60
+ info = dict()
61
+ for model_name, model in model_dict.items():
62
+ # all_params = sum(p.numel() for p in model.parameters())
63
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
64
+
65
+ info[model_name] = trainable_params
66
+ return info
67
+
68
+
69
+ def load_config(path_config):
70
+ with open(path_config, "r") as config:
71
+ args = yaml.safe_load(config)
72
+ args = DotDict(args)
73
+ # print(args)
74
+ return args
75
+
76
+ def save_config(path_config,config):
77
+ config = dict(config)
78
+ with open(path_config, "w") as f:
79
+ yaml.dump(config, f)
80
+
81
+ def to_json(path_params, path_json):
82
+ params = torch.load(path_params, map_location=torch.device('cpu'))
83
+ raw_state_dict = {}
84
+ for k, v in params.items():
85
+ val = v.flatten().numpy().tolist()
86
+ raw_state_dict[k] = val
87
+
88
+ with open(path_json, 'w') as outfile:
89
+ json.dump(raw_state_dict, outfile,indent= "\t")
90
+
91
+
92
+ def convert_tensor_to_numpy(tensor, is_squeeze=True):
93
+ if is_squeeze:
94
+ tensor = tensor.squeeze()
95
+ if tensor.requires_grad:
96
+ tensor = tensor.detach()
97
+ if tensor.is_cuda:
98
+ tensor = tensor.cpu()
99
+ return tensor.numpy()
100
+
101
+
102
+ def load_model(
103
+ expdir,
104
+ model,
105
+ optimizer,
106
+ name='model',
107
+ postfix='',
108
+ device='cpu'):
109
+ if postfix == '':
110
+ postfix = '_' + postfix
111
+ path = os.path.join(expdir, name+postfix)
112
+ path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
113
+ global_step = 0
114
+ if len(path_pt) > 0:
115
+ steps = [s[len(path):] for s in path_pt]
116
+ maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
117
+ if maxstep >= 0:
118
+ path_pt = path+str(maxstep)+'.pt'
119
+ else:
120
+ path_pt = path+'best.pt'
121
+ print(' [*] restoring model from', path_pt)
122
+ ckpt = torch.load(path_pt, map_location=torch.device(device))
123
+ global_step = ckpt['global_step']
124
+ model.load_state_dict(ckpt['model'], strict=False)
125
+ if ckpt.get("optimizer") is not None:
126
+ optimizer.load_state_dict(ckpt['optimizer'])
127
+ return global_step, model, optimizer
diffusion/onnx_export.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import yaml
8
+ from diffusion_onnx import GaussianDiffusion
9
+
10
+
11
+ class DotDict(dict):
12
+ def __getattr__(*args):
13
+ val = dict.get(*args)
14
+ return DotDict(val) if type(val) is dict else val
15
+
16
+ __setattr__ = dict.__setitem__
17
+ __delattr__ = dict.__delitem__
18
+
19
+
20
+ def load_model_vocoder(
21
+ model_path,
22
+ device='cpu'):
23
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
24
+ with open(config_file, "r") as config:
25
+ args = yaml.safe_load(config)
26
+ args = DotDict(args)
27
+
28
+ # load model
29
+ model = Unit2Mel(
30
+ args.data.encoder_out_channels,
31
+ args.model.n_spk,
32
+ args.model.use_pitch_aug,
33
+ 128,
34
+ args.model.n_layers,
35
+ args.model.n_chans,
36
+ args.model.n_hidden,
37
+ args.model.timesteps,
38
+ args.model.k_step_max)
39
+
40
+ print(' [Loading] ' + model_path)
41
+ ckpt = torch.load(model_path, map_location=torch.device(device))
42
+ model.to(device)
43
+ model.load_state_dict(ckpt['model'])
44
+ model.eval()
45
+ return model, args
46
+
47
+
48
+ class Unit2Mel(nn.Module):
49
+ def __init__(
50
+ self,
51
+ input_channel,
52
+ n_spk,
53
+ use_pitch_aug=False,
54
+ out_dims=128,
55
+ n_layers=20,
56
+ n_chans=384,
57
+ n_hidden=256,
58
+ timesteps=1000,
59
+ k_step_max=1000):
60
+ super().__init__()
61
+
62
+ self.unit_embed = nn.Linear(input_channel, n_hidden)
63
+ self.f0_embed = nn.Linear(1, n_hidden)
64
+ self.volume_embed = nn.Linear(1, n_hidden)
65
+ if use_pitch_aug:
66
+ self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
67
+ else:
68
+ self.aug_shift_embed = None
69
+ self.n_spk = n_spk
70
+ if n_spk is not None and n_spk > 1:
71
+ self.spk_embed = nn.Embedding(n_spk, n_hidden)
72
+
73
+ self.timesteps = timesteps if timesteps is not None else 1000
74
+ self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
75
+
76
+
77
+ # diffusion
78
+ self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden,self.timesteps,self.k_step_max)
79
+ self.hidden_size = n_hidden
80
+ self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))
81
+
82
+
83
+
84
+ def forward(self, units, mel2ph, f0, volume, g = None):
85
+
86
+ '''
87
+ input:
88
+ B x n_frames x n_unit
89
+ return:
90
+ dict of B x n_frames x feat
91
+ '''
92
+
93
+ decoder_inp = F.pad(units, [0, 0, 1, 0])
94
+ mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, units.shape[-1]])
95
+ units = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
96
+
97
+ x = self.unit_embed(units) + self.f0_embed((1 + f0.unsqueeze(-1) / 700).log()) + self.volume_embed(volume.unsqueeze(-1))
98
+
99
+ if self.n_spk is not None and self.n_spk > 1: # [N, S] * [S, B, 1, H]
100
+ g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
101
+ g = g * self.speaker_map # [N, S, B, 1, H]
102
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
103
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
104
+ x = x.transpose(1, 2) + g
105
+ return x
106
+ else:
107
+ return x.transpose(1, 2)
108
+
109
+
110
+ def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
111
+ gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
112
+
113
+ '''
114
+ input:
115
+ B x n_frames x n_unit
116
+ return:
117
+ dict of B x n_frames x feat
118
+ '''
119
+ x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
120
+ if self.n_spk is not None and self.n_spk > 1:
121
+ if spk_mix_dict is not None:
122
+ spk_embed_mix = torch.zeros((1,1,self.hidden_size))
123
+ for k, v in spk_mix_dict.items():
124
+ spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
125
+ spk_embeddd = self.spk_embed(spk_id_torch)
126
+ self.speaker_map[k] = spk_embeddd
127
+ spk_embed_mix = spk_embed_mix + v * spk_embeddd
128
+ x = x + spk_embed_mix
129
+ else:
130
+ x = x + self.spk_embed(spk_id - 1)
131
+ self.speaker_map = self.speaker_map.unsqueeze(0)
132
+ self.speaker_map = self.speaker_map.detach()
133
+ return x.transpose(1, 2)
134
+
135
+ def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True):
136
+ hubert_hidden_size = 768
137
+ n_frames = 100
138
+ hubert = torch.randn((1, n_frames, hubert_hidden_size))
139
+ mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
140
+ f0 = torch.randn((1, n_frames))
141
+ volume = torch.randn((1, n_frames))
142
+ spk_mix = []
143
+ spks = {}
144
+ if self.n_spk is not None and self.n_spk > 1:
145
+ for i in range(self.n_spk):
146
+ spk_mix.append(1.0/float(self.n_spk))
147
+ spks.update({i:1.0/float(self.n_spk)})
148
+ spk_mix = torch.tensor(spk_mix)
149
+ spk_mix = spk_mix.repeat(n_frames, 1)
150
+ self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
151
+ self.forward(hubert, mel2ph, f0, volume, spk_mix)
152
+ if export_encoder:
153
+ torch.onnx.export(
154
+ self,
155
+ (hubert, mel2ph, f0, volume, spk_mix),
156
+ f"{project_name}_encoder.onnx",
157
+ input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
158
+ output_names=["mel_pred"],
159
+ dynamic_axes={
160
+ "hubert": [1],
161
+ "f0": [1],
162
+ "volume": [1],
163
+ "mel2ph": [1],
164
+ "spk_mix": [0],
165
+ },
166
+ opset_version=16
167
+ )
168
+
169
+ self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after)
170
+
171
+ def ExportOnnx(self, project_name=None):
172
+ hubert_hidden_size = 768
173
+ n_frames = 100
174
+ hubert = torch.randn((1, n_frames, hubert_hidden_size))
175
+ mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
176
+ f0 = torch.randn((1, n_frames))
177
+ volume = torch.randn((1, n_frames))
178
+ spk_mix = []
179
+ spks = {}
180
+ if self.n_spk is not None and self.n_spk > 1:
181
+ for i in range(self.n_spk):
182
+ spk_mix.append(1.0/float(self.n_spk))
183
+ spks.update({i:1.0/float(self.n_spk)})
184
+ spk_mix = torch.tensor(spk_mix)
185
+ self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
186
+ self.forward(hubert, mel2ph, f0, volume, spk_mix)
187
+
188
+ torch.onnx.export(
189
+ self,
190
+ (hubert, mel2ph, f0, volume, spk_mix),
191
+ f"{project_name}_encoder.onnx",
192
+ input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"],
193
+ output_names=["mel_pred"],
194
+ dynamic_axes={
195
+ "hubert": [1],
196
+ "f0": [1],
197
+ "volume": [1],
198
+ "mel2ph": [1]
199
+ },
200
+ opset_version=16
201
+ )
202
+
203
+ condition = torch.randn(1,self.decoder.n_hidden,n_frames)
204
+ noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32)
205
+ pndm_speedup = torch.LongTensor([100])
206
+ K_steps = torch.LongTensor([1000])
207
+ self.decoder = torch.jit.script(self.decoder)
208
+ self.decoder(condition, noise, pndm_speedup, K_steps)
209
+
210
+ torch.onnx.export(
211
+ self.decoder,
212
+ (condition, noise, pndm_speedup, K_steps),
213
+ f"{project_name}_diffusion.onnx",
214
+ input_names=["condition", "noise", "pndm_speedup", "K_steps"],
215
+ output_names=["mel"],
216
+ dynamic_axes={
217
+ "condition": [2],
218
+ "noise": [3],
219
+ },
220
+ opset_version=16
221
+ )
222
+
223
+
224
+ if __name__ == "__main__":
225
+ project_name = "dddsp"
226
+ model_path = f'{project_name}/model_500000.pt'
227
+
228
+ model, _ = load_model_vocoder(model_path)
229
+
230
+ # 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样)
231
+ model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True)
232
+
233
+ # 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可)
234
+ # model.ExportOnnx(project_name)
235
+
diffusion/solver.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from torch import autocast
7
+ from torch.cuda.amp import GradScaler
8
+
9
+ from diffusion.logger import utils
10
+ from diffusion.logger.saver import Saver
11
+
12
+
13
+ def test(args, model, vocoder, loader_test, saver):
14
+ print(' [*] testing...')
15
+ model.eval()
16
+
17
+ # losses
18
+ test_loss = 0.
19
+
20
+ # intialization
21
+ num_batches = len(loader_test)
22
+ rtf_all = []
23
+
24
+ # run
25
+ with torch.no_grad():
26
+ for bidx, data in enumerate(loader_test):
27
+ fn = data['name'][0].split("/")[-1]
28
+ speaker = data['name'][0].split("/")[-2]
29
+ print('--------')
30
+ print('{}/{} - {}'.format(bidx, num_batches, fn))
31
+
32
+ # unpack data
33
+ for k in data.keys():
34
+ if not k.startswith('name'):
35
+ data[k] = data[k].to(args.device)
36
+ print('>>', data['name'][0])
37
+
38
+ # forward
39
+ st_time = time.time()
40
+ mel = model(
41
+ data['units'],
42
+ data['f0'],
43
+ data['volume'],
44
+ data['spk_id'],
45
+ gt_spec=None if model.k_step_max == model.timesteps else data['mel'],
46
+ infer=True,
47
+ infer_speedup=args.infer.speedup,
48
+ method=args.infer.method,
49
+ k_step=model.k_step_max
50
+ )
51
+ signal = vocoder.infer(mel, data['f0'])
52
+ ed_time = time.time()
53
+
54
+ # RTF
55
+ run_time = ed_time - st_time
56
+ song_time = signal.shape[-1] / args.data.sampling_rate
57
+ rtf = run_time / song_time
58
+ print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
59
+ rtf_all.append(rtf)
60
+
61
+ # loss
62
+ for i in range(args.train.batch_size):
63
+ loss = model(
64
+ data['units'],
65
+ data['f0'],
66
+ data['volume'],
67
+ data['spk_id'],
68
+ gt_spec=data['mel'],
69
+ infer=False,
70
+ k_step=model.k_step_max)
71
+ test_loss += loss.item()
72
+
73
+ # log mel
74
+ saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel)
75
+
76
+ # log audi
77
+ path_audio = data['name_ext'][0]
78
+ audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
79
+ if len(audio.shape) > 1:
80
+ audio = librosa.to_mono(audio)
81
+ audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
82
+ saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal})
83
+ # report
84
+ test_loss /= args.train.batch_size
85
+ test_loss /= num_batches
86
+
87
+ # check
88
+ print(' [test_loss] test_loss:', test_loss)
89
+ print(' Real Time Factor', np.mean(rtf_all))
90
+ return test_loss
91
+
92
+
93
+ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
94
+ # saver
95
+ saver = Saver(args, initial_global_step=initial_global_step)
96
+
97
+ # model size
98
+ params_count = utils.get_network_paras_amount({'model': model})
99
+ saver.log_info('--- model size ---')
100
+ saver.log_info(params_count)
101
+
102
+ # run
103
+ num_batches = len(loader_train)
104
+ model.train()
105
+ saver.log_info('======= start training =======')
106
+ scaler = GradScaler()
107
+ if args.train.amp_dtype == 'fp32':
108
+ dtype = torch.float32
109
+ elif args.train.amp_dtype == 'fp16':
110
+ dtype = torch.float16
111
+ elif args.train.amp_dtype == 'bf16':
112
+ dtype = torch.bfloat16
113
+ else:
114
+ raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
115
+ saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
116
+ for epoch in range(args.train.epochs):
117
+ for batch_idx, data in enumerate(loader_train):
118
+ saver.global_step_increment()
119
+ optimizer.zero_grad()
120
+
121
+ # unpack data
122
+ for k in data.keys():
123
+ if not k.startswith('name'):
124
+ data[k] = data[k].to(args.device)
125
+
126
+ # forward
127
+ if dtype == torch.float32:
128
+ loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
129
+ aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max)
130
+ else:
131
+ with autocast(device_type=args.device, dtype=dtype):
132
+ loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
133
+ aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max)
134
+
135
+ # handle nan loss
136
+ if torch.isnan(loss):
137
+ raise ValueError(' [x] nan loss ')
138
+ else:
139
+ # backpropagate
140
+ if dtype == torch.float32:
141
+ loss.backward()
142
+ optimizer.step()
143
+ else:
144
+ scaler.scale(loss).backward()
145
+ scaler.step(optimizer)
146
+ scaler.update()
147
+ scheduler.step()
148
+
149
+ # log loss
150
+ if saver.global_step % args.train.interval_log == 0:
151
+ current_lr = optimizer.param_groups[0]['lr']
152
+ saver.log_info(
153
+ 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
154
+ epoch,
155
+ batch_idx,
156
+ num_batches,
157
+ args.env.expdir,
158
+ args.train.interval_log/saver.get_interval_time(),
159
+ current_lr,
160
+ loss.item(),
161
+ saver.get_total_time(),
162
+ saver.global_step
163
+ )
164
+ )
165
+
166
+ saver.log_value({
167
+ 'train/loss': loss.item()
168
+ })
169
+
170
+ saver.log_value({
171
+ 'train/lr': current_lr
172
+ })
173
+
174
+ # validation
175
+ if saver.global_step % args.train.interval_val == 0:
176
+ optimizer_save = optimizer if args.train.save_opt else None
177
+
178
+ # save latest
179
+ saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
180
+ last_val_step = saver.global_step - args.train.interval_val
181
+ if last_val_step % args.train.interval_force_save != 0:
182
+ saver.delete_model(postfix=f'{last_val_step}')
183
+
184
+ # run testing set
185
+ test_loss = test(args, model, vocoder, loader_test, saver)
186
+
187
+ # log loss
188
+ saver.log_info(
189
+ ' --- <validation> --- \nloss: {:.3f}. '.format(
190
+ test_loss,
191
+ )
192
+ )
193
+
194
+ saver.log_value({
195
+ 'validation/loss': test_loss
196
+ })
197
+
198
+ model.train()
199
+
200
+
diffusion/uni_pc.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ dtype=torch.float32,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
75
+
76
+ self.schedule = schedule
77
+ if schedule == 'discrete':
78
+ if betas is not None:
79
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
80
+ else:
81
+ assert alphas_cumprod is not None
82
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
83
+ self.total_N = len(log_alphas)
84
+ self.T = 1.
85
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
86
+ self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
87
+ else:
88
+ self.total_N = 1000
89
+ self.beta_0 = continuous_beta_0
90
+ self.beta_1 = continuous_beta_1
91
+ self.cosine_s = 0.008
92
+ self.cosine_beta_max = 999.
93
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
94
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
95
+ self.schedule = schedule
96
+ if schedule == 'cosine':
97
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
98
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
99
+ self.T = 0.9946
100
+ else:
101
+ self.T = 1.
102
+
103
+ def marginal_log_mean_coeff(self, t):
104
+ """
105
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
106
+ """
107
+ if self.schedule == 'discrete':
108
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
109
+ elif self.schedule == 'linear':
110
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
111
+ elif self.schedule == 'cosine':
112
+ def log_alpha_fn(s):
113
+ return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
114
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
115
+ return log_alpha_t
116
+
117
+ def marginal_alpha(self, t):
118
+ """
119
+ Compute alpha_t of a given continuous-time label t in [0, T].
120
+ """
121
+ return torch.exp(self.marginal_log_mean_coeff(t))
122
+
123
+ def marginal_std(self, t):
124
+ """
125
+ Compute sigma_t of a given continuous-time label t in [0, T].
126
+ """
127
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
128
+
129
+ def marginal_lambda(self, t):
130
+ """
131
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
132
+ """
133
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
134
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
135
+ return log_mean_coeff - log_std
136
+
137
+ def inverse_lambda(self, lamb):
138
+ """
139
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
140
+ """
141
+ if self.schedule == 'linear':
142
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
143
+ Delta = self.beta_0**2 + tmp
144
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
145
+ elif self.schedule == 'discrete':
146
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
147
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
148
+ return t.reshape((-1,))
149
+ else:
150
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
151
+ def t_fn(log_alpha_t):
152
+ return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2.0 * (1.0 + self.cosine_s) / math.pi - self.cosine_s
153
+ t = t_fn(log_alpha)
154
+ return t
155
+
156
+
157
+ def model_wrapper(
158
+ model,
159
+ noise_schedule,
160
+ model_type="noise",
161
+ model_kwargs={},
162
+ guidance_type="uncond",
163
+ condition=None,
164
+ unconditional_condition=None,
165
+ guidance_scale=1.,
166
+ classifier_fn=None,
167
+ classifier_kwargs={},
168
+ ):
169
+ """Create a wrapper function for the noise prediction model.
170
+ """
171
+
172
+ def get_model_input_time(t_continuous):
173
+ """
174
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
175
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
176
+ For continuous-time DPMs, we just use `t_continuous`.
177
+ """
178
+ if noise_schedule.schedule == 'discrete':
179
+ return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
180
+ else:
181
+ return t_continuous
182
+
183
+ def noise_pred_fn(x, t_continuous, cond=None):
184
+ t_input = get_model_input_time(t_continuous)
185
+ if cond is None:
186
+ output = model(x, t_input, **model_kwargs)
187
+ else:
188
+ output = model(x, t_input, cond, **model_kwargs)
189
+ if model_type == "noise":
190
+ return output
191
+ elif model_type == "x_start":
192
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
193
+ return (x - alpha_t * output) / sigma_t
194
+ elif model_type == "v":
195
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
196
+ return alpha_t * output + sigma_t * x
197
+ elif model_type == "score":
198
+ sigma_t = noise_schedule.marginal_std(t_continuous)
199
+ return -sigma_t * output
200
+
201
+ def cond_grad_fn(x, t_input):
202
+ """
203
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
204
+ """
205
+ with torch.enable_grad():
206
+ x_in = x.detach().requires_grad_(True)
207
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
208
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
209
+
210
+ def model_fn(x, t_continuous):
211
+ """
212
+ The noise predicition model function that is used for DPM-Solver.
213
+ """
214
+ if guidance_type == "uncond":
215
+ return noise_pred_fn(x, t_continuous)
216
+ elif guidance_type == "classifier":
217
+ assert classifier_fn is not None
218
+ t_input = get_model_input_time(t_continuous)
219
+ cond_grad = cond_grad_fn(x, t_input)
220
+ sigma_t = noise_schedule.marginal_std(t_continuous)
221
+ noise = noise_pred_fn(x, t_continuous)
222
+ return noise - guidance_scale * sigma_t * cond_grad
223
+ elif guidance_type == "classifier-free":
224
+ if guidance_scale == 1. or unconditional_condition is None:
225
+ return noise_pred_fn(x, t_continuous, cond=condition)
226
+ else:
227
+ x_in = torch.cat([x] * 2)
228
+ t_in = torch.cat([t_continuous] * 2)
229
+ c_in = torch.cat([unconditional_condition, condition])
230
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
231
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
232
+
233
+ assert model_type in ["noise", "x_start", "v"]
234
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
235
+ return model_fn
236
+
237
+
238
+ class UniPC:
239
+ def __init__(
240
+ self,
241
+ model_fn,
242
+ noise_schedule,
243
+ algorithm_type="data_prediction",
244
+ correcting_x0_fn=None,
245
+ correcting_xt_fn=None,
246
+ thresholding_max_val=1.,
247
+ dynamic_thresholding_ratio=0.995,
248
+ variant='bh1'
249
+ ):
250
+ """Construct a UniPC.
251
+
252
+ We support both data_prediction and noise_prediction.
253
+ """
254
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
255
+ self.noise_schedule = noise_schedule
256
+ assert algorithm_type in ["data_prediction", "noise_prediction"]
257
+
258
+ if correcting_x0_fn == "dynamic_thresholding":
259
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
260
+ else:
261
+ self.correcting_x0_fn = correcting_x0_fn
262
+
263
+ self.correcting_xt_fn = correcting_xt_fn
264
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
265
+ self.thresholding_max_val = thresholding_max_val
266
+
267
+ self.variant = variant
268
+ self.predict_x0 = algorithm_type == "data_prediction"
269
+
270
+ def dynamic_thresholding_fn(self, x0, t=None):
271
+ """
272
+ The dynamic thresholding method.
273
+ """
274
+ dims = x0.dim()
275
+ p = self.dynamic_thresholding_ratio
276
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
277
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
278
+ x0 = torch.clamp(x0, -s, s) / s
279
+ return x0
280
+
281
+ def noise_prediction_fn(self, x, t):
282
+ """
283
+ Return the noise prediction model.
284
+ """
285
+ return self.model(x, t)
286
+
287
+ def data_prediction_fn(self, x, t):
288
+ """
289
+ Return the data prediction model (with corrector).
290
+ """
291
+ noise = self.noise_prediction_fn(x, t)
292
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
293
+ x0 = (x - sigma_t * noise) / alpha_t
294
+ if self.correcting_x0_fn is not None:
295
+ x0 = self.correcting_x0_fn(x0)
296
+ return x0
297
+
298
+ def model_fn(self, x, t):
299
+ """
300
+ Convert the model to the noise prediction model or the data prediction model.
301
+ """
302
+ if self.predict_x0:
303
+ return self.data_prediction_fn(x, t)
304
+ else:
305
+ return self.noise_prediction_fn(x, t)
306
+
307
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
308
+ """Compute the intermediate time steps for sampling.
309
+ """
310
+ if skip_type == 'logSNR':
311
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
312
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
313
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
314
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
315
+ elif skip_type == 'time_uniform':
316
+ return torch.linspace(t_T, t_0, N + 1).to(device)
317
+ elif skip_type == 'time_quadratic':
318
+ t_order = 2
319
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
320
+ return t
321
+ else:
322
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
323
+
324
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
325
+ """
326
+ Get the order of each step for sampling by the singlestep DPM-Solver.
327
+ """
328
+ if order == 3:
329
+ K = steps // 3 + 1
330
+ if steps % 3 == 0:
331
+ orders = [3,] * (K - 2) + [2, 1]
332
+ elif steps % 3 == 1:
333
+ orders = [3,] * (K - 1) + [1]
334
+ else:
335
+ orders = [3,] * (K - 1) + [2]
336
+ elif order == 2:
337
+ if steps % 2 == 0:
338
+ K = steps // 2
339
+ orders = [2,] * K
340
+ else:
341
+ K = steps // 2 + 1
342
+ orders = [2,] * (K - 1) + [1]
343
+ elif order == 1:
344
+ K = steps
345
+ orders = [1,] * steps
346
+ else:
347
+ raise ValueError("'order' must be '1' or '2' or '3'.")
348
+ if skip_type == 'logSNR':
349
+ # To reproduce the results in DPM-Solver paper
350
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
351
+ else:
352
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
353
+ return timesteps_outer, orders
354
+
355
+ def denoise_to_zero_fn(self, x, s):
356
+ """
357
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
358
+ """
359
+ return self.data_prediction_fn(x, s)
360
+
361
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
362
+ if len(t.shape) == 0:
363
+ t = t.view(-1)
364
+ if 'bh' in self.variant:
365
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
366
+ else:
367
+ assert self.variant == 'vary_coeff'
368
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
369
+
370
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
371
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
372
+ ns = self.noise_schedule
373
+ assert order <= len(model_prev_list)
374
+
375
+ # first compute rks
376
+ t_prev_0 = t_prev_list[-1]
377
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
378
+ lambda_t = ns.marginal_lambda(t)
379
+ model_prev_0 = model_prev_list[-1]
380
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
381
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
382
+ alpha_t = torch.exp(log_alpha_t)
383
+
384
+ h = lambda_t - lambda_prev_0
385
+
386
+ rks = []
387
+ D1s = []
388
+ for i in range(1, order):
389
+ t_prev_i = t_prev_list[-(i + 1)]
390
+ model_prev_i = model_prev_list[-(i + 1)]
391
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
392
+ rk = (lambda_prev_i - lambda_prev_0) / h
393
+ rks.append(rk)
394
+ D1s.append((model_prev_i - model_prev_0) / rk)
395
+
396
+ rks.append(1.)
397
+ rks = torch.tensor(rks, device=x.device)
398
+
399
+ K = len(rks)
400
+ # build C matrix
401
+ C = []
402
+
403
+ col = torch.ones_like(rks)
404
+ for k in range(1, K + 1):
405
+ C.append(col)
406
+ col = col * rks / (k + 1)
407
+ C = torch.stack(C, dim=1)
408
+
409
+ if len(D1s) > 0:
410
+ D1s = torch.stack(D1s, dim=1) # (B, K)
411
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
412
+ A_p = C_inv_p
413
+
414
+ if use_corrector:
415
+ #print('using corrector')
416
+ C_inv = torch.linalg.inv(C)
417
+ A_c = C_inv
418
+
419
+ hh = -h if self.predict_x0 else h
420
+ h_phi_1 = torch.expm1(hh)
421
+ h_phi_ks = []
422
+ factorial_k = 1
423
+ h_phi_k = h_phi_1
424
+ for k in range(1, K + 2):
425
+ h_phi_ks.append(h_phi_k)
426
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
427
+ factorial_k *= (k + 1)
428
+
429
+ model_t = None
430
+ if self.predict_x0:
431
+ x_t_ = (
432
+ sigma_t / sigma_prev_0 * x
433
+ - alpha_t * h_phi_1 * model_prev_0
434
+ )
435
+ # now predictor
436
+ x_t = x_t_
437
+ if len(D1s) > 0:
438
+ # compute the residuals for predictor
439
+ for k in range(K - 1):
440
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
441
+ # now corrector
442
+ if use_corrector:
443
+ model_t = self.model_fn(x_t, t)
444
+ D1_t = (model_t - model_prev_0)
445
+ x_t = x_t_
446
+ k = 0
447
+ for k in range(K - 1):
448
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
449
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
450
+ else:
451
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
452
+ x_t_ = (
453
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
454
+ - (sigma_t * h_phi_1) * model_prev_0
455
+ )
456
+ # now predictor
457
+ x_t = x_t_
458
+ if len(D1s) > 0:
459
+ # compute the residuals for predictor
460
+ for k in range(K - 1):
461
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
462
+ # now corrector
463
+ if use_corrector:
464
+ model_t = self.model_fn(x_t, t)
465
+ D1_t = (model_t - model_prev_0)
466
+ x_t = x_t_
467
+ k = 0
468
+ for k in range(K - 1):
469
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
470
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
471
+ return x_t, model_t
472
+
473
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
474
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
475
+ ns = self.noise_schedule
476
+ assert order <= len(model_prev_list)
477
+
478
+ # first compute rks
479
+ t_prev_0 = t_prev_list[-1]
480
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
481
+ lambda_t = ns.marginal_lambda(t)
482
+ model_prev_0 = model_prev_list[-1]
483
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
484
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
485
+ alpha_t = torch.exp(log_alpha_t)
486
+
487
+ h = lambda_t - lambda_prev_0
488
+
489
+ rks = []
490
+ D1s = []
491
+ for i in range(1, order):
492
+ t_prev_i = t_prev_list[-(i + 1)]
493
+ model_prev_i = model_prev_list[-(i + 1)]
494
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
495
+ rk = (lambda_prev_i - lambda_prev_0) / h
496
+ rks.append(rk)
497
+ D1s.append((model_prev_i - model_prev_0) / rk)
498
+
499
+ rks.append(1.)
500
+ rks = torch.tensor(rks, device=x.device)
501
+
502
+ R = []
503
+ b = []
504
+
505
+ hh = -h if self.predict_x0 else h
506
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
507
+ h_phi_k = h_phi_1 / hh - 1
508
+
509
+ factorial_i = 1
510
+
511
+ if self.variant == 'bh1':
512
+ B_h = hh
513
+ elif self.variant == 'bh2':
514
+ B_h = torch.expm1(hh)
515
+ else:
516
+ raise NotImplementedError()
517
+
518
+ for i in range(1, order + 1):
519
+ R.append(torch.pow(rks, i - 1))
520
+ b.append(h_phi_k * factorial_i / B_h)
521
+ factorial_i *= (i + 1)
522
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
523
+
524
+ R = torch.stack(R)
525
+ b = torch.cat(b)
526
+
527
+ # now predictor
528
+ use_predictor = len(D1s) > 0 and x_t is None
529
+ if len(D1s) > 0:
530
+ D1s = torch.stack(D1s, dim=1) # (B, K)
531
+ if x_t is None:
532
+ # for order 2, we use a simplified version
533
+ if order == 2:
534
+ rhos_p = torch.tensor([0.5], device=b.device)
535
+ else:
536
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
537
+ else:
538
+ D1s = None
539
+
540
+ if use_corrector:
541
+ #print('using corrector')
542
+ # for order 1, we use a simplified version
543
+ if order == 1:
544
+ rhos_c = torch.tensor([0.5], device=b.device)
545
+ else:
546
+ rhos_c = torch.linalg.solve(R, b)
547
+
548
+ model_t = None
549
+ if self.predict_x0:
550
+ x_t_ = (
551
+ sigma_t / sigma_prev_0 * x
552
+ - alpha_t * h_phi_1 * model_prev_0
553
+ )
554
+
555
+ if x_t is None:
556
+ if use_predictor:
557
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
558
+ else:
559
+ pred_res = 0
560
+ x_t = x_t_ - alpha_t * B_h * pred_res
561
+
562
+ if use_corrector:
563
+ model_t = self.model_fn(x_t, t)
564
+ if D1s is not None:
565
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
566
+ else:
567
+ corr_res = 0
568
+ D1_t = (model_t - model_prev_0)
569
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
570
+ else:
571
+ x_t_ = (
572
+ torch.exp(log_alpha_t - log_alpha_prev_0) * x
573
+ - sigma_t * h_phi_1 * model_prev_0
574
+ )
575
+ if x_t is None:
576
+ if use_predictor:
577
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
578
+ else:
579
+ pred_res = 0
580
+ x_t = x_t_ - sigma_t * B_h * pred_res
581
+
582
+ if use_corrector:
583
+ model_t = self.model_fn(x_t, t)
584
+ if D1s is not None:
585
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
586
+ else:
587
+ corr_res = 0
588
+ D1_t = (model_t - model_prev_0)
589
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
590
+ return x_t, model_t
591
+
592
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
593
+ method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False,
594
+ ):
595
+ """
596
+ Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`.
597
+ """
598
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
599
+ t_T = self.noise_schedule.T if t_start is None else t_start
600
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
601
+ if return_intermediate:
602
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
603
+ if self.correcting_xt_fn is not None:
604
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
605
+ device = x.device
606
+ intermediates = []
607
+ with torch.no_grad():
608
+ if method == 'multistep':
609
+ assert steps >= order
610
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
611
+ assert timesteps.shape[0] - 1 == steps
612
+ # Init the initial values.
613
+ step = 0
614
+ t = timesteps[step]
615
+ t_prev_list = [t]
616
+ model_prev_list = [self.model_fn(x, t)]
617
+ if self.correcting_xt_fn is not None:
618
+ x = self.correcting_xt_fn(x, t, step)
619
+ if return_intermediate:
620
+ intermediates.append(x)
621
+
622
+ # Init the first `order` values by lower order multistep UniPC.
623
+ for step in range(1, order):
624
+ t = timesteps[step]
625
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True)
626
+ if model_x is None:
627
+ model_x = self.model_fn(x, t)
628
+ if self.correcting_xt_fn is not None:
629
+ x = self.correcting_xt_fn(x, t, step)
630
+ if return_intermediate:
631
+ intermediates.append(x)
632
+ t_prev_list.append(t)
633
+ model_prev_list.append(model_x)
634
+
635
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
636
+ for step in range(order, steps + 1):
637
+ t = timesteps[step]
638
+ if lower_order_final:
639
+ step_order = min(order, steps + 1 - step)
640
+ else:
641
+ step_order = order
642
+ if step == steps:
643
+ #print('do not run corrector at the last step')
644
+ use_corrector = False
645
+ else:
646
+ use_corrector = True
647
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector)
648
+ if self.correcting_xt_fn is not None:
649
+ x = self.correcting_xt_fn(x, t, step)
650
+ if return_intermediate:
651
+ intermediates.append(x)
652
+ for i in range(order - 1):
653
+ t_prev_list[i] = t_prev_list[i + 1]
654
+ model_prev_list[i] = model_prev_list[i + 1]
655
+ t_prev_list[-1] = t
656
+ # We do not need to evaluate the final model value.
657
+ if step < steps:
658
+ if model_x is None:
659
+ model_x = self.model_fn(x, t)
660
+ model_prev_list[-1] = model_x
661
+ else:
662
+ raise ValueError("Got wrong method {}".format(method))
663
+
664
+ if denoise_to_zero:
665
+ t = torch.ones((1,)).to(device) * t_0
666
+ x = self.denoise_to_zero_fn(x, t)
667
+ if self.correcting_xt_fn is not None:
668
+ x = self.correcting_xt_fn(x, t, step + 1)
669
+ if return_intermediate:
670
+ intermediates.append(x)
671
+ if return_intermediate:
672
+ return x, intermediates
673
+ else:
674
+ return x
675
+
676
+
677
+ #############################################################
678
+ # other utility functions
679
+ #############################################################
680
+
681
+ def interpolate_fn(x, xp, yp):
682
+ """
683
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
684
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
685
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
686
+
687
+ Args:
688
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
689
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
690
+ yp: PyTorch tensor with shape [C, K].
691
+ Returns:
692
+ The function values f(x), with shape [N, C].
693
+ """
694
+ N, K = x.shape[0], xp.shape[1]
695
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
696
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
697
+ x_idx = torch.argmin(x_indices, dim=2)
698
+ cand_start_idx = x_idx - 1
699
+ start_idx = torch.where(
700
+ torch.eq(x_idx, 0),
701
+ torch.tensor(1, device=x.device),
702
+ torch.where(
703
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
704
+ ),
705
+ )
706
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
707
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
708
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
709
+ start_idx2 = torch.where(
710
+ torch.eq(x_idx, 0),
711
+ torch.tensor(0, device=x.device),
712
+ torch.where(
713
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
714
+ ),
715
+ )
716
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
717
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
718
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
719
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
720
+ return cand
721
+
722
+
723
+ def expand_dims(v, dims):
724
+ """
725
+ Expand the tensor `v` to the dim `dims`.
726
+
727
+ Args:
728
+ `v`: a PyTorch tensor with shape [N].
729
+ `dim`: a `int`.
730
+ Returns:
731
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
732
+ """
733
+ return v[(...,) + (None,)*(dims - 1)]
diffusion/unit2mel.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import yaml
7
+
8
+ from .diffusion import GaussianDiffusion
9
+ from .vocoder import Vocoder
10
+ from .wavenet import WaveNet
11
+
12
+
13
+ class DotDict(dict):
14
+ def __getattr__(*args):
15
+ val = dict.get(*args)
16
+ return DotDict(val) if type(val) is dict else val
17
+
18
+ __setattr__ = dict.__setitem__
19
+ __delattr__ = dict.__delitem__
20
+
21
+
22
+ def load_model_vocoder(
23
+ model_path,
24
+ device='cpu',
25
+ config_path = None
26
+ ):
27
+ if config_path is None:
28
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
29
+ else:
30
+ config_file = config_path
31
+
32
+ with open(config_file, "r") as config:
33
+ args = yaml.safe_load(config)
34
+ args = DotDict(args)
35
+
36
+ # load vocoder
37
+ vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
38
+
39
+ # load model
40
+ model = Unit2Mel(
41
+ args.data.encoder_out_channels,
42
+ args.model.n_spk,
43
+ args.model.use_pitch_aug,
44
+ vocoder.dimension,
45
+ args.model.n_layers,
46
+ args.model.n_chans,
47
+ args.model.n_hidden,
48
+ args.model.timesteps,
49
+ args.model.k_step_max
50
+ )
51
+
52
+ print(' [Loading] ' + model_path)
53
+ ckpt = torch.load(model_path, map_location=torch.device(device))
54
+ model.to(device)
55
+ model.load_state_dict(ckpt['model'])
56
+ model.eval()
57
+ print(f'Loaded diffusion model, sampler is {args.infer.method}, speedup: {args.infer.speedup} ')
58
+ return model, vocoder, args
59
+
60
+
61
+ class Unit2Mel(nn.Module):
62
+ def __init__(
63
+ self,
64
+ input_channel,
65
+ n_spk,
66
+ use_pitch_aug=False,
67
+ out_dims=128,
68
+ n_layers=20,
69
+ n_chans=384,
70
+ n_hidden=256,
71
+ timesteps=1000,
72
+ k_step_max=1000
73
+ ):
74
+ super().__init__()
75
+ self.unit_embed = nn.Linear(input_channel, n_hidden)
76
+ self.f0_embed = nn.Linear(1, n_hidden)
77
+ self.volume_embed = nn.Linear(1, n_hidden)
78
+ if use_pitch_aug:
79
+ self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
80
+ else:
81
+ self.aug_shift_embed = None
82
+ self.n_spk = n_spk
83
+ if n_spk is not None and n_spk > 1:
84
+ self.spk_embed = nn.Embedding(n_spk, n_hidden)
85
+
86
+ self.timesteps = timesteps if timesteps is not None else 1000
87
+ self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
88
+
89
+ self.n_hidden = n_hidden
90
+ # diffusion
91
+ self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden),timesteps=self.timesteps,k_step=self.k_step_max, out_dims=out_dims)
92
+ self.input_channel = input_channel
93
+
94
+ def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
95
+ gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
96
+
97
+ '''
98
+ input:
99
+ B x n_frames x n_unit
100
+ return:
101
+ dict of B x n_frames x feat
102
+ '''
103
+ x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
104
+ if self.n_spk is not None and self.n_spk > 1:
105
+ if spk_mix_dict is not None:
106
+ spk_embed_mix = torch.zeros((1,1,self.hidden_size))
107
+ for k, v in spk_mix_dict.items():
108
+ spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
109
+ spk_embeddd = self.spk_embed(spk_id_torch)
110
+ self.speaker_map[k] = spk_embeddd
111
+ spk_embed_mix = spk_embed_mix + v * spk_embeddd
112
+ x = x + spk_embed_mix
113
+ else:
114
+ x = x + self.spk_embed(spk_id - 1)
115
+ self.speaker_map = self.speaker_map.unsqueeze(0)
116
+ self.speaker_map = self.speaker_map.detach()
117
+ return x.transpose(1, 2)
118
+
119
+ def init_spkmix(self, n_spk):
120
+ self.speaker_map = torch.zeros((n_spk,1,1,self.n_hidden))
121
+ hubert_hidden_size = self.input_channel
122
+ n_frames = 10
123
+ hubert = torch.randn((1, n_frames, hubert_hidden_size))
124
+ f0 = torch.randn((1, n_frames))
125
+ volume = torch.randn((1, n_frames))
126
+ spks = {}
127
+ for i in range(n_spk):
128
+ spks.update({i:1.0/float(self.n_spk)})
129
+ self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
130
+
131
+ def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
132
+ gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
133
+
134
+ '''
135
+ input:
136
+ B x n_frames x n_unit
137
+ return:
138
+ dict of B x n_frames x feat
139
+ '''
140
+
141
+ if not self.training and gt_spec is not None and k_step>self.k_step_max:
142
+ raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!")
143
+
144
+ if not self.training and gt_spec is None and self.k_step_max!=self.timesteps:
145
+ raise Exception("This model can only be used for shallow diffusion and can not infer alone!")
146
+
147
+ x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
148
+ if self.n_spk is not None and self.n_spk > 1:
149
+ if spk_mix_dict is not None:
150
+ for k, v in spk_mix_dict.items():
151
+ spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
152
+ x = x + v * self.spk_embed(spk_id_torch)
153
+ else:
154
+ if spk_id.shape[1] > 1:
155
+ g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
156
+ g = g * self.speaker_map # [N, S, B, 1, H]
157
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
158
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
159
+ x = x + g
160
+ else:
161
+ x = x + self.spk_embed(spk_id)
162
+ if self.aug_shift_embed is not None and aug_shift is not None:
163
+ x = x + self.aug_shift_embed(aug_shift / 5)
164
+ x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
165
+
166
+ return x
167
+
diffusion/vocoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio.transforms import Resample
3
+
4
+ from vdecoder.nsf_hifigan.models import load_config, load_model
5
+ from vdecoder.nsf_hifigan.nvSTFT import STFT
6
+
7
+
8
+ class Vocoder:
9
+ def __init__(self, vocoder_type, vocoder_ckpt, device = None):
10
+ if device is None:
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ self.device = device
13
+
14
+ if vocoder_type == 'nsf-hifigan':
15
+ self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
16
+ elif vocoder_type == 'nsf-hifigan-log10':
17
+ self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
18
+ else:
19
+ raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
20
+
21
+ self.resample_kernel = {}
22
+ self.vocoder_sample_rate = self.vocoder.sample_rate()
23
+ self.vocoder_hop_size = self.vocoder.hop_size()
24
+ self.dimension = self.vocoder.dimension()
25
+
26
+ def extract(self, audio, sample_rate, keyshift=0):
27
+
28
+ # resample
29
+ if sample_rate == self.vocoder_sample_rate:
30
+ audio_res = audio
31
+ else:
32
+ key_str = str(sample_rate)
33
+ if key_str not in self.resample_kernel:
34
+ self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
35
+ audio_res = self.resample_kernel[key_str](audio)
36
+
37
+ # extract
38
+ mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
39
+ return mel
40
+
41
+ def infer(self, mel, f0):
42
+ f0 = f0[:,:mel.size(1),0] # B, n_frames
43
+ audio = self.vocoder(mel, f0)
44
+ return audio
45
+
46
+
47
+ class NsfHifiGAN(torch.nn.Module):
48
+ def __init__(self, model_path, device=None):
49
+ super().__init__()
50
+ if device is None:
51
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
+ self.device = device
53
+ self.model_path = model_path
54
+ self.model = None
55
+ self.h = load_config(model_path)
56
+ self.stft = STFT(
57
+ self.h.sampling_rate,
58
+ self.h.num_mels,
59
+ self.h.n_fft,
60
+ self.h.win_size,
61
+ self.h.hop_size,
62
+ self.h.fmin,
63
+ self.h.fmax)
64
+
65
+ def sample_rate(self):
66
+ return self.h.sampling_rate
67
+
68
+ def hop_size(self):
69
+ return self.h.hop_size
70
+
71
+ def dimension(self):
72
+ return self.h.num_mels
73
+
74
+ def extract(self, audio, keyshift=0):
75
+ mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
76
+ return mel
77
+
78
+ def forward(self, mel, f0):
79
+ if self.model is None:
80
+ print('| Load HifiGAN: ', self.model_path)
81
+ self.model, self.h = load_model(self.model_path, device=self.device)
82
+ with torch.no_grad():
83
+ c = mel.transpose(1, 2)
84
+ audio = self.model(c, f0)
85
+ return audio
86
+
87
+ class NsfHifiGANLog10(NsfHifiGAN):
88
+ def forward(self, mel, f0):
89
+ if self.model is None:
90
+ print('| Load HifiGAN: ', self.model_path)
91
+ self.model, self.h = load_model(self.model_path, device=self.device)
92
+ with torch.no_grad():
93
+ c = 0.434294 * mel.transpose(1, 2)
94
+ audio = self.model(c, f0)
95
+ return audio
diffusion/wavenet.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Mish
8
+
9
+
10
+ class Conv1d(torch.nn.Conv1d):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ nn.init.kaiming_normal_(self.weight)
14
+
15
+
16
+ class SinusoidalPosEmb(nn.Module):
17
+ def __init__(self, dim):
18
+ super().__init__()
19
+ self.dim = dim
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :]
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+ class ResidualBlock(nn.Module):
32
+ def __init__(self, encoder_hidden, residual_channels, dilation):
33
+ super().__init__()
34
+ self.residual_channels = residual_channels
35
+ self.dilated_conv = nn.Conv1d(
36
+ residual_channels,
37
+ 2 * residual_channels,
38
+ kernel_size=3,
39
+ padding=dilation,
40
+ dilation=dilation
41
+ )
42
+ self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
43
+ self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
44
+ self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
45
+
46
+ def forward(self, x, conditioner, diffusion_step):
47
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
48
+ conditioner = self.conditioner_projection(conditioner)
49
+ y = x + diffusion_step
50
+
51
+ y = self.dilated_conv(y) + conditioner
52
+
53
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
54
+ gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
55
+ y = torch.sigmoid(gate) * torch.tanh(filter)
56
+
57
+ y = self.output_projection(y)
58
+
59
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
60
+ residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
61
+ return (x + residual) / math.sqrt(2.0), skip
62
+
63
+
64
+ class WaveNet(nn.Module):
65
+ def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
66
+ super().__init__()
67
+ self.input_projection = Conv1d(in_dims, n_chans, 1)
68
+ self.diffusion_embedding = SinusoidalPosEmb(n_chans)
69
+ self.mlp = nn.Sequential(
70
+ nn.Linear(n_chans, n_chans * 4),
71
+ Mish(),
72
+ nn.Linear(n_chans * 4, n_chans)
73
+ )
74
+ self.residual_layers = nn.ModuleList([
75
+ ResidualBlock(
76
+ encoder_hidden=n_hidden,
77
+ residual_channels=n_chans,
78
+ dilation=1
79
+ )
80
+ for i in range(n_layers)
81
+ ])
82
+ self.skip_projection = Conv1d(n_chans, n_chans, 1)
83
+ self.output_projection = Conv1d(n_chans, in_dims, 1)
84
+ nn.init.zeros_(self.output_projection.weight)
85
+
86
+ def forward(self, spec, diffusion_step, cond):
87
+ """
88
+ :param spec: [B, 1, M, T]
89
+ :param diffusion_step: [B, 1]
90
+ :param cond: [B, M, T]
91
+ :return:
92
+ """
93
+ x = spec.squeeze(1)
94
+ x = self.input_projection(x) # [B, residual_channel, T]
95
+
96
+ x = F.relu(x)
97
+ diffusion_step = self.diffusion_embedding(diffusion_step)
98
+ diffusion_step = self.mlp(diffusion_step)
99
+ skip = []
100
+ for layer in self.residual_layers:
101
+ x, skip_connection = layer(x, cond, diffusion_step)
102
+ skip.append(skip_connection)
103
+
104
+ x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
105
+ x = self.skip_projection(x)
106
+ x = F.relu(x)
107
+ x = self.output_projection(x) # [B, mel_bins, T]
108
+ return x[:, None, :, :]
edgetts/tts.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import random
3
+ import sys
4
+
5
+ import edge_tts
6
+ from edge_tts import VoicesManager
7
+ from langdetect import DetectorFactory, detect
8
+
9
+ DetectorFactory.seed = 0
10
+
11
+ async def _main() -> None:
12
+ TEXT = sys.argv[1]
13
+ LANG = sys.argv[2]
14
+ RATE = sys.argv[3]
15
+ VOLUME = sys.argv[4]
16
+ GENDER = sys.argv[5] if 5 < len(sys.argv) else None
17
+ OUTPUT_FILE = sys.argv[6] if 6 < len(sys.argv) else "tts.wav"
18
+
19
+ print("Running TTS...")
20
+ print(f"Text: {TEXT}, Language: {LANG}, Gender: {GENDER}, Rate: {RATE}, Volume: {VOLUME}")
21
+
22
+ voices = await VoicesManager.create()
23
+ if LANG == "Auto":
24
+ LANG = detect(TEXT)
25
+ # From "zh-cn" to "zh-CN" etc.
26
+ if LANG == "zh-cn" or LANG == "zh-tw":
27
+ LOCALE = LANG[:-2] + LANG[-2:].upper()
28
+ voice = voices.find(Gender=GENDER, Locale=LOCALE)
29
+ else:
30
+ voice = voices.find(Gender=GENDER, Language=LANG)
31
+ VOICE = random.choice(voice)["Name"]
32
+ print(f"Using random {LANG} voice: {VOICE}")
33
+ else:
34
+ VOICE = LANG
35
+
36
+ communicate = edge_tts.Communicate(text = TEXT, voice = VOICE, rate = RATE, volume = VOLUME)
37
+ await communicate.save(OUTPUT_FILE)
38
+
39
+ if __name__ == "__main__":
40
+ if sys.platform.startswith("win"):
41
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
42
+ asyncio.run(_main())
43
+ else:
44
+ loop = asyncio.get_event_loop_policy().get_event_loop()
45
+ try:
46
+ loop.run_until_complete(_main())
47
+ finally:
48
+ loop.close()
edgetts/tts_voices.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #List of Supported Voices for edge_TTS
2
+ SUPPORTED_VOICES = {
3
+ 'zh-CN-XiaoxiaoNeural': 'zh-CN',
4
+ 'zh-CN-XiaoyiNeural': 'zh-CN',
5
+ 'zh-CN-YunjianNeural': 'zh-CN',
6
+ 'zh-CN-YunxiNeural': 'zh-CN',
7
+ 'zh-CN-YunxiaNeural': 'zh-CN',
8
+ 'zh-CN-YunyangNeural': 'zh-CN',
9
+ 'zh-HK-HiuGaaiNeural': 'zh-HK',
10
+ 'zh-HK-HiuMaanNeural': 'zh-HK',
11
+ 'zh-HK-WanLungNeural': 'zh-HK',
12
+ 'zh-TW-HsiaoChenNeural': 'zh-TW',
13
+ 'zh-TW-YunJheNeural': 'zh-TW',
14
+ 'zh-TW-HsiaoYuNeural': 'zh-TW',
15
+ 'af-ZA-AdriNeural': 'af-ZA',
16
+ 'af-ZA-WillemNeural': 'af-ZA',
17
+ 'am-ET-AmehaNeural': 'am-ET',
18
+ 'am-ET-MekdesNeural': 'am-ET',
19
+ 'ar-AE-FatimaNeural': 'ar-AE',
20
+ 'ar-AE-HamdanNeural': 'ar-AE',
21
+ 'ar-BH-AliNeural': 'ar-BH',
22
+ 'ar-BH-LailaNeural': 'ar-BH',
23
+ 'ar-DZ-AminaNeural': 'ar-DZ',
24
+ 'ar-DZ-IsmaelNeural': 'ar-DZ',
25
+ 'ar-EG-SalmaNeural': 'ar-EG',
26
+ 'ar-EG-ShakirNeural': 'ar-EG',
27
+ 'ar-IQ-BasselNeural': 'ar-IQ',
28
+ 'ar-IQ-RanaNeural': 'ar-IQ',
29
+ 'ar-JO-SanaNeural': 'ar-JO',
30
+ 'ar-JO-TaimNeural': 'ar-JO',
31
+ 'ar-KW-FahedNeural': 'ar-KW',
32
+ 'ar-KW-NouraNeural': 'ar-KW',
33
+ 'ar-LB-LaylaNeural': 'ar-LB',
34
+ 'ar-LB-RamiNeural': 'ar-LB',
35
+ 'ar-LY-ImanNeural': 'ar-LY',
36
+ 'ar-LY-OmarNeural': 'ar-LY',
37
+ 'ar-MA-JamalNeural': 'ar-MA',
38
+ 'ar-MA-MounaNeural': 'ar-MA',
39
+ 'ar-OM-AbdullahNeural': 'ar-OM',
40
+ 'ar-OM-AyshaNeural': 'ar-OM',
41
+ 'ar-QA-AmalNeural': 'ar-QA',
42
+ 'ar-QA-MoazNeural': 'ar-QA',
43
+ 'ar-SA-HamedNeural': 'ar-SA',
44
+ 'ar-SA-ZariyahNeural': 'ar-SA',
45
+ 'ar-SY-AmanyNeural': 'ar-SY',
46
+ 'ar-SY-LaithNeural': 'ar-SY',
47
+ 'ar-TN-HediNeural': 'ar-TN',
48
+ 'ar-TN-ReemNeural': 'ar-TN',
49
+ 'ar-YE-MaryamNeural': 'ar-YE',
50
+ 'ar-YE-SalehNeural': 'ar-YE',
51
+ 'az-AZ-BabekNeural': 'az-AZ',
52
+ 'az-AZ-BanuNeural': 'az-AZ',
53
+ 'bg-BG-BorislavNeural': 'bg-BG',
54
+ 'bg-BG-KalinaNeural': 'bg-BG',
55
+ 'bn-BD-NabanitaNeural': 'bn-BD',
56
+ 'bn-BD-PradeepNeural': 'bn-BD',
57
+ 'bn-IN-BashkarNeural': 'bn-IN',
58
+ 'bn-IN-TanishaaNeural': 'bn-IN',
59
+ 'bs-BA-GoranNeural': 'bs-BA',
60
+ 'bs-BA-VesnaNeural': 'bs-BA',
61
+ 'ca-ES-EnricNeural': 'ca-ES',
62
+ 'ca-ES-JoanaNeural': 'ca-ES',
63
+ 'cs-CZ-AntoninNeural': 'cs-CZ',
64
+ 'cs-CZ-VlastaNeural': 'cs-CZ',
65
+ 'cy-GB-AledNeural': 'cy-GB',
66
+ 'cy-GB-NiaNeural': 'cy-GB',
67
+ 'da-DK-ChristelNeural': 'da-DK',
68
+ 'da-DK-JeppeNeural': 'da-DK',
69
+ 'de-AT-IngridNeural': 'de-AT',
70
+ 'de-AT-JonasNeural': 'de-AT',
71
+ 'de-CH-JanNeural': 'de-CH',
72
+ 'de-CH-LeniNeural': 'de-CH',
73
+ 'de-DE-AmalaNeural': 'de-DE',
74
+ 'de-DE-ConradNeural': 'de-DE',
75
+ 'de-DE-KatjaNeural': 'de-DE',
76
+ 'de-DE-KillianNeural': 'de-DE',
77
+ 'el-GR-AthinaNeural': 'el-GR',
78
+ 'el-GR-NestorasNeural': 'el-GR',
79
+ 'en-AU-NatashaNeural': 'en-AU',
80
+ 'en-AU-WilliamNeural': 'en-AU',
81
+ 'en-CA-ClaraNeural': 'en-CA',
82
+ 'en-CA-LiamNeural': 'en-CA',
83
+ 'en-GB-LibbyNeural': 'en-GB',
84
+ 'en-GB-MaisieNeural': 'en-GB',
85
+ 'en-GB-RyanNeural': 'en-GB',
86
+ 'en-GB-SoniaNeural': 'en-GB',
87
+ 'en-GB-ThomasNeural': 'en-GB',
88
+ 'en-HK-SamNeural': 'en-HK',
89
+ 'en-HK-YanNeural': 'en-HK',
90
+ 'en-IE-ConnorNeural': 'en-IE',
91
+ 'en-IE-EmilyNeural': 'en-IE',
92
+ 'en-IN-NeerjaNeural': 'en-IN',
93
+ 'en-IN-PrabhatNeural': 'en-IN',
94
+ 'en-KE-AsiliaNeural': 'en-KE',
95
+ 'en-KE-ChilembaNeural': 'en-KE',
96
+ 'en-NG-AbeoNeural': 'en-NG',
97
+ 'en-NG-EzinneNeural': 'en-NG',
98
+ 'en-NZ-MitchellNeural': 'en-NZ',
99
+ 'en-NZ-MollyNeural': 'en-NZ',
100
+ 'en-PH-JamesNeural': 'en-PH',
101
+ 'en-PH-RosaNeural': 'en-PH',
102
+ 'en-SG-LunaNeural': 'en-SG',
103
+ 'en-SG-WayneNeural': 'en-SG',
104
+ 'en-TZ-ElimuNeural': 'en-TZ',
105
+ 'en-TZ-ImaniNeural': 'en-TZ',
106
+ 'en-US-AnaNeural': 'en-US',
107
+ 'en-US-AriaNeural': 'en-US',
108
+ 'en-US-ChristopherNeural': 'en-US',
109
+ 'en-US-EricNeural': 'en-US',
110
+ 'en-US-GuyNeural': 'en-US',
111
+ 'en-US-JennyNeural': 'en-US',
112
+ 'en-US-MichelleNeural': 'en-US',
113
+ 'en-ZA-LeahNeural': 'en-ZA',
114
+ 'en-ZA-LukeNeural': 'en-ZA',
115
+ 'es-AR-ElenaNeural': 'es-AR',
116
+ 'es-AR-TomasNeural': 'es-AR',
117
+ 'es-BO-MarceloNeural': 'es-BO',
118
+ 'es-BO-SofiaNeural': 'es-BO',
119
+ 'es-CL-CatalinaNeural': 'es-CL',
120
+ 'es-CL-LorenzoNeural': 'es-CL',
121
+ 'es-CO-GonzaloNeural': 'es-CO',
122
+ 'es-CO-SalomeNeural': 'es-CO',
123
+ 'es-CR-JuanNeural': 'es-CR',
124
+ 'es-CR-MariaNeural': 'es-CR',
125
+ 'es-CU-BelkysNeural': 'es-CU',
126
+ 'es-CU-ManuelNeural': 'es-CU',
127
+ 'es-DO-EmilioNeural': 'es-DO',
128
+ 'es-DO-RamonaNeural': 'es-DO',
129
+ 'es-EC-AndreaNeural': 'es-EC',
130
+ 'es-EC-LuisNeural': 'es-EC',
131
+ 'es-ES-AlvaroNeural': 'es-ES',
132
+ 'es-ES-ElviraNeural': 'es-ES',
133
+ 'es-ES-ManuelEsCUNeural': 'es-ES',
134
+ 'es-GQ-JavierNeural': 'es-GQ',
135
+ 'es-GQ-TeresaNeural': 'es-GQ',
136
+ 'es-GT-AndresNeural': 'es-GT',
137
+ 'es-GT-MartaNeural': 'es-GT',
138
+ 'es-HN-CarlosNeural': 'es-HN',
139
+ 'es-HN-KarlaNeural': 'es-HN',
140
+ 'es-MX-DaliaNeural': 'es-MX',
141
+ 'es-MX-JorgeNeural': 'es-MX',
142
+ 'es-MX-LorenzoEsCLNeural': 'es-MX',
143
+ 'es-NI-FedericoNeural': 'es-NI',
144
+ 'es-NI-YolandaNeural': 'es-NI',
145
+ 'es-PA-MargaritaNeural': 'es-PA',
146
+ 'es-PA-RobertoNeural': 'es-PA',
147
+ 'es-PE-AlexNeural': 'es-PE',
148
+ 'es-PE-CamilaNeural': 'es-PE',
149
+ 'es-PR-KarinaNeural': 'es-PR',
150
+ 'es-PR-VictorNeural': 'es-PR',
151
+ 'es-PY-MarioNeural': 'es-PY',
152
+ 'es-PY-TaniaNeural': 'es-PY',
153
+ 'es-SV-LorenaNeural': 'es-SV',
154
+ 'es-SV-RodrigoNeural': 'es-SV',
155
+ 'es-US-AlonsoNeural': 'es-US',
156
+ 'es-US-PalomaNeural': 'es-US',
157
+ 'es-UY-MateoNeural': 'es-UY',
158
+ 'es-UY-ValentinaNeural': 'es-UY',
159
+ 'es-VE-PaolaNeural': 'es-VE',
160
+ 'es-VE-SebastianNeural': 'es-VE',
161
+ 'et-EE-AnuNeural': 'et-EE',
162
+ 'et-EE-KertNeural': 'et-EE',
163
+ 'fa-IR-DilaraNeural': 'fa-IR',
164
+ 'fa-IR-FaridNeural': 'fa-IR',
165
+ 'fi-FI-HarriNeural': 'fi-FI',
166
+ 'fi-FI-NooraNeural': 'fi-FI',
167
+ 'fil-PH-AngeloNeural': 'fil-PH',
168
+ 'fil-PH-BlessicaNeural': 'fil-PH',
169
+ 'fr-BE-CharlineNeural': 'fr-BE',
170
+ 'fr-BE-GerardNeural': 'fr-BE',
171
+ 'fr-CA-AntoineNeural': 'fr-CA',
172
+ 'fr-CA-JeanNeural': 'fr-CA',
173
+ 'fr-CA-SylvieNeural': 'fr-CA',
174
+ 'fr-CH-ArianeNeural': 'fr-CH',
175
+ 'fr-CH-FabriceNeural': 'fr-CH',
176
+ 'fr-FR-DeniseNeural': 'fr-FR',
177
+ 'fr-FR-EloiseNeural': 'fr-FR',
178
+ 'fr-FR-HenriNeural': 'fr-FR',
179
+ 'ga-IE-ColmNeural': 'ga-IE',
180
+ 'ga-IE-OrlaNeural': 'ga-IE',
181
+ 'gl-ES-RoiNeural': 'gl-ES',
182
+ 'gl-ES-SabelaNeural': 'gl-ES',
183
+ 'gu-IN-DhwaniNeural': 'gu-IN',
184
+ 'gu-IN-NiranjanNeural': 'gu-IN',
185
+ 'he-IL-AvriNeural': 'he-IL',
186
+ 'he-IL-HilaNeural': 'he-IL',
187
+ 'hi-IN-MadhurNeural': 'hi-IN',
188
+ 'hi-IN-SwaraNeural': 'hi-IN',
189
+ 'hr-HR-GabrijelaNeural': 'hr-HR',
190
+ 'hr-HR-SreckoNeural': 'hr-HR',
191
+ 'hu-HU-NoemiNeural': 'hu-HU',
192
+ 'hu-HU-TamasNeural': 'hu-HU',
193
+ 'id-ID-ArdiNeural': 'id-ID',
194
+ 'id-ID-GadisNeural': 'id-ID',
195
+ 'is-IS-GudrunNeural': 'is-IS',
196
+ 'is-IS-GunnarNeural': 'is-IS',
197
+ 'it-IT-DiegoNeural': 'it-IT',
198
+ 'it-IT-ElsaNeural': 'it-IT',
199
+ 'it-IT-IsabellaNeural': 'it-IT',
200
+ 'ja-JP-KeitaNeural': 'ja-JP',
201
+ 'ja-JP-NanamiNeural': 'ja-JP',
202
+ 'jv-ID-DimasNeural': 'jv-ID',
203
+ 'jv-ID-SitiNeural': 'jv-ID',
204
+ 'ka-GE-EkaNeural': 'ka-GE',
205
+ 'ka-GE-GiorgiNeural': 'ka-GE',
206
+ 'kk-KZ-AigulNeural': 'kk-KZ',
207
+ 'kk-KZ-DauletNeural': 'kk-KZ',
208
+ 'km-KH-PisethNeural': 'km-KH',
209
+ 'km-KH-SreymomNeural': 'km-KH',
210
+ 'kn-IN-GaganNeural': 'kn-IN',
211
+ 'kn-IN-SapnaNeural': 'kn-IN',
212
+ 'ko-KR-InJoonNeural': 'ko-KR',
213
+ 'ko-KR-SunHiNeural': 'ko-KR',
214
+ 'lo-LA-ChanthavongNeural': 'lo-LA',
215
+ 'lo-LA-KeomanyNeural': 'lo-LA',
216
+ 'lt-LT-LeonasNeural': 'lt-LT',
217
+ 'lt-LT-OnaNeural': 'lt-LT',
218
+ 'lv-LV-EveritaNeural': 'lv-LV',
219
+ 'lv-LV-NilsNeural': 'lv-LV',
220
+ 'mk-MK-AleksandarNeural': 'mk-MK',
221
+ 'mk-MK-MarijaNeural': 'mk-MK',
222
+ 'ml-IN-MidhunNeural': 'ml-IN',
223
+ 'ml-IN-SobhanaNeural': 'ml-IN',
224
+ 'mn-MN-BataaNeural': 'mn-MN',
225
+ 'mn-MN-YesuiNeural': 'mn-MN',
226
+ 'mr-IN-AarohiNeural': 'mr-IN',
227
+ 'mr-IN-ManoharNeural': 'mr-IN',
228
+ 'ms-MY-OsmanNeural': 'ms-MY',
229
+ 'ms-MY-YasminNeural': 'ms-MY',
230
+ 'mt-MT-GraceNeural': 'mt-MT',
231
+ 'mt-MT-JosephNeural': 'mt-MT',
232
+ 'my-MM-NilarNeural': 'my-MM',
233
+ 'my-MM-ThihaNeural': 'my-MM',
234
+ 'nb-NO-FinnNeural': 'nb-NO',
235
+ 'nb-NO-PernilleNeural': 'nb-NO',
236
+ 'ne-NP-HemkalaNeural': 'ne-NP',
237
+ 'ne-NP-SagarNeural': 'ne-NP',
238
+ 'nl-BE-ArnaudNeural': 'nl-BE',
239
+ 'nl-BE-DenaNeural': 'nl-BE',
240
+ 'nl-NL-ColetteNeural': 'nl-NL',
241
+ 'nl-NL-FennaNeural': 'nl-NL',
242
+ 'nl-NL-MaartenNeural': 'nl-NL',
243
+ 'pl-PL-MarekNeural': 'pl-PL',
244
+ 'pl-PL-ZofiaNeural': 'pl-PL',
245
+ 'ps-AF-GulNawazNeural': 'ps-AF',
246
+ 'ps-AF-LatifaNeural': 'ps-AF',
247
+ 'pt-BR-AntonioNeural': 'pt-BR',
248
+ 'pt-BR-FranciscaNeural': 'pt-BR',
249
+ 'pt-PT-DuarteNeural': 'pt-PT',
250
+ 'pt-PT-RaquelNeural': 'pt-PT',
251
+ 'ro-RO-AlinaNeural': 'ro-RO',
252
+ 'ro-RO-EmilNeural': 'ro-RO',
253
+ 'ru-RU-DmitryNeural': 'ru-RU',
254
+ 'ru-RU-SvetlanaNeural': 'ru-RU',
255
+ 'si-LK-SameeraNeural': 'si-LK',
256
+ 'si-LK-ThiliniNeural': 'si-LK',
257
+ 'sk-SK-LukasNeural': 'sk-SK',
258
+ 'sk-SK-ViktoriaNeural': 'sk-SK',
259
+ 'sl-SI-PetraNeural': 'sl-SI',
260
+ 'sl-SI-RokNeural': 'sl-SI',
261
+ 'so-SO-MuuseNeural': 'so-SO',
262
+ 'so-SO-UbaxNeural': 'so-SO',
263
+ 'sq-AL-AnilaNeural': 'sq-AL',
264
+ 'sq-AL-IlirNeural': 'sq-AL',
265
+ 'sr-RS-NicholasNeural': 'sr-RS',
266
+ 'sr-RS-SophieNeural': 'sr-RS',
267
+ 'su-ID-JajangNeural': 'su-ID',
268
+ 'su-ID-TutiNeural': 'su-ID',
269
+ 'sv-SE-MattiasNeural': 'sv-SE',
270
+ 'sv-SE-SofieNeural': 'sv-SE',
271
+ 'sw-KE-RafikiNeural': 'sw-KE',
272
+ 'sw-KE-ZuriNeural': 'sw-KE',
273
+ 'sw-TZ-DaudiNeural': 'sw-TZ',
274
+ 'sw-TZ-RehemaNeural': 'sw-TZ',
275
+ 'ta-IN-PallaviNeural': 'ta-IN',
276
+ 'ta-IN-ValluvarNeural': 'ta-IN',
277
+ 'ta-LK-KumarNeural': 'ta-LK',
278
+ 'ta-LK-SaranyaNeural': 'ta-LK',
279
+ 'ta-MY-KaniNeural': 'ta-MY',
280
+ 'ta-MY-SuryaNeural': 'ta-MY',
281
+ 'ta-SG-AnbuNeural': 'ta-SG',
282
+ 'ta-SG-VenbaNeural': 'ta-SG',
283
+ 'te-IN-MohanNeural': 'te-IN',
284
+ 'te-IN-ShrutiNeural': 'te-IN',
285
+ 'th-TH-NiwatNeural': 'th-TH',
286
+ 'th-TH-PremwadeeNeural': 'th-TH',
287
+ 'tr-TR-AhmetNeural': 'tr-TR',
288
+ 'tr-TR-EmelNeural': 'tr-TR',
289
+ 'uk-UA-OstapNeural': 'uk-UA',
290
+ 'uk-UA-PolinaNeural': 'uk-UA',
291
+ 'ur-IN-GulNeural': 'ur-IN',
292
+ 'ur-IN-SalmanNeural': 'ur-IN',
293
+ 'ur-PK-AsadNeural': 'ur-PK',
294
+ 'ur-PK-UzmaNeural': 'ur-PK',
295
+ 'uz-UZ-MadinaNeural': 'uz-UZ',
296
+ 'uz-UZ-SardorNeural': 'uz-UZ',
297
+ 'vi-VN-HoaiMyNeural': 'vi-VN',
298
+ 'vi-VN-NamMinhNeural': 'vi-VN',
299
+ 'zu-ZA-ThandoNeural': 'zu-ZA',
300
+ 'zu-ZA-ThembaNeural': 'zu-ZA',
301
+ }
302
+
303
+ SUPPORTED_LANGUAGES = [
304
+ "Auto",
305
+ *SUPPORTED_VOICES.keys()
306
+ ]
export_index_for_onnx.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import faiss
5
+
6
+ path = "crs"
7
+ indexs_file_path = f"checkpoints/{path}/feature_and_index.pkl"
8
+ indexs_out_dir = f"checkpoints/{path}/"
9
+
10
+ with open("feature_and_index.pkl",mode="rb") as f:
11
+ indexs = pickle.load(f)
12
+
13
+ for k in indexs:
14
+ print(f"Save {k} index")
15
+ faiss.write_index(
16
+ indexs[k],
17
+ os.path.join(indexs_out_dir,f"Index-{k}.index")
18
+ )
19
+
20
+ print("Saved all index")
flask_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+ import soundfile
5
+ import torch
6
+ import torchaudio
7
+ from flask import Flask, request, send_file
8
+ from flask_cors import CORS
9
+
10
+ from inference.infer_tool import RealTimeVC, Svc
11
+
12
+ app = Flask(__name__)
13
+
14
+ CORS(app)
15
+
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+
18
+
19
+ @app.route("/voiceChangeModel", methods=["POST"])
20
+ def voice_change_model():
21
+ request_form = request.form
22
+ wave_file = request.files.get("sample", None)
23
+ # 变调信息
24
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
25
+ # DAW所需的采样率
26
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
27
+ speaker_id = int(float(request_form.get("sSpeakId", 0)))
28
+ # http获得wav文件并转换
29
+ input_wav_path = io.BytesIO(wave_file.read())
30
+
31
+ # 模型推理
32
+ if raw_infer:
33
+ # out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path)
34
+ out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0,
35
+ auto_predict_f0=False, noice_scale=0.4, f0_filter=False)
36
+ tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample)
37
+ else:
38
+ out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0,
39
+ auto_predict_f0=False, noice_scale=0.4, f0_filter=False)
40
+ tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample)
41
+ # 返回音频
42
+ out_wav_path = io.BytesIO()
43
+ soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav")
44
+ out_wav_path.seek(0)
45
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
46
+
47
+
48
+ if __name__ == '__main__':
49
+ # 启用则为直接切片合成,False为交叉淡化方式
50
+ # vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音
51
+ # 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些
52
+ raw_infer = True
53
+ # 每个模型和config是唯一对应的
54
+ model_name = "logs/32k/G_174000-Copy1.pth"
55
+ config_name = "configs/config.json"
56
+ cluster_model_path = "logs/44k/kmeans_10000.pt"
57
+ svc_model = Svc(model_name, config_name, cluster_model_path=cluster_model_path)
58
+ svc = RealTimeVC()
59
+ # 此处与vst插件对应,不建议更改
60
+ app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
flask_api_full_song.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import numpy as np
4
+ import soundfile
5
+ from flask import Flask, request, send_file
6
+
7
+ from inference import infer_tool, slicer
8
+
9
+ app = Flask(__name__)
10
+
11
+
12
+ @app.route("/wav2wav", methods=["POST"])
13
+ def wav2wav():
14
+ request_form = request.form
15
+ audio_path = request_form.get("audio_path", None) # wav文件地址
16
+ tran = int(float(request_form.get("tran", 0))) # 音调
17
+ spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config)
18
+ wav_format = request_form.get("wav_format", 'wav') # 范围文件格式
19
+ infer_tool.format_wav(audio_path)
20
+ chunks = slicer.cut(audio_path, db_thresh=-40)
21
+ audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)
22
+
23
+ audio = []
24
+ for (slice_tag, data) in audio_data:
25
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
26
+
27
+ length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
28
+ if slice_tag:
29
+ print('jump empty segment')
30
+ _audio = np.zeros(length)
31
+ else:
32
+ # padd
33
+ pad_len = int(audio_sr * 0.5)
34
+ data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
35
+ raw_path = io.BytesIO()
36
+ soundfile.write(raw_path, data, audio_sr, format="wav")
37
+ raw_path.seek(0)
38
+ out_audio, out_sr = svc_model.infer(spk, tran, raw_path)
39
+ svc_model.clear_empty()
40
+ _audio = out_audio.cpu().numpy()
41
+ pad_len = int(svc_model.target_sample * 0.5)
42
+ _audio = _audio[pad_len:-pad_len]
43
+
44
+ audio.extend(list(infer_tool.pad_array(_audio, length)))
45
+ out_wav_path = io.BytesIO()
46
+ soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
47
+ out_wav_path.seek(0)
48
+ return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True)
49
+
50
+
51
+ if __name__ == '__main__':
52
+ model_name = "logs/44k/G_60000.pth" # 模型地址
53
+ config_name = "configs/config.json" # config地址
54
+ svc_model = infer_tool.Svc(model_name, config_name)
55
+ app.run(port=1145, host="0.0.0.0", debug=False, threaded=False)
inference/__init__.py ADDED
File without changes
inference/infer_tool.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import io
4
+ import json
5
+ import logging
6
+ import os
7
+ import pickle
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import librosa
12
+ import numpy as np
13
+
14
+ # import onnxruntime
15
+ import soundfile
16
+ import torch
17
+ import torchaudio
18
+
19
+ import cluster
20
+ import utils
21
+ from diffusion.unit2mel import load_model_vocoder
22
+ from inference import slicer
23
+ from models import SynthesizerTrn
24
+
25
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
26
+
27
+
28
+ def read_temp(file_name):
29
+ if not os.path.exists(file_name):
30
+ with open(file_name, "w") as f:
31
+ f.write(json.dumps({"info": "temp_dict"}))
32
+ return {}
33
+ else:
34
+ try:
35
+ with open(file_name, "r") as f:
36
+ data = f.read()
37
+ data_dict = json.loads(data)
38
+ if os.path.getsize(file_name) > 50 * 1024 * 1024:
39
+ f_name = file_name.replace("\\", "/").split("/")[-1]
40
+ print(f"clean {f_name}")
41
+ for wav_hash in list(data_dict.keys()):
42
+ if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
43
+ del data_dict[wav_hash]
44
+ except Exception as e:
45
+ print(e)
46
+ print(f"{file_name} error,auto rebuild file")
47
+ data_dict = {"info": "temp_dict"}
48
+ return data_dict
49
+
50
+
51
+ def write_temp(file_name, data):
52
+ with open(file_name, "w") as f:
53
+ f.write(json.dumps(data))
54
+
55
+
56
+ def timeit(func):
57
+ def run(*args, **kwargs):
58
+ t = time.time()
59
+ res = func(*args, **kwargs)
60
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
61
+ return res
62
+
63
+ return run
64
+
65
+
66
+ def format_wav(audio_path):
67
+ if Path(audio_path).suffix == '.wav':
68
+ return
69
+ raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
70
+ soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
71
+
72
+
73
+ def get_end_file(dir_path, end):
74
+ file_lists = []
75
+ for root, dirs, files in os.walk(dir_path):
76
+ files = [f for f in files if f[0] != '.']
77
+ dirs[:] = [d for d in dirs if d[0] != '.']
78
+ for f_file in files:
79
+ if f_file.endswith(end):
80
+ file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
81
+ return file_lists
82
+
83
+
84
+ def get_md5(content):
85
+ return hashlib.new("md5", content).hexdigest()
86
+
87
+ def fill_a_to_b(a, b):
88
+ if len(a) < len(b):
89
+ for _ in range(0, len(b) - len(a)):
90
+ a.append(a[0])
91
+
92
+ def mkdir(paths: list):
93
+ for path in paths:
94
+ if not os.path.exists(path):
95
+ os.mkdir(path)
96
+
97
+ def pad_array(arr, target_length):
98
+ current_length = arr.shape[0]
99
+ if current_length >= target_length:
100
+ return arr
101
+ else:
102
+ pad_width = target_length - current_length
103
+ pad_left = pad_width // 2
104
+ pad_right = pad_width - pad_left
105
+ padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
106
+ return padded_arr
107
+
108
+ def split_list_by_n(list_collection, n, pre=0):
109
+ for i in range(0, len(list_collection), n):
110
+ yield list_collection[i-pre if i-pre>=0 else i: i + n]
111
+
112
+
113
+ class F0FilterException(Exception):
114
+ pass
115
+
116
+ class Svc(object):
117
+ def __init__(self, net_g_path, config_path,
118
+ device=None,
119
+ cluster_model_path="logs/44k/kmeans_10000.pt",
120
+ nsf_hifigan_enhance = False,
121
+ diffusion_model_path="logs/44k/diffusion/model_0.pt",
122
+ diffusion_config_path="configs/diffusion.yaml",
123
+ shallow_diffusion = False,
124
+ only_diffusion = False,
125
+ spk_mix_enable = False,
126
+ feature_retrieval = False
127
+ ):
128
+ self.net_g_path = net_g_path
129
+ self.only_diffusion = only_diffusion
130
+ self.shallow_diffusion = shallow_diffusion
131
+ self.feature_retrieval = feature_retrieval
132
+ if device is None:
133
+ self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+ else:
135
+ self.dev = torch.device(device)
136
+ self.net_g_ms = None
137
+ if not self.only_diffusion:
138
+ self.hps_ms = utils.get_hparams_from_file(config_path,True)
139
+ self.target_sample = self.hps_ms.data.sampling_rate
140
+ self.hop_size = self.hps_ms.data.hop_length
141
+ self.spk2id = self.hps_ms.spk
142
+ self.unit_interpolate_mode = self.hps_ms.data.unit_interpolate_mode if self.hps_ms.data.unit_interpolate_mode is not None else 'left'
143
+ self.vol_embedding = self.hps_ms.model.vol_embedding if self.hps_ms.model.vol_embedding is not None else False
144
+ self.speech_encoder = self.hps_ms.model.speech_encoder if self.hps_ms.model.speech_encoder is not None else 'vec768l12'
145
+
146
+ self.nsf_hifigan_enhance = nsf_hifigan_enhance
147
+ if self.shallow_diffusion or self.only_diffusion:
148
+ if os.path.exists(diffusion_model_path) and os.path.exists(diffusion_model_path):
149
+ self.diffusion_model,self.vocoder,self.diffusion_args = load_model_vocoder(diffusion_model_path,self.dev,config_path=diffusion_config_path)
150
+ if self.only_diffusion:
151
+ self.target_sample = self.diffusion_args.data.sampling_rate
152
+ self.hop_size = self.diffusion_args.data.block_size
153
+ self.spk2id = self.diffusion_args.spk
154
+ self.dtype = torch.float32
155
+ self.speech_encoder = self.diffusion_args.data.encoder
156
+ self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode is not None else 'left'
157
+ if spk_mix_enable:
158
+ self.diffusion_model.init_spkmix(len(self.spk2id))
159
+ else:
160
+ print("No diffusion model or config found. Shallow diffusion mode will False")
161
+ self.shallow_diffusion = self.only_diffusion = False
162
+
163
+ # load hubert and model
164
+ if not self.only_diffusion:
165
+ self.load_model(spk_mix_enable)
166
+ self.hubert_model = utils.get_speech_encoder(self.speech_encoder,device=self.dev)
167
+ self.volume_extractor = utils.Volume_Extractor(self.hop_size)
168
+ else:
169
+ self.hubert_model = utils.get_speech_encoder(self.diffusion_args.data.encoder,device=self.dev)
170
+ self.volume_extractor = utils.Volume_Extractor(self.diffusion_args.data.block_size)
171
+
172
+ if os.path.exists(cluster_model_path):
173
+ if self.feature_retrieval:
174
+ with open(cluster_model_path,"rb") as f:
175
+ self.cluster_model = pickle.load(f)
176
+ self.big_npy = None
177
+ self.now_spk_id = -1
178
+ else:
179
+ self.cluster_model = cluster.get_cluster_model(cluster_model_path)
180
+ else:
181
+ self.feature_retrieval=False
182
+
183
+ if self.shallow_diffusion :
184
+ self.nsf_hifigan_enhance = False
185
+ if self.nsf_hifigan_enhance:
186
+ from modules.enhancer import Enhancer
187
+ self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev)
188
+
189
+ def load_model(self, spk_mix_enable=False):
190
+ # get model configuration
191
+ self.net_g_ms = SynthesizerTrn(
192
+ self.hps_ms.data.filter_length // 2 + 1,
193
+ self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
194
+ **self.hps_ms.model)
195
+ _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
196
+ self.dtype = list(self.net_g_ms.parameters())[0].dtype
197
+ if "half" in self.net_g_path and torch.cuda.is_available():
198
+ _ = self.net_g_ms.half().eval().to(self.dev)
199
+ else:
200
+ _ = self.net_g_ms.eval().to(self.dev)
201
+ if spk_mix_enable:
202
+ self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev)
203
+
204
+ def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05):
205
+
206
+ if not hasattr(self,"f0_predictor_object") or self.f0_predictor_object is None or f0_predictor != self.f0_predictor_object.name:
207
+ self.f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold)
208
+ f0, uv = self.f0_predictor_object.compute_f0_uv(wav)
209
+
210
+ if f0_filter and sum(f0) == 0:
211
+ raise F0FilterException("No voice detected")
212
+ f0 = torch.FloatTensor(f0).to(self.dev)
213
+ uv = torch.FloatTensor(uv).to(self.dev)
214
+
215
+ f0 = f0 * 2 ** (tran / 12)
216
+ f0 = f0.unsqueeze(0)
217
+ uv = uv.unsqueeze(0)
218
+
219
+ wav = torch.from_numpy(wav).to(self.dev)
220
+ if not hasattr(self,"audio16k_resample_transform"):
221
+ self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev)
222
+ wav16k = self.audio16k_resample_transform(wav[None,:])[0]
223
+
224
+ c = self.hubert_model.encoder(wav16k)
225
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
226
+
227
+ if cluster_infer_ratio !=0:
228
+ if self.feature_retrieval:
229
+ speaker_id = self.spk2id.get(speaker)
230
+ if not speaker_id and type(speaker) is int:
231
+ if len(self.spk2id.__dict__) >= speaker:
232
+ speaker_id = speaker
233
+ if speaker_id is None:
234
+ raise RuntimeError("The name you entered is not in the speaker list!")
235
+ feature_index = self.cluster_model[speaker_id]
236
+ feat_np = np.ascontiguousarray(c.transpose(0,1).cpu().numpy())
237
+ if self.big_npy is None or self.now_spk_id != speaker_id:
238
+ self.big_npy = feature_index.reconstruct_n(0, feature_index.ntotal)
239
+ self.now_spk_id = speaker_id
240
+ print("starting feature retrieval...")
241
+ score, ix = feature_index.search(feat_np, k=8)
242
+ weight = np.square(1 / score)
243
+ weight /= weight.sum(axis=1, keepdims=True)
244
+ npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
245
+ c = cluster_infer_ratio * npy + (1 - cluster_infer_ratio) * feat_np
246
+ c = torch.FloatTensor(c).to(self.dev).transpose(0,1)
247
+ print("end feature retrieval...")
248
+ else:
249
+ cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, speaker).T
250
+ cluster_c = torch.FloatTensor(cluster_c).to(self.dev)
251
+ c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
252
+
253
+ c = c.unsqueeze(0)
254
+ return c, f0, uv
255
+
256
+ def infer(self, speaker, tran, raw_path,
257
+ cluster_infer_ratio=0,
258
+ auto_predict_f0=False,
259
+ noice_scale=0.4,
260
+ f0_filter=False,
261
+ f0_predictor='pm',
262
+ enhancer_adaptive_key = 0,
263
+ cr_threshold = 0.05,
264
+ k_step = 100,
265
+ frame = 0,
266
+ spk_mix = False,
267
+ second_encoding = False,
268
+ loudness_envelope_adjustment = 1
269
+ ):
270
+ torchaudio.set_audio_backend("soundfile")
271
+ wav, sr = torchaudio.load(raw_path)
272
+ if not hasattr(self,"audio_resample_transform") or self.audio16k_resample_transform.orig_freq != sr:
273
+ self.audio_resample_transform = torchaudio.transforms.Resample(sr,self.target_sample)
274
+ wav = self.audio_resample_transform(wav).numpy()[0]
275
+ if spk_mix:
276
+ c, f0, uv = self.get_unit_f0(wav, tran, 0, None, f0_filter,f0_predictor,cr_threshold=cr_threshold)
277
+ n_frames = f0.size(1)
278
+ sid = speaker[:, frame:frame+n_frames].transpose(0,1)
279
+ else:
280
+ speaker_id = self.spk2id.get(speaker)
281
+ if not speaker_id and type(speaker) is int:
282
+ if len(self.spk2id.__dict__) >= speaker:
283
+ speaker_id = speaker
284
+ if speaker_id is None:
285
+ raise RuntimeError("The name you entered is not in the speaker list!")
286
+ sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
287
+ c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold)
288
+ n_frames = f0.size(1)
289
+ c = c.to(self.dtype)
290
+ f0 = f0.to(self.dtype)
291
+ uv = uv.to(self.dtype)
292
+ with torch.no_grad():
293
+ start = time.time()
294
+ vol = None
295
+ if not self.only_diffusion:
296
+ vol = self.volume_extractor.extract(torch.FloatTensor(wav).to(self.dev)[None,:])[None,:].to(self.dev) if self.vol_embedding else None
297
+ audio,f0 = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale,vol=vol)
298
+ audio = audio[0,0].data.float()
299
+ audio_mel = self.vocoder.extract(audio[None,:],self.target_sample) if self.shallow_diffusion else None
300
+ else:
301
+ audio = torch.FloatTensor(wav).to(self.dev)
302
+ audio_mel = None
303
+ if self.dtype != torch.float32:
304
+ c = c.to(torch.float32)
305
+ f0 = f0.to(torch.float32)
306
+ uv = uv.to(torch.float32)
307
+ if self.only_diffusion or self.shallow_diffusion:
308
+ vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None]
309
+ if self.shallow_diffusion and second_encoding:
310
+ if not hasattr(self,"audio16k_resample_transform"):
311
+ self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev)
312
+ audio16k = self.audio16k_resample_transform(audio[None,:])[0]
313
+ c = self.hubert_model.encoder(audio16k)
314
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
315
+ f0 = f0[:,:,None]
316
+ c = c.transpose(-1,-2)
317
+ audio_mel = self.diffusion_model(
318
+ c,
319
+ f0,
320
+ vol,
321
+ spk_id = sid,
322
+ spk_mix_dict = None,
323
+ gt_spec=audio_mel,
324
+ infer=True,
325
+ infer_speedup=self.diffusion_args.infer.speedup,
326
+ method=self.diffusion_args.infer.method,
327
+ k_step=k_step)
328
+ audio = self.vocoder.infer(audio_mel, f0).squeeze()
329
+ if self.nsf_hifigan_enhance:
330
+ audio, _ = self.enhancer.enhance(
331
+ audio[None,:],
332
+ self.target_sample,
333
+ f0[:,:,None],
334
+ self.hps_ms.data.hop_length,
335
+ adaptive_key = enhancer_adaptive_key)
336
+ if loudness_envelope_adjustment != 1:
337
+ audio = utils.change_rms(wav,self.target_sample,audio,self.target_sample,loudness_envelope_adjustment)
338
+ use_time = time.time() - start
339
+ print("vits use time:{}".format(use_time))
340
+ return audio, audio.shape[-1], n_frames
341
+
342
+ def clear_empty(self):
343
+ # clean up vram
344
+ torch.cuda.empty_cache()
345
+
346
+ def unload_model(self):
347
+ # unload model
348
+ self.net_g_ms = self.net_g_ms.to("cpu")
349
+ del self.net_g_ms
350
+ if hasattr(self,"enhancer"):
351
+ self.enhancer.enhancer = self.enhancer.enhancer.to("cpu")
352
+ del self.enhancer.enhancer
353
+ del self.enhancer
354
+ gc.collect()
355
+
356
+ def slice_inference(self,
357
+ raw_audio_path,
358
+ spk,
359
+ tran,
360
+ slice_db,
361
+ cluster_infer_ratio,
362
+ auto_predict_f0,
363
+ noice_scale,
364
+ pad_seconds=0.5,
365
+ clip_seconds=0,
366
+ lg_num=0,
367
+ lgr_num =0.75,
368
+ f0_predictor='pm',
369
+ enhancer_adaptive_key = 0,
370
+ cr_threshold = 0.05,
371
+ k_step = 100,
372
+ use_spk_mix = False,
373
+ second_encoding = False,
374
+ loudness_envelope_adjustment = 1
375
+ ):
376
+ if use_spk_mix:
377
+ if len(self.spk2id) == 1:
378
+ spk = self.spk2id.keys()[0]
379
+ use_spk_mix = False
380
+ wav_path = Path(raw_audio_path).with_suffix('.wav')
381
+ chunks = slicer.cut(wav_path, db_thresh=slice_db)
382
+ audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
383
+ per_size = int(clip_seconds*audio_sr)
384
+ lg_size = int(lg_num*audio_sr)
385
+ lg_size_r = int(lg_size*lgr_num)
386
+ lg_size_c_l = (lg_size-lg_size_r)//2
387
+ lg_size_c_r = lg_size-lg_size_r-lg_size_c_l
388
+ lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0
389
+
390
+ if use_spk_mix:
391
+ assert len(self.spk2id) == len(spk)
392
+ audio_length = 0
393
+ for (slice_tag, data) in audio_data:
394
+ aud_length = int(np.ceil(len(data) / audio_sr * self.target_sample))
395
+ if slice_tag:
396
+ audio_length += aud_length // self.hop_size
397
+ continue
398
+ if per_size != 0:
399
+ datas = split_list_by_n(data, per_size,lg_size)
400
+ else:
401
+ datas = [data]
402
+ for k,dat in enumerate(datas):
403
+ pad_len = int(audio_sr * pad_seconds)
404
+ per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample))
405
+ a_length = per_length + 2 * pad_len
406
+ audio_length += a_length // self.hop_size
407
+ audio_length += len(audio_data)
408
+ spk_mix_tensor = torch.zeros(size=(len(spk), audio_length)).to(self.dev)
409
+ for i in range(len(spk)):
410
+ last_end = None
411
+ for mix in spk[i]:
412
+ if mix[3]<0. or mix[2]<0.:
413
+ raise RuntimeError("mix value must higer Than zero!")
414
+ begin = int(audio_length * mix[0])
415
+ end = int(audio_length * mix[1])
416
+ length = end - begin
417
+ if length<=0:
418
+ raise RuntimeError("begin Must lower Than end!")
419
+ step = (mix[3] - mix[2])/length
420
+ if last_end is not None:
421
+ if last_end != begin:
422
+ raise RuntimeError("[i]EndTime Must Equal [i+1]BeginTime!")
423
+ last_end = end
424
+ if step == 0.:
425
+ spk_mix_data = torch.zeros(length).to(self.dev) + mix[2]
426
+ else:
427
+ spk_mix_data = torch.arange(mix[2],mix[3],step).to(self.dev)
428
+ if(len(spk_mix_data)<length):
429
+ num_pad = length - len(spk_mix_data)
430
+ spk_mix_data = torch.nn.functional.pad(spk_mix_data, [0, num_pad], mode="reflect").to(self.dev)
431
+ spk_mix_tensor[i][begin:end] = spk_mix_data[:length]
432
+
433
+ spk_mix_ten = torch.sum(spk_mix_tensor,dim=0).unsqueeze(0).to(self.dev)
434
+ # spk_mix_tensor[0][spk_mix_ten<0.001] = 1.0
435
+ for i, x in enumerate(spk_mix_ten[0]):
436
+ if x == 0.0:
437
+ spk_mix_ten[0][i] = 1.0
438
+ spk_mix_tensor[:,i] = 1.0 / len(spk)
439
+ spk_mix_tensor = spk_mix_tensor / spk_mix_ten
440
+ if not ((torch.sum(spk_mix_tensor,dim=0) - 1.)<0.0001).all():
441
+ raise RuntimeError("sum(spk_mix_tensor) not equal 1")
442
+ spk = spk_mix_tensor
443
+
444
+ global_frame = 0
445
+ audio = []
446
+ for (slice_tag, data) in audio_data:
447
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
448
+ # padd
449
+ length = int(np.ceil(len(data) / audio_sr * self.target_sample))
450
+ if slice_tag:
451
+ print('jump empty segment')
452
+ _audio = np.zeros(length)
453
+ audio.extend(list(pad_array(_audio, length)))
454
+ global_frame += length // self.hop_size
455
+ continue
456
+ if per_size != 0:
457
+ datas = split_list_by_n(data, per_size,lg_size)
458
+ else:
459
+ datas = [data]
460
+ for k,dat in enumerate(datas):
461
+ per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
462
+ if clip_seconds!=0:
463
+ print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
464
+ # padd
465
+ pad_len = int(audio_sr * pad_seconds)
466
+ dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
467
+ raw_path = io.BytesIO()
468
+ soundfile.write(raw_path, dat, audio_sr, format="wav")
469
+ raw_path.seek(0)
470
+ out_audio, out_sr, out_frame = self.infer(spk, tran, raw_path,
471
+ cluster_infer_ratio=cluster_infer_ratio,
472
+ auto_predict_f0=auto_predict_f0,
473
+ noice_scale=noice_scale,
474
+ f0_predictor = f0_predictor,
475
+ enhancer_adaptive_key = enhancer_adaptive_key,
476
+ cr_threshold = cr_threshold,
477
+ k_step = k_step,
478
+ frame = global_frame,
479
+ spk_mix = use_spk_mix,
480
+ second_encoding = second_encoding,
481
+ loudness_envelope_adjustment = loudness_envelope_adjustment
482
+ )
483
+ global_frame += out_frame
484
+ _audio = out_audio.cpu().numpy()
485
+ pad_len = int(self.target_sample * pad_seconds)
486
+ _audio = _audio[pad_len:-pad_len]
487
+ _audio = pad_array(_audio, per_length)
488
+ if lg_size!=0 and k!=0:
489
+ lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:]
490
+ lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size]
491
+ lg_pre = lg1*(1-lg)+lg2*lg
492
+ audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size]
493
+ audio.extend(lg_pre)
494
+ _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:]
495
+ audio.extend(list(_audio))
496
+ return np.array(audio)
497
+
498
+ class RealTimeVC:
499
+ def __init__(self):
500
+ self.last_chunk = None
501
+ self.last_o = None
502
+ self.chunk_len = 16000 # chunk length
503
+ self.pre_len = 3840 # cross fade length, multiples of 640
504
+
505
+ # Input and output are 1-dimensional numpy waveform arrays
506
+
507
+ def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path,
508
+ cluster_infer_ratio=0,
509
+ auto_predict_f0=False,
510
+ noice_scale=0.4,
511
+ f0_filter=False):
512
+
513
+ import maad
514
+ audio, sr = torchaudio.load(input_wav_path)
515
+ audio = audio.cpu().numpy()[0]
516
+ temp_wav = io.BytesIO()
517
+ if self.last_chunk is None:
518
+ input_wav_path.seek(0)
519
+
520
+ audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path,
521
+ cluster_infer_ratio=cluster_infer_ratio,
522
+ auto_predict_f0=auto_predict_f0,
523
+ noice_scale=noice_scale,
524
+ f0_filter=f0_filter)
525
+
526
+ audio = audio.cpu().numpy()
527
+ self.last_chunk = audio[-self.pre_len:]
528
+ self.last_o = audio
529
+ return audio[-self.chunk_len:]
530
+ else:
531
+ audio = np.concatenate([self.last_chunk, audio])
532
+ soundfile.write(temp_wav, audio, sr, format="wav")
533
+ temp_wav.seek(0)
534
+
535
+ audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav,
536
+ cluster_infer_ratio=cluster_infer_ratio,
537
+ auto_predict_f0=auto_predict_f0,
538
+ noice_scale=noice_scale,
539
+ f0_filter=f0_filter)
540
+
541
+ audio = audio.cpu().numpy()
542
+ ret = maad.util.crossfade(self.last_o, audio, self.pre_len)
543
+ self.last_chunk = audio[-self.pre_len:]
544
+ self.last_o = audio
545
+ return ret[self.chunk_len:2 * self.chunk_len]
546
+
inference/infer_tool_grad.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import os
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import parselmouth
8
+ import soundfile
9
+ import torch
10
+ import torchaudio
11
+
12
+ import utils
13
+ from inference import slicer
14
+ from models import SynthesizerTrn
15
+
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
18
+
19
+ def resize2d_f0(x, target_len):
20
+ source = np.array(x)
21
+ source[source < 0.001] = np.nan
22
+ target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
23
+ source)
24
+ res = np.nan_to_num(target)
25
+ return res
26
+
27
+ def get_f0(x, p_len,f0_up_key=0):
28
+
29
+ time_step = 160 / 16000 * 1000
30
+ f0_min = 50
31
+ f0_max = 1100
32
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
33
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
34
+
35
+ f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
36
+ time_step=time_step / 1000, voicing_threshold=0.6,
37
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
38
+
39
+ pad_size=(p_len - len(f0) + 1) // 2
40
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
41
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
42
+
43
+ f0 *= pow(2, f0_up_key / 12)
44
+ f0_mel = 1127 * np.log(1 + f0 / 700)
45
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
46
+ f0_mel[f0_mel <= 1] = 1
47
+ f0_mel[f0_mel > 255] = 255
48
+ f0_coarse = np.rint(f0_mel).astype(np.int)
49
+ return f0_coarse, f0
50
+
51
+ def clean_pitch(input_pitch):
52
+ num_nan = np.sum(input_pitch == 1)
53
+ if num_nan / len(input_pitch) > 0.9:
54
+ input_pitch[input_pitch != 1] = 1
55
+ return input_pitch
56
+
57
+
58
+ def plt_pitch(input_pitch):
59
+ input_pitch = input_pitch.astype(float)
60
+ input_pitch[input_pitch == 1] = np.nan
61
+ return input_pitch
62
+
63
+
64
+ def f0_to_pitch(ff):
65
+ f0_pitch = 69 + 12 * np.log2(ff / 440)
66
+ return f0_pitch
67
+
68
+
69
+ def fill_a_to_b(a, b):
70
+ if len(a) < len(b):
71
+ for _ in range(0, len(b) - len(a)):
72
+ a.append(a[0])
73
+
74
+
75
+ def mkdir(paths: list):
76
+ for path in paths:
77
+ if not os.path.exists(path):
78
+ os.mkdir(path)
79
+
80
+
81
+ class VitsSvc(object):
82
+ def __init__(self):
83
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ self.SVCVITS = None
85
+ self.hps = None
86
+ self.speakers = None
87
+ self.hubert_soft = utils.get_hubert_model()
88
+
89
+ def set_device(self, device):
90
+ self.device = torch.device(device)
91
+ self.hubert_soft.to(self.device)
92
+ if self.SVCVITS is not None:
93
+ self.SVCVITS.to(self.device)
94
+
95
+ def loadCheckpoint(self, path):
96
+ self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
97
+ self.SVCVITS = SynthesizerTrn(
98
+ self.hps.data.filter_length // 2 + 1,
99
+ self.hps.train.segment_size // self.hps.data.hop_length,
100
+ **self.hps.model)
101
+ _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.SVCVITS, None)
102
+ _ = self.SVCVITS.eval().to(self.device)
103
+ self.speakers = self.hps.spk
104
+
105
+ def get_units(self, source, sr):
106
+ source = source.unsqueeze(0).to(self.device)
107
+ with torch.inference_mode():
108
+ units = self.hubert_soft.units(source)
109
+ return units
110
+
111
+
112
+ def get_unit_pitch(self, in_path, tran):
113
+ source, sr = torchaudio.load(in_path)
114
+ source = torchaudio.functional.resample(source, sr, 16000)
115
+ if len(source.shape) == 2 and source.shape[1] >= 2:
116
+ source = torch.mean(source, dim=0).unsqueeze(0)
117
+ soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
118
+ f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0]*2, tran)
119
+ return soft, f0
120
+
121
+ def infer(self, speaker_id, tran, raw_path):
122
+ speaker_id = self.speakers[speaker_id]
123
+ sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
124
+ soft, pitch = self.get_unit_pitch(raw_path, tran)
125
+ f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.device)
126
+ stn_tst = torch.FloatTensor(soft)
127
+ with torch.no_grad():
128
+ x_tst = stn_tst.unsqueeze(0).to(self.device)
129
+ x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
130
+ audio,_ = self.SVCVITS.infer(x_tst, f0=f0, g=sid)[0,0].data.float()
131
+ return audio, audio.shape[-1]
132
+
133
+ def inference(self,srcaudio,chara,tran,slice_db):
134
+ sampling_rate, audio = srcaudio
135
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
136
+ if len(audio.shape) > 1:
137
+ audio = librosa.to_mono(audio.transpose(1, 0))
138
+ if sampling_rate != 16000:
139
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
140
+ soundfile.write("tmpwav.wav", audio, 16000, format="wav")
141
+ chunks = slicer.cut("tmpwav.wav", db_thresh=slice_db)
142
+ audio_data, audio_sr = slicer.chunks2audio("tmpwav.wav", chunks)
143
+ audio = []
144
+ for (slice_tag, data) in audio_data:
145
+ length = int(np.ceil(len(data) / audio_sr * self.hps.data.sampling_rate))
146
+ raw_path = io.BytesIO()
147
+ soundfile.write(raw_path, data, audio_sr, format="wav")
148
+ raw_path.seek(0)
149
+ if slice_tag:
150
+ _audio = np.zeros(length)
151
+ else:
152
+ out_audio, out_sr = self.infer(chara, tran, raw_path)
153
+ _audio = out_audio.cpu().numpy()
154
+ audio.extend(list(_audio))
155
+ audio = (np.array(audio) * 32768.0).astype('int16')
156
+ return (self.hps.data.sampling_rate,audio)
inference/slicer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torchaudio
4
+
5
+
6
+ class Slicer:
7
+ def __init__(self,
8
+ sr: int,
9
+ threshold: float = -40.,
10
+ min_length: int = 5000,
11
+ min_interval: int = 300,
12
+ hop_size: int = 20,
13
+ max_sil_kept: int = 5000):
14
+ if not min_length >= min_interval >= hop_size:
15
+ raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
16
+ if not max_sil_kept >= hop_size:
17
+ raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
18
+ min_interval = sr * min_interval / 1000
19
+ self.threshold = 10 ** (threshold / 20.)
20
+ self.hop_size = round(sr * hop_size / 1000)
21
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
22
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
23
+ self.min_interval = round(min_interval / self.hop_size)
24
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
25
+
26
+ def _apply_slice(self, waveform, begin, end):
27
+ if len(waveform.shape) > 1:
28
+ return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
29
+ else:
30
+ return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
31
+
32
+ # @timeit
33
+ def slice(self, waveform):
34
+ if len(waveform.shape) > 1:
35
+ samples = librosa.to_mono(waveform)
36
+ else:
37
+ samples = waveform
38
+ if samples.shape[0] <= self.min_length:
39
+ return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
40
+ rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
41
+ sil_tags = []
42
+ silence_start = None
43
+ clip_start = 0
44
+ for i, rms in enumerate(rms_list):
45
+ # Keep looping while frame is silent.
46
+ if rms < self.threshold:
47
+ # Record start of silent frames.
48
+ if silence_start is None:
49
+ silence_start = i
50
+ continue
51
+ # Keep looping while frame is not silent and silence start has not been recorded.
52
+ if silence_start is None:
53
+ continue
54
+ # Clear recorded silence start if interval is not enough or clip is too short
55
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
56
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
57
+ if not is_leading_silence and not need_slice_middle:
58
+ silence_start = None
59
+ continue
60
+ # Need slicing. Record the range of silent frames to be removed.
61
+ if i - silence_start <= self.max_sil_kept:
62
+ pos = rms_list[silence_start: i + 1].argmin() + silence_start
63
+ if silence_start == 0:
64
+ sil_tags.append((0, pos))
65
+ else:
66
+ sil_tags.append((pos, pos))
67
+ clip_start = pos
68
+ elif i - silence_start <= self.max_sil_kept * 2:
69
+ pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
70
+ pos += i - self.max_sil_kept
71
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
72
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
73
+ if silence_start == 0:
74
+ sil_tags.append((0, pos_r))
75
+ clip_start = pos_r
76
+ else:
77
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
78
+ clip_start = max(pos_r, pos)
79
+ else:
80
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
81
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
82
+ if silence_start == 0:
83
+ sil_tags.append((0, pos_r))
84
+ else:
85
+ sil_tags.append((pos_l, pos_r))
86
+ clip_start = pos_r
87
+ silence_start = None
88
+ # Deal with trailing silence.
89
+ total_frames = rms_list.shape[0]
90
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
91
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
92
+ pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
93
+ sil_tags.append((pos, total_frames + 1))
94
+ # Apply and return slices.
95
+ if len(sil_tags) == 0:
96
+ return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
97
+ else:
98
+ chunks = []
99
+ # 第一段静音并非从头开始,补上有声片段
100
+ if sil_tags[0][0]:
101
+ chunks.append(
102
+ {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
103
+ for i in range(0, len(sil_tags)):
104
+ # 标识有声片段(跳过第一段)
105
+ if i:
106
+ chunks.append({"slice": False,
107
+ "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
108
+ # 标识所有静音片段
109
+ chunks.append({"slice": True,
110
+ "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
111
+ # 最后一段静音并非结尾,补上结尾片段
112
+ if sil_tags[-1][1] * self.hop_size < len(waveform):
113
+ chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
114
+ chunk_dict = {}
115
+ for i in range(len(chunks)):
116
+ chunk_dict[str(i)] = chunks[i]
117
+ return chunk_dict
118
+
119
+
120
+ def cut(audio_path, db_thresh=-30, min_len=5000):
121
+ audio, sr = librosa.load(audio_path, sr=None)
122
+ slicer = Slicer(
123
+ sr=sr,
124
+ threshold=db_thresh,
125
+ min_length=min_len
126
+ )
127
+ chunks = slicer.slice(audio)
128
+ return chunks
129
+
130
+
131
+ def chunks2audio(audio_path, chunks):
132
+ chunks = dict(chunks)
133
+ audio, sr = torchaudio.load(audio_path)
134
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
135
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
136
+ audio = audio.cpu().numpy()[0]
137
+ result = []
138
+ for k, v in chunks.items():
139
+ tag = v["split_time"].split(",")
140
+ if tag[0] != tag[1]:
141
+ result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
142
+ return result, sr
inference_main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import soundfile
4
+
5
+ from inference import infer_tool
6
+ from inference.infer_tool import Svc
7
+ from spkmix import spk_mix_map
8
+
9
+ logging.getLogger('numba').setLevel(logging.WARNING)
10
+ chunks_dict = infer_tool.read_temp("inference/chunks_temp.json")
11
+
12
+
13
+
14
+ def main():
15
+ import argparse
16
+
17
+ parser = argparse.ArgumentParser(description='sovits4 inference')
18
+
19
+ # 一定要设置的部分
20
+ parser.add_argument('-m', '--model_path', type=str, default="logs/44k/G_37600.pth", help='模型路径')
21
+ parser.add_argument('-c', '--config_path', type=str, default="logs/44k/config.json", help='配置文件路径')
22
+ parser.add_argument('-cl', '--clip', type=float, default=0, help='音频强制切片,默认0为自动切片,单位为秒/s')
23
+ parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["君の知らない物語-src.wav"], help='wav文件名列表,放在raw文件夹下')
24
+ parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], help='音高调整,支持正负(半音)')
25
+ parser.add_argument('-s', '--spk_list', type=str, nargs='+', default=['buyizi'], help='合成目标说话人名称')
26
+
27
+ # 可选项部分
28
+ parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=False, help='语音转换自动预测音高,转换歌声时不要打开这个会严重跑调')
29
+ parser.add_argument('-cm', '--cluster_model_path', type=str, default="", help='聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填')
30
+ parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, help='聚类方案或特征检索占比,范围0-1,若没有训练聚类模型或特征检索则默认0即可')
31
+ parser.add_argument('-lg', '--linear_gradient', type=float, default=0, help='两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值0,单位为秒')
32
+ parser.add_argument('-f0p', '--f0_predictor', type=str, default="pm", help='选择F0预测器,可选择crepe,pm,dio,harvest,rmvpe,fcpe默认为pm(注意:crepe为原F0使用均值滤波器)')
33
+ parser.add_argument('-eh', '--enhance', action='store_true', default=False, help='是否使用NSF_HIFIGAN增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭')
34
+ parser.add_argument('-shd', '--shallow_diffusion', action='store_true', default=False, help='是否使用浅层扩散,使用后可解决一部分电音问题,默认关闭,该选项打开时,NSF_HIFIGAN增强器将会被禁止')
35
+ parser.add_argument('-usm', '--use_spk_mix', action='store_true', default=False, help='是否使用角色融合')
36
+ parser.add_argument('-lea', '--loudness_envelope_adjustment', type=float, default=1, help='输入源响度包络替换输出响度包络融合比例,越靠近1越使用输出响度包络')
37
+ parser.add_argument('-fr', '--feature_retrieval', action='store_true', default=False, help='是否使用特征检索,如果使用聚类模型将被禁用,且cm与cr参数将会变成特征检索的索引路径与混合比例')
38
+
39
+ # 浅扩散设置
40
+ parser.add_argument('-dm', '--diffusion_model_path', type=str, default="logs/44k/diffusion/model_0.pt", help='扩散模型路径')
41
+ parser.add_argument('-dc', '--diffusion_config_path', type=str, default="logs/44k/diffusion/config.yaml", help='扩散模型配置文件路径')
42
+ parser.add_argument('-ks', '--k_step', type=int, default=100, help='扩散步数,越大越接近扩散模型的结果,默认100')
43
+ parser.add_argument('-se', '--second_encoding', action='store_true', default=False, help='二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,有时候效果好,有时候效果差')
44
+ parser.add_argument('-od', '--only_diffusion', action='store_true', default=False, help='纯扩散模式,该模式不会加载sovits模型,以扩散模型推理')
45
+
46
+
47
+ # 不用动的部分
48
+ parser.add_argument('-sd', '--slice_db', type=int, default=-40, help='默认-40,嘈杂的音频可以-30,干声保留呼吸可以-50')
49
+ parser.add_argument('-d', '--device', type=str, default=None, help='推理设备,None则为自动选择cpu和gpu')
50
+ parser.add_argument('-ns', '--noice_scale', type=float, default=0.4, help='噪音级别,会影响咬字和音质,较为玄学')
51
+ parser.add_argument('-p', '--pad_seconds', type=float, default=0.5, help='推理音频pad秒数,由于未知原因开头结尾会有异响,pad一小段静音段后就不会出现')
52
+ parser.add_argument('-wf', '--wav_format', type=str, default='flac', help='音频输出格式')
53
+ parser.add_argument('-lgr', '--linear_gradient_retain', type=float, default=0.75, help='自动音频切片后,需要舍弃每段切片的头尾。该参数设置交叉长度保留的比例,范围0-1,左开右闭')
54
+ parser.add_argument('-eak', '--enhancer_adaptive_key', type=int, default=0, help='使增强器适应更高的音域(单位为半音数)|默认为0')
55
+ parser.add_argument('-ft', '--f0_filter_threshold', type=float, default=0.05,help='F0过滤阈值,只有使用crepe时有效. 数值范围从0-1. 降低该值可减少跑调概率,但会增加哑音')
56
+
57
+
58
+ args = parser.parse_args()
59
+
60
+ clean_names = args.clean_names
61
+ trans = args.trans
62
+ spk_list = args.spk_list
63
+ slice_db = args.slice_db
64
+ wav_format = args.wav_format
65
+ auto_predict_f0 = args.auto_predict_f0
66
+ cluster_infer_ratio = args.cluster_infer_ratio
67
+ noice_scale = args.noice_scale
68
+ pad_seconds = args.pad_seconds
69
+ clip = args.clip
70
+ lg = args.linear_gradient
71
+ lgr = args.linear_gradient_retain
72
+ f0p = args.f0_predictor
73
+ enhance = args.enhance
74
+ enhancer_adaptive_key = args.enhancer_adaptive_key
75
+ cr_threshold = args.f0_filter_threshold
76
+ diffusion_model_path = args.diffusion_model_path
77
+ diffusion_config_path = args.diffusion_config_path
78
+ k_step = args.k_step
79
+ only_diffusion = args.only_diffusion
80
+ shallow_diffusion = args.shallow_diffusion
81
+ use_spk_mix = args.use_spk_mix
82
+ second_encoding = args.second_encoding
83
+ loudness_envelope_adjustment = args.loudness_envelope_adjustment
84
+
85
+ if cluster_infer_ratio != 0:
86
+ if args.cluster_model_path == "":
87
+ if args.feature_retrieval: # 若指定了占比但没有指定模型路径,则按是否使用特征检索分配默认的模型路径
88
+ args.cluster_model_path = "logs/44k/feature_and_index.pkl"
89
+ else:
90
+ args.cluster_model_path = "logs/44k/kmeans_10000.pt"
91
+ else: # 若未指定占比,则无论是否指定模型路径,都将其置空以避免之后的模型加载
92
+ args.cluster_model_path = ""
93
+
94
+ svc_model = Svc(args.model_path,
95
+ args.config_path,
96
+ args.device,
97
+ args.cluster_model_path,
98
+ enhance,
99
+ diffusion_model_path,
100
+ diffusion_config_path,
101
+ shallow_diffusion,
102
+ only_diffusion,
103
+ use_spk_mix,
104
+ args.feature_retrieval)
105
+
106
+ infer_tool.mkdir(["raw", "results"])
107
+
108
+ if len(spk_mix_map)<=1:
109
+ use_spk_mix = False
110
+ if use_spk_mix:
111
+ spk_list = [spk_mix_map]
112
+
113
+ infer_tool.fill_a_to_b(trans, clean_names)
114
+ for clean_name, tran in zip(clean_names, trans):
115
+ raw_audio_path = f"raw/{clean_name}"
116
+ if "." not in raw_audio_path:
117
+ raw_audio_path += ".wav"
118
+ infer_tool.format_wav(raw_audio_path)
119
+ for spk in spk_list:
120
+ kwarg = {
121
+ "raw_audio_path" : raw_audio_path,
122
+ "spk" : spk,
123
+ "tran" : tran,
124
+ "slice_db" : slice_db,
125
+ "cluster_infer_ratio" : cluster_infer_ratio,
126
+ "auto_predict_f0" : auto_predict_f0,
127
+ "noice_scale" : noice_scale,
128
+ "pad_seconds" : pad_seconds,
129
+ "clip_seconds" : clip,
130
+ "lg_num": lg,
131
+ "lgr_num" : lgr,
132
+ "f0_predictor" : f0p,
133
+ "enhancer_adaptive_key" : enhancer_adaptive_key,
134
+ "cr_threshold" : cr_threshold,
135
+ "k_step":k_step,
136
+ "use_spk_mix":use_spk_mix,
137
+ "second_encoding":second_encoding,
138
+ "loudness_envelope_adjustment":loudness_envelope_adjustment
139
+ }
140
+ audio = svc_model.slice_inference(**kwarg)
141
+ key = "auto" if auto_predict_f0 else f"{tran}key"
142
+ cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}"
143
+ isdiffusion = "sovits"
144
+ if shallow_diffusion :
145
+ isdiffusion = "sovdiff"
146
+ if only_diffusion :
147
+ isdiffusion = "diff"
148
+ if use_spk_mix:
149
+ spk = "spk_mix"
150
+ res_path = f'results/{clean_name}_{key}_{spk}{cluster_name}_{isdiffusion}_{f0p}.{wav_format}'
151
+ soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format)
152
+ svc_model.clear_empty()
153
+
154
+ if __name__ == '__main__':
155
+ main()