nevreal commited on
Commit
e07bb32
1 Parent(s): b576d8c

Delete infer/modules/train

Browse files
infer/modules/train/extract/extract_f0_print.py DELETED
@@ -1,175 +0,0 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import parselmouth
6
-
7
- now_dir = os.getcwd()
8
- sys.path.append(now_dir)
9
- import logging
10
-
11
- import numpy as np
12
- import pyworld
13
-
14
- from infer.lib.audio import load_audio
15
-
16
- logging.getLogger("numba").setLevel(logging.WARNING)
17
- from multiprocessing import Process
18
-
19
- exp_dir = sys.argv[1]
20
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
21
-
22
-
23
- def printt(strr):
24
- print(strr)
25
- f.write("%s\n" % strr)
26
- f.flush()
27
-
28
-
29
- n_p = int(sys.argv[2])
30
- f0method = sys.argv[3]
31
-
32
-
33
- class FeatureInput(object):
34
- def __init__(self, samplerate=16000, hop_size=160):
35
- self.fs = samplerate
36
- self.hop = hop_size
37
-
38
- self.f0_bin = 256
39
- self.f0_max = 1100.0
40
- self.f0_min = 50.0
41
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
42
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
43
-
44
- def compute_f0(self, path, f0_method):
45
- x = load_audio(path, self.fs)
46
- p_len = x.shape[0] // self.hop
47
- if f0_method == "pm":
48
- time_step = 160 / 16000 * 1000
49
- f0_min = 50
50
- f0_max = 1100
51
- f0 = (
52
- parselmouth.Sound(x, self.fs)
53
- .to_pitch_ac(
54
- time_step=time_step / 1000,
55
- voicing_threshold=0.6,
56
- pitch_floor=f0_min,
57
- pitch_ceiling=f0_max,
58
- )
59
- .selected_array["frequency"]
60
- )
61
- pad_size = (p_len - len(f0) + 1) // 2
62
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
63
- f0 = np.pad(
64
- f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
65
- )
66
- elif f0_method == "harvest":
67
- f0, t = pyworld.harvest(
68
- x.astype(np.double),
69
- fs=self.fs,
70
- f0_ceil=self.f0_max,
71
- f0_floor=self.f0_min,
72
- frame_period=1000 * self.hop / self.fs,
73
- )
74
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
75
- elif f0_method == "dio":
76
- f0, t = pyworld.dio(
77
- x.astype(np.double),
78
- fs=self.fs,
79
- f0_ceil=self.f0_max,
80
- f0_floor=self.f0_min,
81
- frame_period=1000 * self.hop / self.fs,
82
- )
83
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
84
- elif f0_method == "rmvpe":
85
- if hasattr(self, "model_rmvpe") == False:
86
- from infer.lib.rmvpe import RMVPE
87
-
88
- print("Loading rmvpe model")
89
- self.model_rmvpe = RMVPE(
90
- "assets/rmvpe/rmvpe.pt", is_half=False, device="cpu"
91
- )
92
- f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
93
- return f0
94
-
95
- def coarse_f0(self, f0):
96
- f0_mel = 1127 * np.log(1 + f0 / 700)
97
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * (
98
- self.f0_bin - 2
99
- ) / (self.f0_mel_max - self.f0_mel_min) + 1
100
-
101
- # use 0 or 1
102
- f0_mel[f0_mel <= 1] = 1
103
- f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1
104
- f0_coarse = np.rint(f0_mel).astype(int)
105
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
106
- f0_coarse.max(),
107
- f0_coarse.min(),
108
- )
109
- return f0_coarse
110
-
111
- def go(self, paths, f0_method):
112
- if len(paths) == 0:
113
- printt("no-f0-todo")
114
- else:
115
- printt("todo-f0-%s" % len(paths))
116
- n = max(len(paths) // 5, 1) # 每个进程最多打印5条
117
- for idx, (inp_path, opt_path1, opt_path2) in enumerate(paths):
118
- try:
119
- if idx % n == 0:
120
- printt("f0ing,now-%s,all-%s,-%s" % (idx, len(paths), inp_path))
121
- if (
122
- os.path.exists(opt_path1 + ".npy") == True
123
- and os.path.exists(opt_path2 + ".npy") == True
124
- ):
125
- continue
126
- featur_pit = self.compute_f0(inp_path, f0_method)
127
- np.save(
128
- opt_path2,
129
- featur_pit,
130
- allow_pickle=False,
131
- ) # nsf
132
- coarse_pit = self.coarse_f0(featur_pit)
133
- np.save(
134
- opt_path1,
135
- coarse_pit,
136
- allow_pickle=False,
137
- ) # ori
138
- except:
139
- printt("f0fail-%s-%s-%s" % (idx, inp_path, traceback.format_exc()))
140
-
141
-
142
- if __name__ == "__main__":
143
- # exp_dir=r"E:\codes\py39\dataset\mi-test"
144
- # n_p=16
145
- # f = open("%s/log_extract_f0.log"%exp_dir, "w")
146
- printt(" ".join(sys.argv))
147
- featureInput = FeatureInput()
148
- paths = []
149
- inp_root = "%s/1_16k_wavs" % (exp_dir)
150
- opt_root1 = "%s/2a_f0" % (exp_dir)
151
- opt_root2 = "%s/2b-f0nsf" % (exp_dir)
152
-
153
- os.makedirs(opt_root1, exist_ok=True)
154
- os.makedirs(opt_root2, exist_ok=True)
155
- for name in sorted(list(os.listdir(inp_root))):
156
- inp_path = "%s/%s" % (inp_root, name)
157
- if "spec" in inp_path:
158
- continue
159
- opt_path1 = "%s/%s" % (opt_root1, name)
160
- opt_path2 = "%s/%s" % (opt_root2, name)
161
- paths.append([inp_path, opt_path1, opt_path2])
162
-
163
- ps = []
164
- for i in range(n_p):
165
- p = Process(
166
- target=featureInput.go,
167
- args=(
168
- paths[i::n_p],
169
- f0method,
170
- ),
171
- )
172
- ps.append(p)
173
- p.start()
174
- for i in range(n_p):
175
- ps[i].join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/modules/train/extract/extract_f0_rmvpe.py DELETED
@@ -1,141 +0,0 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import parselmouth
6
-
7
- now_dir = os.getcwd()
8
- sys.path.append(now_dir)
9
- import logging
10
-
11
- import numpy as np
12
- import pyworld
13
-
14
- from infer.lib.audio import load_audio
15
-
16
- logging.getLogger("numba").setLevel(logging.WARNING)
17
-
18
- n_part = int(sys.argv[1])
19
- i_part = int(sys.argv[2])
20
- i_gpu = sys.argv[3]
21
- os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
22
- exp_dir = sys.argv[4]
23
- is_half = sys.argv[5]
24
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
25
-
26
-
27
- def printt(strr):
28
- print(strr)
29
- f.write("%s\n" % strr)
30
- f.flush()
31
-
32
-
33
- class FeatureInput(object):
34
- def __init__(self, samplerate=16000, hop_size=160):
35
- self.fs = samplerate
36
- self.hop = hop_size
37
-
38
- self.f0_bin = 256
39
- self.f0_max = 1100.0
40
- self.f0_min = 50.0
41
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
42
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
43
-
44
- def compute_f0(self, path, f0_method):
45
- x = load_audio(path, self.fs)
46
- # p_len = x.shape[0] // self.hop
47
- if f0_method == "rmvpe":
48
- if hasattr(self, "model_rmvpe") == False:
49
- from infer.lib.rmvpe import RMVPE
50
-
51
- print("Loading rmvpe model")
52
- self.model_rmvpe = RMVPE(
53
- "assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda"
54
- )
55
- f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
56
- return f0
57
-
58
- def coarse_f0(self, f0):
59
- f0_mel = 1127 * np.log(1 + f0 / 700)
60
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * (
61
- self.f0_bin - 2
62
- ) / (self.f0_mel_max - self.f0_mel_min) + 1
63
-
64
- # use 0 or 1
65
- f0_mel[f0_mel <= 1] = 1
66
- f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1
67
- f0_coarse = np.rint(f0_mel).astype(int)
68
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
69
- f0_coarse.max(),
70
- f0_coarse.min(),
71
- )
72
- return f0_coarse
73
-
74
- def go(self, paths, f0_method):
75
- if len(paths) == 0:
76
- printt("no-f0-todo")
77
- else:
78
- printt("todo-f0-%s" % len(paths))
79
- n = max(len(paths) // 5, 1) # 每个进程最多打印5条
80
- for idx, (inp_path, opt_path1, opt_path2) in enumerate(paths):
81
- try:
82
- if idx % n == 0:
83
- printt("f0ing,now-%s,all-%s,-%s" % (idx, len(paths), inp_path))
84
- if (
85
- os.path.exists(opt_path1 + ".npy") == True
86
- and os.path.exists(opt_path2 + ".npy") == True
87
- ):
88
- continue
89
- featur_pit = self.compute_f0(inp_path, f0_method)
90
- np.save(
91
- opt_path2,
92
- featur_pit,
93
- allow_pickle=False,
94
- ) # nsf
95
- coarse_pit = self.coarse_f0(featur_pit)
96
- np.save(
97
- opt_path1,
98
- coarse_pit,
99
- allow_pickle=False,
100
- ) # ori
101
- except:
102
- printt("f0fail-%s-%s-%s" % (idx, inp_path, traceback.format_exc()))
103
-
104
-
105
- if __name__ == "__main__":
106
- # exp_dir=r"E:\codes\py39\dataset\mi-test"
107
- # n_p=16
108
- # f = open("%s/log_extract_f0.log"%exp_dir, "w")
109
- printt(" ".join(sys.argv))
110
- featureInput = FeatureInput()
111
- paths = []
112
- inp_root = "%s/1_16k_wavs" % (exp_dir)
113
- opt_root1 = "%s/2a_f0" % (exp_dir)
114
- opt_root2 = "%s/2b-f0nsf" % (exp_dir)
115
-
116
- os.makedirs(opt_root1, exist_ok=True)
117
- os.makedirs(opt_root2, exist_ok=True)
118
- for name in sorted(list(os.listdir(inp_root))):
119
- inp_path = "%s/%s" % (inp_root, name)
120
- if "spec" in inp_path:
121
- continue
122
- opt_path1 = "%s/%s" % (opt_root1, name)
123
- opt_path2 = "%s/%s" % (opt_root2, name)
124
- paths.append([inp_path, opt_path1, opt_path2])
125
- try:
126
- featureInput.go(paths[i_part::n_part], "rmvpe")
127
- except:
128
- printt("f0_all_fail-%s" % (traceback.format_exc()))
129
- # ps = []
130
- # for i in range(n_p):
131
- # p = Process(
132
- # target=featureInput.go,
133
- # args=(
134
- # paths[i::n_p],
135
- # f0method,
136
- # ),
137
- # )
138
- # ps.append(p)
139
- # p.start()
140
- # for i in range(n_p):
141
- # ps[i].join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/modules/train/extract/extract_f0_rmvpe_dml.py DELETED
@@ -1,139 +0,0 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- import parselmouth
6
-
7
- now_dir = os.getcwd()
8
- sys.path.append(now_dir)
9
- import logging
10
-
11
- import numpy as np
12
- import pyworld
13
-
14
- from infer.lib.audio import load_audio
15
-
16
- logging.getLogger("numba").setLevel(logging.WARNING)
17
-
18
- exp_dir = sys.argv[1]
19
- import torch_directml
20
-
21
- device = torch_directml.device(torch_directml.default_device())
22
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
23
-
24
-
25
- def printt(strr):
26
- print(strr)
27
- f.write("%s\n" % strr)
28
- f.flush()
29
-
30
-
31
- class FeatureInput(object):
32
- def __init__(self, samplerate=16000, hop_size=160):
33
- self.fs = samplerate
34
- self.hop = hop_size
35
-
36
- self.f0_bin = 256
37
- self.f0_max = 1100.0
38
- self.f0_min = 50.0
39
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
40
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
41
-
42
- def compute_f0(self, path, f0_method):
43
- x = load_audio(path, self.fs)
44
- # p_len = x.shape[0] // self.hop
45
- if f0_method == "rmvpe":
46
- if hasattr(self, "model_rmvpe") == False:
47
- from infer.lib.rmvpe import RMVPE
48
-
49
- print("Loading rmvpe model")
50
- self.model_rmvpe = RMVPE(
51
- "assets/rmvpe/rmvpe.pt", is_half=False, device=device
52
- )
53
- f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
54
- return f0
55
-
56
- def coarse_f0(self, f0):
57
- f0_mel = 1127 * np.log(1 + f0 / 700)
58
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * (
59
- self.f0_bin - 2
60
- ) / (self.f0_mel_max - self.f0_mel_min) + 1
61
-
62
- # use 0 or 1
63
- f0_mel[f0_mel <= 1] = 1
64
- f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1
65
- f0_coarse = np.rint(f0_mel).astype(int)
66
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
67
- f0_coarse.max(),
68
- f0_coarse.min(),
69
- )
70
- return f0_coarse
71
-
72
- def go(self, paths, f0_method):
73
- if len(paths) == 0:
74
- printt("no-f0-todo")
75
- else:
76
- printt("todo-f0-%s" % len(paths))
77
- n = max(len(paths) // 5, 1) # 每个进程最多打印5条
78
- for idx, (inp_path, opt_path1, opt_path2) in enumerate(paths):
79
- try:
80
- if idx % n == 0:
81
- printt("f0ing,now-%s,all-%s,-%s" % (idx, len(paths), inp_path))
82
- if (
83
- os.path.exists(opt_path1 + ".npy") == True
84
- and os.path.exists(opt_path2 + ".npy") == True
85
- ):
86
- continue
87
- featur_pit = self.compute_f0(inp_path, f0_method)
88
- np.save(
89
- opt_path2,
90
- featur_pit,
91
- allow_pickle=False,
92
- ) # nsf
93
- coarse_pit = self.coarse_f0(featur_pit)
94
- np.save(
95
- opt_path1,
96
- coarse_pit,
97
- allow_pickle=False,
98
- ) # ori
99
- except:
100
- printt("f0fail-%s-%s-%s" % (idx, inp_path, traceback.format_exc()))
101
-
102
-
103
- if __name__ == "__main__":
104
- # exp_dir=r"E:\codes\py39\dataset\mi-test"
105
- # n_p=16
106
- # f = open("%s/log_extract_f0.log"%exp_dir, "w")
107
- printt(" ".join(sys.argv))
108
- featureInput = FeatureInput()
109
- paths = []
110
- inp_root = "%s/1_16k_wavs" % (exp_dir)
111
- opt_root1 = "%s/2a_f0" % (exp_dir)
112
- opt_root2 = "%s/2b-f0nsf" % (exp_dir)
113
-
114
- os.makedirs(opt_root1, exist_ok=True)
115
- os.makedirs(opt_root2, exist_ok=True)
116
- for name in sorted(list(os.listdir(inp_root))):
117
- inp_path = "%s/%s" % (inp_root, name)
118
- if "spec" in inp_path:
119
- continue
120
- opt_path1 = "%s/%s" % (opt_root1, name)
121
- opt_path2 = "%s/%s" % (opt_root2, name)
122
- paths.append([inp_path, opt_path1, opt_path2])
123
- try:
124
- featureInput.go(paths, "rmvpe")
125
- except:
126
- printt("f0_all_fail-%s" % (traceback.format_exc()))
127
- # ps = []
128
- # for i in range(n_p):
129
- # p = Process(
130
- # target=featureInput.go,
131
- # args=(
132
- # paths[i::n_p],
133
- # f0method,
134
- # ),
135
- # )
136
- # ps.append(p)
137
- # p.start()
138
- # for i in range(n_p):
139
- # ps[i].join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/modules/train/extract_feature_print.py DELETED
@@ -1,142 +0,0 @@
1
- import os
2
- import sys
3
- import traceback
4
-
5
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
- os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
7
-
8
- device = sys.argv[1]
9
- n_part = int(sys.argv[2])
10
- i_part = int(sys.argv[3])
11
- if len(sys.argv) == 7:
12
- exp_dir = sys.argv[4]
13
- version = sys.argv[5]
14
- is_half = sys.argv[6].lower() == "true"
15
- else:
16
- i_gpu = sys.argv[4]
17
- exp_dir = sys.argv[5]
18
- os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
19
- version = sys.argv[6]
20
- is_half = sys.argv[7].lower() == "true"
21
- import fairseq
22
- import numpy as np
23
- import soundfile as sf
24
- import torch
25
- import torch.nn.functional as F
26
-
27
- if "privateuseone" not in device:
28
- device = "cpu"
29
- if torch.cuda.is_available():
30
- device = "cuda"
31
- elif torch.backends.mps.is_available():
32
- device = "mps"
33
- else:
34
- import torch_directml
35
-
36
- device = torch_directml.device(torch_directml.default_device())
37
-
38
- def forward_dml(ctx, x, scale):
39
- ctx.scale = scale
40
- res = x.clone().detach()
41
- return res
42
-
43
- fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
44
-
45
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
46
-
47
-
48
- def printt(strr):
49
- print(strr)
50
- f.write("%s\n" % strr)
51
- f.flush()
52
-
53
-
54
- printt(" ".join(sys.argv))
55
- model_path = "assets/hubert/hubert_base.pt"
56
-
57
- printt("exp_dir: " + exp_dir)
58
- wavPath = "%s/1_16k_wavs" % exp_dir
59
- outPath = (
60
- "%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
61
- )
62
- os.makedirs(outPath, exist_ok=True)
63
-
64
-
65
- # wave must be 16k, hop_size=320
66
- def readwave(wav_path, normalize=False):
67
- wav, sr = sf.read(wav_path)
68
- assert sr == 16000
69
- feats = torch.from_numpy(wav).float()
70
- if feats.dim() == 2: # double channels
71
- feats = feats.mean(-1)
72
- assert feats.dim() == 1, feats.dim()
73
- if normalize:
74
- with torch.no_grad():
75
- feats = F.layer_norm(feats, feats.shape)
76
- feats = feats.view(1, -1)
77
- return feats
78
-
79
-
80
- # HuBERT model
81
- printt("load model(s) from {}".format(model_path))
82
- # if hubert model is exist
83
- if os.access(model_path, os.F_OK) == False:
84
- printt(
85
- "Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
86
- % model_path
87
- )
88
- exit(0)
89
- models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
90
- [model_path],
91
- suffix="",
92
- )
93
- model = models[0]
94
- model = model.to(device)
95
- printt("move model to %s" % device)
96
- if is_half:
97
- if device not in ["mps", "cpu"]:
98
- model = model.half()
99
- model.eval()
100
-
101
- todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
102
- n = max(1, len(todo) // 10) # 最多打印十条
103
- if len(todo) == 0:
104
- printt("no-feature-todo")
105
- else:
106
- printt("all-feature-%s" % len(todo))
107
- for idx, file in enumerate(todo):
108
- try:
109
- if file.endswith(".wav"):
110
- wav_path = "%s/%s" % (wavPath, file)
111
- out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
112
-
113
- if os.path.exists(out_path):
114
- continue
115
-
116
- feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
117
- padding_mask = torch.BoolTensor(feats.shape).fill_(False)
118
- inputs = {
119
- "source": (
120
- feats.half().to(device)
121
- if is_half and device not in ["mps", "cpu"]
122
- else feats.to(device)
123
- ),
124
- "padding_mask": padding_mask.to(device),
125
- "output_layer": 9 if version == "v1" else 12, # layer 9
126
- }
127
- with torch.no_grad():
128
- logits = model.extract_features(**inputs)
129
- feats = (
130
- model.final_proj(logits[0]) if version == "v1" else logits[0]
131
- )
132
-
133
- feats = feats.squeeze(0).float().cpu().numpy()
134
- if np.isnan(feats).sum() == 0:
135
- np.save(out_path, feats, allow_pickle=False)
136
- else:
137
- printt("%s-contains nan" % file)
138
- if idx % n == 0:
139
- printt("now-%s,all-%s,%s,%s" % (len(todo), idx, file, feats.shape))
140
- except:
141
- printt(traceback.format_exc())
142
- printt("all-feature-done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/modules/train/preprocess.py DELETED
@@ -1,142 +0,0 @@
1
- import multiprocessing
2
- import os
3
- import sys
4
-
5
- from scipy import signal
6
-
7
- now_dir = os.getcwd()
8
- sys.path.append(now_dir)
9
- print(*sys.argv[1:])
10
- inp_root = sys.argv[1]
11
- sr = int(sys.argv[2])
12
- n_p = int(sys.argv[3])
13
- exp_dir = sys.argv[4]
14
- noparallel = sys.argv[5] == "True"
15
- per = float(sys.argv[6])
16
- import os
17
- import traceback
18
-
19
- import librosa
20
- import numpy as np
21
- from scipy.io import wavfile
22
-
23
- from infer.lib.audio import load_audio
24
- from infer.lib.slicer2 import Slicer
25
-
26
- f = open("%s/preprocess.log" % exp_dir, "a+")
27
-
28
-
29
- def println(strr):
30
- print(strr)
31
- f.write("%s\n" % strr)
32
- f.flush()
33
-
34
-
35
- class PreProcess:
36
- def __init__(self, sr, exp_dir, per=3.7):
37
- self.slicer = Slicer(
38
- sr=sr,
39
- threshold=-42,
40
- min_length=1500,
41
- min_interval=400,
42
- hop_size=15,
43
- max_sil_kept=500,
44
- )
45
- self.sr = sr
46
- self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
47
- self.per = per
48
- self.overlap = 0.3
49
- self.tail = self.per + self.overlap
50
- self.max = 0.9
51
- self.alpha = 0.75
52
- self.exp_dir = exp_dir
53
- self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
54
- self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
55
- os.makedirs(self.exp_dir, exist_ok=True)
56
- os.makedirs(self.gt_wavs_dir, exist_ok=True)
57
- os.makedirs(self.wavs16k_dir, exist_ok=True)
58
-
59
- def norm_write(self, tmp_audio, idx0, idx1):
60
- tmp_max = np.abs(tmp_audio).max()
61
- if tmp_max > 2.5:
62
- print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
63
- return
64
- tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
65
- 1 - self.alpha
66
- ) * tmp_audio
67
- wavfile.write(
68
- "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
69
- self.sr,
70
- tmp_audio.astype(np.float32),
71
- )
72
- tmp_audio = librosa.resample(
73
- tmp_audio, orig_sr=self.sr, target_sr=16000
74
- ) # , res_type="soxr_vhq"
75
- wavfile.write(
76
- "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
77
- 16000,
78
- tmp_audio.astype(np.float32),
79
- )
80
-
81
- def pipeline(self, path, idx0):
82
- try:
83
- audio = load_audio(path, self.sr)
84
- # zero phased digital filter cause pre-ringing noise...
85
- # audio = signal.filtfilt(self.bh, self.ah, audio)
86
- audio = signal.lfilter(self.bh, self.ah, audio)
87
-
88
- idx1 = 0
89
- for audio in self.slicer.slice(audio):
90
- i = 0
91
- while 1:
92
- start = int(self.sr * (self.per - self.overlap) * i)
93
- i += 1
94
- if len(audio[start:]) > self.tail * self.sr:
95
- tmp_audio = audio[start : start + int(self.per * self.sr)]
96
- self.norm_write(tmp_audio, idx0, idx1)
97
- idx1 += 1
98
- else:
99
- tmp_audio = audio[start:]
100
- idx1 += 1
101
- break
102
- self.norm_write(tmp_audio, idx0, idx1)
103
- println("%s\t-> Success" % path)
104
- except:
105
- println("%s\t-> %s" % (path, traceback.format_exc()))
106
-
107
- def pipeline_mp(self, infos):
108
- for path, idx0 in infos:
109
- self.pipeline(path, idx0)
110
-
111
- def pipeline_mp_inp_dir(self, inp_root, n_p):
112
- try:
113
- infos = [
114
- ("%s/%s" % (inp_root, name), idx)
115
- for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
116
- ]
117
- if noparallel:
118
- for i in range(n_p):
119
- self.pipeline_mp(infos[i::n_p])
120
- else:
121
- ps = []
122
- for i in range(n_p):
123
- p = multiprocessing.Process(
124
- target=self.pipeline_mp, args=(infos[i::n_p],)
125
- )
126
- ps.append(p)
127
- p.start()
128
- for i in range(n_p):
129
- ps[i].join()
130
- except:
131
- println("Fail. %s" % traceback.format_exc())
132
-
133
-
134
- def preprocess_trainset(inp_root, sr, n_p, exp_dir, per):
135
- pp = PreProcess(sr, exp_dir, per)
136
- println("start preprocess")
137
- pp.pipeline_mp_inp_dir(inp_root, n_p)
138
- println("end preprocess")
139
-
140
-
141
- if __name__ == "__main__":
142
- preprocess_trainset(inp_root, sr, n_p, exp_dir, per)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer/modules/train/train.py DELETED
@@ -1,640 +0,0 @@
1
- import os
2
- import sys
3
- import logging
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- now_dir = os.getcwd()
8
- sys.path.append(os.path.join(now_dir))
9
-
10
- import datetime
11
-
12
- from infer.lib.train import utils
13
-
14
- hps = utils.get_hparams()
15
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
16
- n_gpus = len(hps.gpus.split("-"))
17
- from random import randint, shuffle
18
-
19
- import torch
20
-
21
- try:
22
- import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
23
-
24
- if torch.xpu.is_available():
25
- from infer.modules.ipex import ipex_init
26
- from infer.modules.ipex.gradscaler import gradscaler_init
27
- from torch.xpu.amp import autocast
28
-
29
- GradScaler = gradscaler_init()
30
- ipex_init()
31
- else:
32
- from torch.cuda.amp import GradScaler, autocast
33
- except Exception:
34
- from torch.cuda.amp import GradScaler, autocast
35
-
36
- torch.backends.cudnn.deterministic = False
37
- torch.backends.cudnn.benchmark = False
38
- from time import sleep
39
- from time import time as ttime
40
-
41
- import torch.distributed as dist
42
- import torch.multiprocessing as mp
43
- from torch.nn import functional as F
44
- from torch.nn.parallel import DistributedDataParallel as DDP
45
- from torch.utils.data import DataLoader
46
- from torch.utils.tensorboard import SummaryWriter
47
-
48
- from infer.lib.infer_pack import commons
49
- from infer.lib.train.data_utils import (
50
- DistributedBucketSampler,
51
- TextAudioCollate,
52
- TextAudioCollateMultiNSFsid,
53
- TextAudioLoader,
54
- TextAudioLoaderMultiNSFsid,
55
- )
56
-
57
- if hps.version == "v1":
58
- from infer.lib.infer_pack.models import MultiPeriodDiscriminator
59
- from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
60
- from infer.lib.infer_pack.models import (
61
- SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
62
- )
63
- else:
64
- from infer.lib.infer_pack.models import (
65
- SynthesizerTrnMs768NSFsid as RVC_Model_f0,
66
- SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
67
- MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
68
- )
69
-
70
- from infer.lib.train.losses import (
71
- discriminator_loss,
72
- feature_loss,
73
- generator_loss,
74
- kl_loss,
75
- )
76
- from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
77
- from infer.lib.train.process_ckpt import savee
78
-
79
- global_step = 0
80
-
81
-
82
- class EpochRecorder:
83
- def __init__(self):
84
- self.last_time = ttime()
85
-
86
- def record(self):
87
- now_time = ttime()
88
- elapsed_time = now_time - self.last_time
89
- self.last_time = now_time
90
- elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
91
- current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
92
- return f"[{current_time}] | ({elapsed_time_str})"
93
-
94
-
95
- def main():
96
- n_gpus = torch.cuda.device_count()
97
-
98
- if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
99
- n_gpus = 1
100
- if n_gpus < 1:
101
- # patch to unblock people without gpus. there is probably a better way.
102
- print("NO GPU DETECTED: falling back to CPU - this may take a while")
103
- n_gpus = 1
104
- os.environ["MASTER_ADDR"] = "localhost"
105
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
106
- children = []
107
- logger = utils.get_logger(hps.model_dir)
108
- for i in range(n_gpus):
109
- subproc = mp.Process(
110
- target=run,
111
- args=(i, n_gpus, hps, logger),
112
- )
113
- children.append(subproc)
114
- subproc.start()
115
-
116
- for i in range(n_gpus):
117
- children[i].join()
118
-
119
-
120
- def run(rank, n_gpus, hps, logger: logging.Logger):
121
- global global_step
122
- if rank == 0:
123
- # logger = utils.get_logger(hps.model_dir)
124
- logger.info(hps)
125
- # utils.check_git_hash(hps.model_dir)
126
- writer = SummaryWriter(log_dir=hps.model_dir)
127
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
128
-
129
- dist.init_process_group(
130
- backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
131
- )
132
- torch.manual_seed(hps.train.seed)
133
- if torch.cuda.is_available():
134
- torch.cuda.set_device(rank)
135
-
136
- if hps.if_f0 == 1:
137
- train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
138
- else:
139
- train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
140
- train_sampler = DistributedBucketSampler(
141
- train_dataset,
142
- hps.train.batch_size * n_gpus,
143
- # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
144
- [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
145
- num_replicas=n_gpus,
146
- rank=rank,
147
- shuffle=True,
148
- )
149
- # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
150
- # num_workers=8 -> num_workers=4
151
- if hps.if_f0 == 1:
152
- collate_fn = TextAudioCollateMultiNSFsid()
153
- else:
154
- collate_fn = TextAudioCollate()
155
- train_loader = DataLoader(
156
- train_dataset,
157
- num_workers=4,
158
- shuffle=False,
159
- pin_memory=True,
160
- collate_fn=collate_fn,
161
- batch_sampler=train_sampler,
162
- persistent_workers=True,
163
- prefetch_factor=8,
164
- )
165
- if hps.if_f0 == 1:
166
- net_g = RVC_Model_f0(
167
- hps.data.filter_length // 2 + 1,
168
- hps.train.segment_size // hps.data.hop_length,
169
- **hps.model,
170
- is_half=hps.train.fp16_run,
171
- sr=hps.sample_rate,
172
- )
173
- else:
174
- net_g = RVC_Model_nof0(
175
- hps.data.filter_length // 2 + 1,
176
- hps.train.segment_size // hps.data.hop_length,
177
- **hps.model,
178
- is_half=hps.train.fp16_run,
179
- )
180
- if torch.cuda.is_available():
181
- net_g = net_g.cuda(rank)
182
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
183
- if torch.cuda.is_available():
184
- net_d = net_d.cuda(rank)
185
- optim_g = torch.optim.AdamW(
186
- net_g.parameters(),
187
- hps.train.learning_rate,
188
- betas=hps.train.betas,
189
- eps=hps.train.eps,
190
- )
191
- optim_d = torch.optim.AdamW(
192
- net_d.parameters(),
193
- hps.train.learning_rate,
194
- betas=hps.train.betas,
195
- eps=hps.train.eps,
196
- )
197
- # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
198
- # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
199
- if hasattr(torch, "xpu") and torch.xpu.is_available():
200
- pass
201
- elif torch.cuda.is_available():
202
- net_g = DDP(net_g, device_ids=[rank])
203
- net_d = DDP(net_d, device_ids=[rank])
204
- else:
205
- net_g = DDP(net_g)
206
- net_d = DDP(net_d)
207
-
208
- try: # 如果能加载自动resume
209
- _, _, _, epoch_str = utils.load_checkpoint(
210
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
211
- ) # D多半加载没事
212
- if rank == 0:
213
- logger.info("loaded D")
214
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
215
- _, _, _, epoch_str = utils.load_checkpoint(
216
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
217
- )
218
- global_step = (epoch_str - 1) * len(train_loader)
219
- # epoch_str = 1
220
- # global_step = 0
221
- except: # 如果首次不能加载,加载pretrain
222
- # traceback.print_exc()
223
- epoch_str = 1
224
- global_step = 0
225
- if hps.pretrainG != "":
226
- if rank == 0:
227
- logger.info("loaded pretrained %s" % (hps.pretrainG))
228
- if hasattr(net_g, "module"):
229
- logger.info(
230
- net_g.module.load_state_dict(
231
- torch.load(hps.pretrainG, map_location="cpu")["model"]
232
- )
233
- ) ##测试不加载优化器
234
- else:
235
- logger.info(
236
- net_g.load_state_dict(
237
- torch.load(hps.pretrainG, map_location="cpu")["model"]
238
- )
239
- ) ##测试不加载优化器
240
- if hps.pretrainD != "":
241
- if rank == 0:
242
- logger.info("loaded pretrained %s" % (hps.pretrainD))
243
- if hasattr(net_d, "module"):
244
- logger.info(
245
- net_d.module.load_state_dict(
246
- torch.load(hps.pretrainD, map_location="cpu")["model"]
247
- )
248
- )
249
- else:
250
- logger.info(
251
- net_d.load_state_dict(
252
- torch.load(hps.pretrainD, map_location="cpu")["model"]
253
- )
254
- )
255
-
256
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
257
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
258
- )
259
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
260
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
261
- )
262
-
263
- scaler = GradScaler(enabled=hps.train.fp16_run)
264
-
265
- cache = []
266
- for epoch in range(epoch_str, hps.train.epochs + 1):
267
- if rank == 0:
268
- train_and_evaluate(
269
- rank,
270
- epoch,
271
- hps,
272
- [net_g, net_d],
273
- [optim_g, optim_d],
274
- [scheduler_g, scheduler_d],
275
- scaler,
276
- [train_loader, None],
277
- logger,
278
- [writer, writer_eval],
279
- cache,
280
- )
281
- else:
282
- train_and_evaluate(
283
- rank,
284
- epoch,
285
- hps,
286
- [net_g, net_d],
287
- [optim_g, optim_d],
288
- [scheduler_g, scheduler_d],
289
- scaler,
290
- [train_loader, None],
291
- None,
292
- None,
293
- cache,
294
- )
295
- scheduler_g.step()
296
- scheduler_d.step()
297
-
298
-
299
- def train_and_evaluate(
300
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
301
- ):
302
- net_g, net_d = nets
303
- optim_g, optim_d = optims
304
- train_loader, eval_loader = loaders
305
- if writers is not None:
306
- writer, writer_eval = writers
307
-
308
- train_loader.batch_sampler.set_epoch(epoch)
309
- global global_step
310
-
311
- net_g.train()
312
- net_d.train()
313
-
314
- # Prepare data iterator
315
- if hps.if_cache_data_in_gpu == True:
316
- # Use Cache
317
- data_iterator = cache
318
- if cache == []:
319
- # Make new cache
320
- for batch_idx, info in enumerate(train_loader):
321
- # Unpack
322
- if hps.if_f0 == 1:
323
- (
324
- phone,
325
- phone_lengths,
326
- pitch,
327
- pitchf,
328
- spec,
329
- spec_lengths,
330
- wave,
331
- wave_lengths,
332
- sid,
333
- ) = info
334
- else:
335
- (
336
- phone,
337
- phone_lengths,
338
- spec,
339
- spec_lengths,
340
- wave,
341
- wave_lengths,
342
- sid,
343
- ) = info
344
- # Load on CUDA
345
- if torch.cuda.is_available():
346
- phone = phone.cuda(rank, non_blocking=True)
347
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
348
- if hps.if_f0 == 1:
349
- pitch = pitch.cuda(rank, non_blocking=True)
350
- pitchf = pitchf.cuda(rank, non_blocking=True)
351
- sid = sid.cuda(rank, non_blocking=True)
352
- spec = spec.cuda(rank, non_blocking=True)
353
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
354
- wave = wave.cuda(rank, non_blocking=True)
355
- wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
356
- # Cache on list
357
- if hps.if_f0 == 1:
358
- cache.append(
359
- (
360
- batch_idx,
361
- (
362
- phone,
363
- phone_lengths,
364
- pitch,
365
- pitchf,
366
- spec,
367
- spec_lengths,
368
- wave,
369
- wave_lengths,
370
- sid,
371
- ),
372
- )
373
- )
374
- else:
375
- cache.append(
376
- (
377
- batch_idx,
378
- (
379
- phone,
380
- phone_lengths,
381
- spec,
382
- spec_lengths,
383
- wave,
384
- wave_lengths,
385
- sid,
386
- ),
387
- )
388
- )
389
- else:
390
- # Load shuffled cache
391
- shuffle(cache)
392
- else:
393
- # Loader
394
- data_iterator = enumerate(train_loader)
395
-
396
- # Run steps
397
- epoch_recorder = EpochRecorder()
398
- for batch_idx, info in data_iterator:
399
- # Data
400
- ## Unpack
401
- if hps.if_f0 == 1:
402
- (
403
- phone,
404
- phone_lengths,
405
- pitch,
406
- pitchf,
407
- spec,
408
- spec_lengths,
409
- wave,
410
- wave_lengths,
411
- sid,
412
- ) = info
413
- else:
414
- phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
415
- ## Load on CUDA
416
- if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
417
- phone = phone.cuda(rank, non_blocking=True)
418
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
419
- if hps.if_f0 == 1:
420
- pitch = pitch.cuda(rank, non_blocking=True)
421
- pitchf = pitchf.cuda(rank, non_blocking=True)
422
- sid = sid.cuda(rank, non_blocking=True)
423
- spec = spec.cuda(rank, non_blocking=True)
424
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
425
- wave = wave.cuda(rank, non_blocking=True)
426
- # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
427
-
428
- # Calculate
429
- with autocast(enabled=hps.train.fp16_run):
430
- if hps.if_f0 == 1:
431
- (
432
- y_hat,
433
- ids_slice,
434
- x_mask,
435
- z_mask,
436
- (z, z_p, m_p, logs_p, m_q, logs_q),
437
- ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
438
- else:
439
- (
440
- y_hat,
441
- ids_slice,
442
- x_mask,
443
- z_mask,
444
- (z, z_p, m_p, logs_p, m_q, logs_q),
445
- ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
446
- mel = spec_to_mel_torch(
447
- spec,
448
- hps.data.filter_length,
449
- hps.data.n_mel_channels,
450
- hps.data.sampling_rate,
451
- hps.data.mel_fmin,
452
- hps.data.mel_fmax,
453
- )
454
- y_mel = commons.slice_segments(
455
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
456
- )
457
- with autocast(enabled=False):
458
- y_hat_mel = mel_spectrogram_torch(
459
- y_hat.float().squeeze(1),
460
- hps.data.filter_length,
461
- hps.data.n_mel_channels,
462
- hps.data.sampling_rate,
463
- hps.data.hop_length,
464
- hps.data.win_length,
465
- hps.data.mel_fmin,
466
- hps.data.mel_fmax,
467
- )
468
- if hps.train.fp16_run == True:
469
- y_hat_mel = y_hat_mel.half()
470
- wave = commons.slice_segments(
471
- wave, ids_slice * hps.data.hop_length, hps.train.segment_size
472
- ) # slice
473
-
474
- # Discriminator
475
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
476
- with autocast(enabled=False):
477
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
478
- y_d_hat_r, y_d_hat_g
479
- )
480
- optim_d.zero_grad()
481
- scaler.scale(loss_disc).backward()
482
- scaler.unscale_(optim_d)
483
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
484
- scaler.step(optim_d)
485
-
486
- with autocast(enabled=hps.train.fp16_run):
487
- # Generator
488
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
489
- with autocast(enabled=False):
490
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
491
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
492
- loss_fm = feature_loss(fmap_r, fmap_g)
493
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
494
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
495
- optim_g.zero_grad()
496
- scaler.scale(loss_gen_all).backward()
497
- scaler.unscale_(optim_g)
498
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
499
- scaler.step(optim_g)
500
- scaler.update()
501
-
502
- if rank == 0:
503
- if global_step % hps.train.log_interval == 0:
504
- lr = optim_g.param_groups[0]["lr"]
505
- logger.info(
506
- "Train Epoch: {} [{:.0f}%]".format(
507
- epoch, 100.0 * batch_idx / len(train_loader)
508
- )
509
- )
510
- # Amor For Tensorboard display
511
- if loss_mel > 75:
512
- loss_mel = 75
513
- if loss_kl > 9:
514
- loss_kl = 9
515
-
516
- logger.info([global_step, lr])
517
- logger.info(
518
- f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
519
- )
520
- scalar_dict = {
521
- "loss/g/total": loss_gen_all,
522
- "loss/d/total": loss_disc,
523
- "learning_rate": lr,
524
- "grad_norm_d": grad_norm_d,
525
- "grad_norm_g": grad_norm_g,
526
- }
527
- scalar_dict.update(
528
- {
529
- "loss/g/fm": loss_fm,
530
- "loss/g/mel": loss_mel,
531
- "loss/g/kl": loss_kl,
532
- }
533
- )
534
-
535
- scalar_dict.update(
536
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
537
- )
538
- scalar_dict.update(
539
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
540
- )
541
- scalar_dict.update(
542
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
543
- )
544
- image_dict = {
545
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
546
- y_mel[0].data.cpu().numpy()
547
- ),
548
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
549
- y_hat_mel[0].data.cpu().numpy()
550
- ),
551
- "all/mel": utils.plot_spectrogram_to_numpy(
552
- mel[0].data.cpu().numpy()
553
- ),
554
- }
555
- utils.summarize(
556
- writer=writer,
557
- global_step=global_step,
558
- images=image_dict,
559
- scalars=scalar_dict,
560
- )
561
- global_step += 1
562
- # /Run steps
563
-
564
- if epoch % hps.save_every_epoch == 0 and rank == 0:
565
- if hps.if_latest == 0:
566
- utils.save_checkpoint(
567
- net_g,
568
- optim_g,
569
- hps.train.learning_rate,
570
- epoch,
571
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
572
- )
573
- utils.save_checkpoint(
574
- net_d,
575
- optim_d,
576
- hps.train.learning_rate,
577
- epoch,
578
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
579
- )
580
- else:
581
- utils.save_checkpoint(
582
- net_g,
583
- optim_g,
584
- hps.train.learning_rate,
585
- epoch,
586
- os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
587
- )
588
- utils.save_checkpoint(
589
- net_d,
590
- optim_d,
591
- hps.train.learning_rate,
592
- epoch,
593
- os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
594
- )
595
- if rank == 0 and hps.save_every_weights == "1":
596
- if hasattr(net_g, "module"):
597
- ckpt = net_g.module.state_dict()
598
- else:
599
- ckpt = net_g.state_dict()
600
- logger.info(
601
- "saving ckpt %s_e%s:%s"
602
- % (
603
- hps.name,
604
- epoch,
605
- savee(
606
- ckpt,
607
- hps.sample_rate,
608
- hps.if_f0,
609
- hps.name + "_e%s_s%s" % (epoch, global_step),
610
- epoch,
611
- hps.version,
612
- hps,
613
- ),
614
- )
615
- )
616
-
617
- if rank == 0:
618
- logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
619
- if epoch >= hps.total_epoch and rank == 0:
620
- logger.info("Training is done. The program is closed.")
621
-
622
- if hasattr(net_g, "module"):
623
- ckpt = net_g.module.state_dict()
624
- else:
625
- ckpt = net_g.state_dict()
626
- logger.info(
627
- "saving final ckpt:%s"
628
- % (
629
- savee(
630
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
631
- )
632
- )
633
- )
634
- sleep(1)
635
- os._exit(2333333)
636
-
637
-
638
- if __name__ == "__main__":
639
- torch.multiprocessing.set_start_method("spawn")
640
- main()