Spaces:
Running
Running
Delete s2_train.py
Browse files- s2_train.py +0 -566
s2_train.py
DELETED
@@ -1,566 +0,0 @@
|
|
1 |
-
import utils, os
|
2 |
-
|
3 |
-
hps = utils.get_hparams(stage=2)
|
4 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
5 |
-
import torch
|
6 |
-
from torch.nn import functional as F
|
7 |
-
from torch.utils.data import DataLoader
|
8 |
-
from torch.utils.tensorboard import SummaryWriter
|
9 |
-
import torch.multiprocessing as mp
|
10 |
-
import torch.distributed as dist, traceback
|
11 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
12 |
-
from torch.cuda.amp import autocast, GradScaler
|
13 |
-
from tqdm import tqdm
|
14 |
-
import logging, traceback
|
15 |
-
|
16 |
-
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
17 |
-
logging.getLogger("h5py").setLevel(logging.INFO)
|
18 |
-
logging.getLogger("numba").setLevel(logging.INFO)
|
19 |
-
from random import randint
|
20 |
-
from module import commons
|
21 |
-
|
22 |
-
from module.data_utils import (
|
23 |
-
TextAudioSpeakerLoader,
|
24 |
-
TextAudioSpeakerCollate,
|
25 |
-
DistributedBucketSampler,
|
26 |
-
)
|
27 |
-
from module.models import (
|
28 |
-
SynthesizerTrn,
|
29 |
-
MultiPeriodDiscriminator,
|
30 |
-
)
|
31 |
-
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
32 |
-
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
33 |
-
from process_ckpt import savee
|
34 |
-
|
35 |
-
torch.backends.cudnn.benchmark = False
|
36 |
-
torch.backends.cudnn.deterministic = False
|
37 |
-
###反正A100fp32更快,那试试tf32吧
|
38 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
39 |
-
torch.backends.cudnn.allow_tf32 = True
|
40 |
-
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
41 |
-
# from config import pretrained_s2G,pretrained_s2D
|
42 |
-
global_step = 0
|
43 |
-
|
44 |
-
|
45 |
-
def main():
|
46 |
-
"""Assume Single Node Multi GPUs Training Only"""
|
47 |
-
assert torch.cuda.is_available(), "CPU training is not allowed."
|
48 |
-
|
49 |
-
n_gpus = torch.cuda.device_count()
|
50 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
51 |
-
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
52 |
-
|
53 |
-
mp.spawn(
|
54 |
-
run,
|
55 |
-
nprocs=n_gpus,
|
56 |
-
args=(
|
57 |
-
n_gpus,
|
58 |
-
hps,
|
59 |
-
),
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
def run(rank, n_gpus, hps):
|
64 |
-
global global_step
|
65 |
-
if rank == 0:
|
66 |
-
logger = utils.get_logger(hps.data.exp_dir)
|
67 |
-
logger.info(hps)
|
68 |
-
# utils.check_git_hash(hps.s2_ckpt_dir)
|
69 |
-
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
70 |
-
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
71 |
-
|
72 |
-
dist.init_process_group(
|
73 |
-
backend="gloo" if os.name == "nt" else "nccl",
|
74 |
-
init_method="env://",
|
75 |
-
world_size=n_gpus,
|
76 |
-
rank=rank,
|
77 |
-
)
|
78 |
-
torch.manual_seed(hps.train.seed)
|
79 |
-
torch.cuda.set_device(rank)
|
80 |
-
|
81 |
-
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
82 |
-
train_sampler = DistributedBucketSampler(
|
83 |
-
train_dataset,
|
84 |
-
hps.train.batch_size,
|
85 |
-
[
|
86 |
-
32,
|
87 |
-
300,
|
88 |
-
400,
|
89 |
-
500,
|
90 |
-
600,
|
91 |
-
700,
|
92 |
-
800,
|
93 |
-
900,
|
94 |
-
1000,
|
95 |
-
1100,
|
96 |
-
1200,
|
97 |
-
1300,
|
98 |
-
1400,
|
99 |
-
1500,
|
100 |
-
1600,
|
101 |
-
1700,
|
102 |
-
1800,
|
103 |
-
1900,
|
104 |
-
],
|
105 |
-
num_replicas=n_gpus,
|
106 |
-
rank=rank,
|
107 |
-
shuffle=True,
|
108 |
-
)
|
109 |
-
collate_fn = TextAudioSpeakerCollate()
|
110 |
-
train_loader = DataLoader(
|
111 |
-
train_dataset,
|
112 |
-
num_workers=6,
|
113 |
-
shuffle=False,
|
114 |
-
pin_memory=True,
|
115 |
-
collate_fn=collate_fn,
|
116 |
-
batch_sampler=train_sampler,
|
117 |
-
persistent_workers=True,
|
118 |
-
prefetch_factor=16,
|
119 |
-
)
|
120 |
-
# if rank == 0:
|
121 |
-
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
122 |
-
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
123 |
-
# batch_size=1, pin_memory=True,
|
124 |
-
# drop_last=False, collate_fn=collate_fn)
|
125 |
-
|
126 |
-
net_g = SynthesizerTrn(
|
127 |
-
hps.data.filter_length // 2 + 1,
|
128 |
-
hps.train.segment_size // hps.data.hop_length,
|
129 |
-
n_speakers=hps.data.n_speakers,
|
130 |
-
**hps.model,
|
131 |
-
).cuda(rank)
|
132 |
-
|
133 |
-
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
134 |
-
for name, param in net_g.named_parameters():
|
135 |
-
if not param.requires_grad:
|
136 |
-
print(name, "not requires_grad")
|
137 |
-
|
138 |
-
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
|
139 |
-
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
|
140 |
-
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
|
141 |
-
base_params = filter(
|
142 |
-
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
|
143 |
-
net_g.parameters(),
|
144 |
-
)
|
145 |
-
|
146 |
-
# te_p=net_g.enc_p.text_embedding.parameters()
|
147 |
-
# et_p=net_g.enc_p.encoder_text.parameters()
|
148 |
-
# mrte_p=net_g.enc_p.mrte.parameters()
|
149 |
-
|
150 |
-
optim_g = torch.optim.AdamW(
|
151 |
-
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
152 |
-
[
|
153 |
-
{"params": base_params, "lr": hps.train.learning_rate},
|
154 |
-
{
|
155 |
-
"params": net_g.enc_p.text_embedding.parameters(),
|
156 |
-
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
157 |
-
},
|
158 |
-
{
|
159 |
-
"params": net_g.enc_p.encoder_text.parameters(),
|
160 |
-
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"params": net_g.enc_p.mrte.parameters(),
|
164 |
-
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
165 |
-
},
|
166 |
-
],
|
167 |
-
hps.train.learning_rate,
|
168 |
-
betas=hps.train.betas,
|
169 |
-
eps=hps.train.eps,
|
170 |
-
)
|
171 |
-
optim_d = torch.optim.AdamW(
|
172 |
-
net_d.parameters(),
|
173 |
-
hps.train.learning_rate,
|
174 |
-
betas=hps.train.betas,
|
175 |
-
eps=hps.train.eps,
|
176 |
-
)
|
177 |
-
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
178 |
-
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
179 |
-
|
180 |
-
try: # 如果能加载自动resume
|
181 |
-
_, _, _, epoch_str = utils.load_checkpoint(
|
182 |
-
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
|
183 |
-
net_d,
|
184 |
-
optim_d,
|
185 |
-
) # D多半加载没事
|
186 |
-
if rank == 0:
|
187 |
-
logger.info("loaded D")
|
188 |
-
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
189 |
-
_, _, _, epoch_str = utils.load_checkpoint(
|
190 |
-
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
|
191 |
-
net_g,
|
192 |
-
optim_g,
|
193 |
-
)
|
194 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
195 |
-
# epoch_str = 1
|
196 |
-
# global_step = 0
|
197 |
-
except: # 如果首次不能加载,加载pretrain
|
198 |
-
# traceback.print_exc()
|
199 |
-
epoch_str = 1
|
200 |
-
global_step = 0
|
201 |
-
if hps.train.pretrained_s2G != "":
|
202 |
-
if rank == 0:
|
203 |
-
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
204 |
-
print(
|
205 |
-
net_g.module.load_state_dict(
|
206 |
-
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
207 |
-
strict=False,
|
208 |
-
)
|
209 |
-
) ##测试不加载优化器
|
210 |
-
if hps.train.pretrained_s2D != "":
|
211 |
-
if rank == 0:
|
212 |
-
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
213 |
-
print(
|
214 |
-
net_d.module.load_state_dict(
|
215 |
-
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
216 |
-
)
|
217 |
-
)
|
218 |
-
|
219 |
-
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
220 |
-
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
221 |
-
|
222 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
223 |
-
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
224 |
-
)
|
225 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
226 |
-
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
227 |
-
)
|
228 |
-
for _ in range(epoch_str):
|
229 |
-
scheduler_g.step()
|
230 |
-
scheduler_d.step()
|
231 |
-
|
232 |
-
scaler = GradScaler(enabled=hps.train.fp16_run)
|
233 |
-
|
234 |
-
for epoch in range(epoch_str, hps.train.epochs + 1):
|
235 |
-
if rank == 0:
|
236 |
-
train_and_evaluate(
|
237 |
-
rank,
|
238 |
-
epoch,
|
239 |
-
hps,
|
240 |
-
[net_g, net_d],
|
241 |
-
[optim_g, optim_d],
|
242 |
-
[scheduler_g, scheduler_d],
|
243 |
-
scaler,
|
244 |
-
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
245 |
-
[train_loader, None],
|
246 |
-
logger,
|
247 |
-
[writer, writer_eval],
|
248 |
-
)
|
249 |
-
else:
|
250 |
-
train_and_evaluate(
|
251 |
-
rank,
|
252 |
-
epoch,
|
253 |
-
hps,
|
254 |
-
[net_g, net_d],
|
255 |
-
[optim_g, optim_d],
|
256 |
-
[scheduler_g, scheduler_d],
|
257 |
-
scaler,
|
258 |
-
[train_loader, None],
|
259 |
-
None,
|
260 |
-
None,
|
261 |
-
)
|
262 |
-
scheduler_g.step()
|
263 |
-
scheduler_d.step()
|
264 |
-
|
265 |
-
|
266 |
-
def train_and_evaluate(
|
267 |
-
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
268 |
-
):
|
269 |
-
net_g, net_d = nets
|
270 |
-
optim_g, optim_d = optims
|
271 |
-
# scheduler_g, scheduler_d = schedulers
|
272 |
-
train_loader, eval_loader = loaders
|
273 |
-
if writers is not None:
|
274 |
-
writer, writer_eval = writers
|
275 |
-
|
276 |
-
train_loader.batch_sampler.set_epoch(epoch)
|
277 |
-
global global_step
|
278 |
-
|
279 |
-
net_g.train()
|
280 |
-
net_d.train()
|
281 |
-
for batch_idx, (
|
282 |
-
ssl,
|
283 |
-
ssl_lengths,
|
284 |
-
spec,
|
285 |
-
spec_lengths,
|
286 |
-
y,
|
287 |
-
y_lengths,
|
288 |
-
text,
|
289 |
-
text_lengths,
|
290 |
-
) in tqdm(enumerate(train_loader)):
|
291 |
-
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
292 |
-
rank, non_blocking=True
|
293 |
-
)
|
294 |
-
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
295 |
-
rank, non_blocking=True
|
296 |
-
)
|
297 |
-
ssl = ssl.cuda(rank, non_blocking=True)
|
298 |
-
ssl.requires_grad = False
|
299 |
-
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
300 |
-
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
301 |
-
rank, non_blocking=True
|
302 |
-
)
|
303 |
-
|
304 |
-
with autocast(enabled=hps.train.fp16_run):
|
305 |
-
(
|
306 |
-
y_hat,
|
307 |
-
kl_ssl,
|
308 |
-
ids_slice,
|
309 |
-
x_mask,
|
310 |
-
z_mask,
|
311 |
-
(z, z_p, m_p, logs_p, m_q, logs_q),
|
312 |
-
stats_ssl,
|
313 |
-
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
314 |
-
|
315 |
-
mel = spec_to_mel_torch(
|
316 |
-
spec,
|
317 |
-
hps.data.filter_length,
|
318 |
-
hps.data.n_mel_channels,
|
319 |
-
hps.data.sampling_rate,
|
320 |
-
hps.data.mel_fmin,
|
321 |
-
hps.data.mel_fmax,
|
322 |
-
)
|
323 |
-
y_mel = commons.slice_segments(
|
324 |
-
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
325 |
-
)
|
326 |
-
y_hat_mel = mel_spectrogram_torch(
|
327 |
-
y_hat.squeeze(1),
|
328 |
-
hps.data.filter_length,
|
329 |
-
hps.data.n_mel_channels,
|
330 |
-
hps.data.sampling_rate,
|
331 |
-
hps.data.hop_length,
|
332 |
-
hps.data.win_length,
|
333 |
-
hps.data.mel_fmin,
|
334 |
-
hps.data.mel_fmax,
|
335 |
-
)
|
336 |
-
|
337 |
-
y = commons.slice_segments(
|
338 |
-
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
339 |
-
) # slice
|
340 |
-
|
341 |
-
# Discriminator
|
342 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
343 |
-
with autocast(enabled=False):
|
344 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
345 |
-
y_d_hat_r, y_d_hat_g
|
346 |
-
)
|
347 |
-
loss_disc_all = loss_disc
|
348 |
-
optim_d.zero_grad()
|
349 |
-
scaler.scale(loss_disc_all).backward()
|
350 |
-
scaler.unscale_(optim_d)
|
351 |
-
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
352 |
-
scaler.step(optim_d)
|
353 |
-
|
354 |
-
with autocast(enabled=hps.train.fp16_run):
|
355 |
-
# Generator
|
356 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
357 |
-
with autocast(enabled=False):
|
358 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
359 |
-
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
360 |
-
|
361 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
362 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
363 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
|
364 |
-
|
365 |
-
optim_g.zero_grad()
|
366 |
-
scaler.scale(loss_gen_all).backward()
|
367 |
-
scaler.unscale_(optim_g)
|
368 |
-
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
369 |
-
scaler.step(optim_g)
|
370 |
-
scaler.update()
|
371 |
-
|
372 |
-
if rank == 0:
|
373 |
-
if global_step % hps.train.log_interval == 0:
|
374 |
-
lr = optim_g.param_groups[0]["lr"]
|
375 |
-
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
376 |
-
logger.info(
|
377 |
-
"Train Epoch: {} [{:.0f}%]".format(
|
378 |
-
epoch, 100.0 * batch_idx / len(train_loader)
|
379 |
-
)
|
380 |
-
)
|
381 |
-
logger.info([x.item() for x in losses] + [global_step, lr])
|
382 |
-
|
383 |
-
scalar_dict = {
|
384 |
-
"loss/g/total": loss_gen_all,
|
385 |
-
"loss/d/total": loss_disc_all,
|
386 |
-
"learning_rate": lr,
|
387 |
-
"grad_norm_d": grad_norm_d,
|
388 |
-
"grad_norm_g": grad_norm_g,
|
389 |
-
}
|
390 |
-
scalar_dict.update(
|
391 |
-
{
|
392 |
-
"loss/g/fm": loss_fm,
|
393 |
-
"loss/g/mel": loss_mel,
|
394 |
-
"loss/g/kl_ssl": kl_ssl,
|
395 |
-
"loss/g/kl": loss_kl,
|
396 |
-
}
|
397 |
-
)
|
398 |
-
|
399 |
-
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
400 |
-
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
401 |
-
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
402 |
-
image_dict = {
|
403 |
-
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
404 |
-
y_mel[0].data.cpu().numpy()
|
405 |
-
),
|
406 |
-
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
407 |
-
y_hat_mel[0].data.cpu().numpy()
|
408 |
-
),
|
409 |
-
"all/mel": utils.plot_spectrogram_to_numpy(
|
410 |
-
mel[0].data.cpu().numpy()
|
411 |
-
),
|
412 |
-
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
413 |
-
stats_ssl[0].data.cpu().numpy()
|
414 |
-
),
|
415 |
-
}
|
416 |
-
utils.summarize(
|
417 |
-
writer=writer,
|
418 |
-
global_step=global_step,
|
419 |
-
images=image_dict,
|
420 |
-
scalars=scalar_dict,
|
421 |
-
)
|
422 |
-
global_step += 1
|
423 |
-
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
424 |
-
if hps.train.if_save_latest == 0:
|
425 |
-
utils.save_checkpoint(
|
426 |
-
net_g,
|
427 |
-
optim_g,
|
428 |
-
hps.train.learning_rate,
|
429 |
-
epoch,
|
430 |
-
os.path.join(
|
431 |
-
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
|
432 |
-
),
|
433 |
-
)
|
434 |
-
utils.save_checkpoint(
|
435 |
-
net_d,
|
436 |
-
optim_d,
|
437 |
-
hps.train.learning_rate,
|
438 |
-
epoch,
|
439 |
-
os.path.join(
|
440 |
-
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
|
441 |
-
),
|
442 |
-
)
|
443 |
-
else:
|
444 |
-
utils.save_checkpoint(
|
445 |
-
net_g,
|
446 |
-
optim_g,
|
447 |
-
hps.train.learning_rate,
|
448 |
-
epoch,
|
449 |
-
os.path.join(
|
450 |
-
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
|
451 |
-
),
|
452 |
-
)
|
453 |
-
utils.save_checkpoint(
|
454 |
-
net_d,
|
455 |
-
optim_d,
|
456 |
-
hps.train.learning_rate,
|
457 |
-
epoch,
|
458 |
-
os.path.join(
|
459 |
-
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
|
460 |
-
),
|
461 |
-
)
|
462 |
-
if rank == 0 and hps.train.if_save_every_weights == True:
|
463 |
-
if hasattr(net_g, "module"):
|
464 |
-
ckpt = net_g.module.state_dict()
|
465 |
-
else:
|
466 |
-
ckpt = net_g.state_dict()
|
467 |
-
logger.info(
|
468 |
-
"saving ckpt %s_e%s:%s"
|
469 |
-
% (
|
470 |
-
hps.name,
|
471 |
-
epoch,
|
472 |
-
savee(
|
473 |
-
ckpt,
|
474 |
-
hps.name + "_e%s_s%s" % (epoch, global_step),
|
475 |
-
epoch,
|
476 |
-
global_step,
|
477 |
-
hps,
|
478 |
-
),
|
479 |
-
)
|
480 |
-
)
|
481 |
-
|
482 |
-
if rank == 0:
|
483 |
-
logger.info("====> Epoch: {}".format(epoch))
|
484 |
-
|
485 |
-
|
486 |
-
def evaluate(hps, generator, eval_loader, writer_eval):
|
487 |
-
generator.eval()
|
488 |
-
image_dict = {}
|
489 |
-
audio_dict = {}
|
490 |
-
print("Evaluating ...")
|
491 |
-
with torch.no_grad():
|
492 |
-
for batch_idx, (
|
493 |
-
ssl,
|
494 |
-
ssl_lengths,
|
495 |
-
spec,
|
496 |
-
spec_lengths,
|
497 |
-
y,
|
498 |
-
y_lengths,
|
499 |
-
text,
|
500 |
-
text_lengths,
|
501 |
-
) in enumerate(eval_loader):
|
502 |
-
print(111)
|
503 |
-
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
504 |
-
y, y_lengths = y.cuda(), y_lengths.cuda()
|
505 |
-
ssl = ssl.cuda()
|
506 |
-
text, text_lengths = text.cuda(), text_lengths.cuda()
|
507 |
-
for test in [0, 1]:
|
508 |
-
y_hat, mask, *_ = generator.module.infer(
|
509 |
-
ssl, spec, spec_lengths, text, text_lengths, test=test
|
510 |
-
)
|
511 |
-
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
512 |
-
|
513 |
-
mel = spec_to_mel_torch(
|
514 |
-
spec,
|
515 |
-
hps.data.filter_length,
|
516 |
-
hps.data.n_mel_channels,
|
517 |
-
hps.data.sampling_rate,
|
518 |
-
hps.data.mel_fmin,
|
519 |
-
hps.data.mel_fmax,
|
520 |
-
)
|
521 |
-
y_hat_mel = mel_spectrogram_torch(
|
522 |
-
y_hat.squeeze(1).float(),
|
523 |
-
hps.data.filter_length,
|
524 |
-
hps.data.n_mel_channels,
|
525 |
-
hps.data.sampling_rate,
|
526 |
-
hps.data.hop_length,
|
527 |
-
hps.data.win_length,
|
528 |
-
hps.data.mel_fmin,
|
529 |
-
hps.data.mel_fmax,
|
530 |
-
)
|
531 |
-
image_dict.update(
|
532 |
-
{
|
533 |
-
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
534 |
-
y_hat_mel[0].cpu().numpy()
|
535 |
-
)
|
536 |
-
}
|
537 |
-
)
|
538 |
-
audio_dict.update(
|
539 |
-
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
540 |
-
)
|
541 |
-
image_dict.update(
|
542 |
-
{
|
543 |
-
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
544 |
-
mel[0].cpu().numpy()
|
545 |
-
)
|
546 |
-
}
|
547 |
-
)
|
548 |
-
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
549 |
-
|
550 |
-
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
|
551 |
-
# audio_dict.update({
|
552 |
-
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
|
553 |
-
# })
|
554 |
-
|
555 |
-
utils.summarize(
|
556 |
-
writer=writer_eval,
|
557 |
-
global_step=global_step,
|
558 |
-
images=image_dict,
|
559 |
-
audios=audio_dict,
|
560 |
-
audio_sampling_rate=hps.data.sampling_rate,
|
561 |
-
)
|
562 |
-
generator.train()
|
563 |
-
|
564 |
-
|
565 |
-
if __name__ == "__main__":
|
566 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|