XzJosh commited on
Commit
902b99d
1 Parent(s): 3762e4f

Delete s2_train.py

Browse files
Files changed (1) hide show
  1. s2_train.py +0 -566
s2_train.py DELETED
@@ -1,566 +0,0 @@
1
- import utils, os
2
-
3
- hps = utils.get_hparams(stage=2)
4
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
5
- import torch
6
- from torch.nn import functional as F
7
- from torch.utils.data import DataLoader
8
- from torch.utils.tensorboard import SummaryWriter
9
- import torch.multiprocessing as mp
10
- import torch.distributed as dist, traceback
11
- from torch.nn.parallel import DistributedDataParallel as DDP
12
- from torch.cuda.amp import autocast, GradScaler
13
- from tqdm import tqdm
14
- import logging, traceback
15
-
16
- logging.getLogger("matplotlib").setLevel(logging.INFO)
17
- logging.getLogger("h5py").setLevel(logging.INFO)
18
- logging.getLogger("numba").setLevel(logging.INFO)
19
- from random import randint
20
- from module import commons
21
-
22
- from module.data_utils import (
23
- TextAudioSpeakerLoader,
24
- TextAudioSpeakerCollate,
25
- DistributedBucketSampler,
26
- )
27
- from module.models import (
28
- SynthesizerTrn,
29
- MultiPeriodDiscriminator,
30
- )
31
- from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
32
- from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
33
- from process_ckpt import savee
34
-
35
- torch.backends.cudnn.benchmark = False
36
- torch.backends.cudnn.deterministic = False
37
- ###反正A100fp32更快,那试试tf32吧
38
- torch.backends.cuda.matmul.allow_tf32 = True
39
- torch.backends.cudnn.allow_tf32 = True
40
- torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
41
- # from config import pretrained_s2G,pretrained_s2D
42
- global_step = 0
43
-
44
-
45
- def main():
46
- """Assume Single Node Multi GPUs Training Only"""
47
- assert torch.cuda.is_available(), "CPU training is not allowed."
48
-
49
- n_gpus = torch.cuda.device_count()
50
- os.environ["MASTER_ADDR"] = "localhost"
51
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
52
-
53
- mp.spawn(
54
- run,
55
- nprocs=n_gpus,
56
- args=(
57
- n_gpus,
58
- hps,
59
- ),
60
- )
61
-
62
-
63
- def run(rank, n_gpus, hps):
64
- global global_step
65
- if rank == 0:
66
- logger = utils.get_logger(hps.data.exp_dir)
67
- logger.info(hps)
68
- # utils.check_git_hash(hps.s2_ckpt_dir)
69
- writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
70
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
71
-
72
- dist.init_process_group(
73
- backend="gloo" if os.name == "nt" else "nccl",
74
- init_method="env://",
75
- world_size=n_gpus,
76
- rank=rank,
77
- )
78
- torch.manual_seed(hps.train.seed)
79
- torch.cuda.set_device(rank)
80
-
81
- train_dataset = TextAudioSpeakerLoader(hps.data) ########
82
- train_sampler = DistributedBucketSampler(
83
- train_dataset,
84
- hps.train.batch_size,
85
- [
86
- 32,
87
- 300,
88
- 400,
89
- 500,
90
- 600,
91
- 700,
92
- 800,
93
- 900,
94
- 1000,
95
- 1100,
96
- 1200,
97
- 1300,
98
- 1400,
99
- 1500,
100
- 1600,
101
- 1700,
102
- 1800,
103
- 1900,
104
- ],
105
- num_replicas=n_gpus,
106
- rank=rank,
107
- shuffle=True,
108
- )
109
- collate_fn = TextAudioSpeakerCollate()
110
- train_loader = DataLoader(
111
- train_dataset,
112
- num_workers=6,
113
- shuffle=False,
114
- pin_memory=True,
115
- collate_fn=collate_fn,
116
- batch_sampler=train_sampler,
117
- persistent_workers=True,
118
- prefetch_factor=16,
119
- )
120
- # if rank == 0:
121
- # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
122
- # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
123
- # batch_size=1, pin_memory=True,
124
- # drop_last=False, collate_fn=collate_fn)
125
-
126
- net_g = SynthesizerTrn(
127
- hps.data.filter_length // 2 + 1,
128
- hps.train.segment_size // hps.data.hop_length,
129
- n_speakers=hps.data.n_speakers,
130
- **hps.model,
131
- ).cuda(rank)
132
-
133
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
134
- for name, param in net_g.named_parameters():
135
- if not param.requires_grad:
136
- print(name, "not requires_grad")
137
-
138
- te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
139
- et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
140
- mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
141
- base_params = filter(
142
- lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
143
- net_g.parameters(),
144
- )
145
-
146
- # te_p=net_g.enc_p.text_embedding.parameters()
147
- # et_p=net_g.enc_p.encoder_text.parameters()
148
- # mrte_p=net_g.enc_p.mrte.parameters()
149
-
150
- optim_g = torch.optim.AdamW(
151
- # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
152
- [
153
- {"params": base_params, "lr": hps.train.learning_rate},
154
- {
155
- "params": net_g.enc_p.text_embedding.parameters(),
156
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
157
- },
158
- {
159
- "params": net_g.enc_p.encoder_text.parameters(),
160
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
161
- },
162
- {
163
- "params": net_g.enc_p.mrte.parameters(),
164
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
165
- },
166
- ],
167
- hps.train.learning_rate,
168
- betas=hps.train.betas,
169
- eps=hps.train.eps,
170
- )
171
- optim_d = torch.optim.AdamW(
172
- net_d.parameters(),
173
- hps.train.learning_rate,
174
- betas=hps.train.betas,
175
- eps=hps.train.eps,
176
- )
177
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
178
- net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
179
-
180
- try: # 如果能加载自动resume
181
- _, _, _, epoch_str = utils.load_checkpoint(
182
- utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
183
- net_d,
184
- optim_d,
185
- ) # D多半加载没事
186
- if rank == 0:
187
- logger.info("loaded D")
188
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
189
- _, _, _, epoch_str = utils.load_checkpoint(
190
- utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
191
- net_g,
192
- optim_g,
193
- )
194
- global_step = (epoch_str - 1) * len(train_loader)
195
- # epoch_str = 1
196
- # global_step = 0
197
- except: # 如果首次不能加载,加载pretrain
198
- # traceback.print_exc()
199
- epoch_str = 1
200
- global_step = 0
201
- if hps.train.pretrained_s2G != "":
202
- if rank == 0:
203
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
204
- print(
205
- net_g.module.load_state_dict(
206
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
207
- strict=False,
208
- )
209
- ) ##测试不加载优化器
210
- if hps.train.pretrained_s2D != "":
211
- if rank == 0:
212
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
213
- print(
214
- net_d.module.load_state_dict(
215
- torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
216
- )
217
- )
218
-
219
- # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
220
- # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
221
-
222
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
223
- optim_g, gamma=hps.train.lr_decay, last_epoch=-1
224
- )
225
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
226
- optim_d, gamma=hps.train.lr_decay, last_epoch=-1
227
- )
228
- for _ in range(epoch_str):
229
- scheduler_g.step()
230
- scheduler_d.step()
231
-
232
- scaler = GradScaler(enabled=hps.train.fp16_run)
233
-
234
- for epoch in range(epoch_str, hps.train.epochs + 1):
235
- if rank == 0:
236
- train_and_evaluate(
237
- rank,
238
- epoch,
239
- hps,
240
- [net_g, net_d],
241
- [optim_g, optim_d],
242
- [scheduler_g, scheduler_d],
243
- scaler,
244
- # [train_loader, eval_loader], logger, [writer, writer_eval])
245
- [train_loader, None],
246
- logger,
247
- [writer, writer_eval],
248
- )
249
- else:
250
- train_and_evaluate(
251
- rank,
252
- epoch,
253
- hps,
254
- [net_g, net_d],
255
- [optim_g, optim_d],
256
- [scheduler_g, scheduler_d],
257
- scaler,
258
- [train_loader, None],
259
- None,
260
- None,
261
- )
262
- scheduler_g.step()
263
- scheduler_d.step()
264
-
265
-
266
- def train_and_evaluate(
267
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
268
- ):
269
- net_g, net_d = nets
270
- optim_g, optim_d = optims
271
- # scheduler_g, scheduler_d = schedulers
272
- train_loader, eval_loader = loaders
273
- if writers is not None:
274
- writer, writer_eval = writers
275
-
276
- train_loader.batch_sampler.set_epoch(epoch)
277
- global global_step
278
-
279
- net_g.train()
280
- net_d.train()
281
- for batch_idx, (
282
- ssl,
283
- ssl_lengths,
284
- spec,
285
- spec_lengths,
286
- y,
287
- y_lengths,
288
- text,
289
- text_lengths,
290
- ) in tqdm(enumerate(train_loader)):
291
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
292
- rank, non_blocking=True
293
- )
294
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
295
- rank, non_blocking=True
296
- )
297
- ssl = ssl.cuda(rank, non_blocking=True)
298
- ssl.requires_grad = False
299
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
300
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
301
- rank, non_blocking=True
302
- )
303
-
304
- with autocast(enabled=hps.train.fp16_run):
305
- (
306
- y_hat,
307
- kl_ssl,
308
- ids_slice,
309
- x_mask,
310
- z_mask,
311
- (z, z_p, m_p, logs_p, m_q, logs_q),
312
- stats_ssl,
313
- ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
314
-
315
- mel = spec_to_mel_torch(
316
- spec,
317
- hps.data.filter_length,
318
- hps.data.n_mel_channels,
319
- hps.data.sampling_rate,
320
- hps.data.mel_fmin,
321
- hps.data.mel_fmax,
322
- )
323
- y_mel = commons.slice_segments(
324
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
325
- )
326
- y_hat_mel = mel_spectrogram_torch(
327
- y_hat.squeeze(1),
328
- hps.data.filter_length,
329
- hps.data.n_mel_channels,
330
- hps.data.sampling_rate,
331
- hps.data.hop_length,
332
- hps.data.win_length,
333
- hps.data.mel_fmin,
334
- hps.data.mel_fmax,
335
- )
336
-
337
- y = commons.slice_segments(
338
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
339
- ) # slice
340
-
341
- # Discriminator
342
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
343
- with autocast(enabled=False):
344
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
345
- y_d_hat_r, y_d_hat_g
346
- )
347
- loss_disc_all = loss_disc
348
- optim_d.zero_grad()
349
- scaler.scale(loss_disc_all).backward()
350
- scaler.unscale_(optim_d)
351
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
352
- scaler.step(optim_d)
353
-
354
- with autocast(enabled=hps.train.fp16_run):
355
- # Generator
356
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
357
- with autocast(enabled=False):
358
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
359
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
360
-
361
- loss_fm = feature_loss(fmap_r, fmap_g)
362
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
363
- loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
364
-
365
- optim_g.zero_grad()
366
- scaler.scale(loss_gen_all).backward()
367
- scaler.unscale_(optim_g)
368
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
369
- scaler.step(optim_g)
370
- scaler.update()
371
-
372
- if rank == 0:
373
- if global_step % hps.train.log_interval == 0:
374
- lr = optim_g.param_groups[0]["lr"]
375
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
376
- logger.info(
377
- "Train Epoch: {} [{:.0f}%]".format(
378
- epoch, 100.0 * batch_idx / len(train_loader)
379
- )
380
- )
381
- logger.info([x.item() for x in losses] + [global_step, lr])
382
-
383
- scalar_dict = {
384
- "loss/g/total": loss_gen_all,
385
- "loss/d/total": loss_disc_all,
386
- "learning_rate": lr,
387
- "grad_norm_d": grad_norm_d,
388
- "grad_norm_g": grad_norm_g,
389
- }
390
- scalar_dict.update(
391
- {
392
- "loss/g/fm": loss_fm,
393
- "loss/g/mel": loss_mel,
394
- "loss/g/kl_ssl": kl_ssl,
395
- "loss/g/kl": loss_kl,
396
- }
397
- )
398
-
399
- # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
400
- # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
401
- # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
402
- image_dict = {
403
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
404
- y_mel[0].data.cpu().numpy()
405
- ),
406
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
407
- y_hat_mel[0].data.cpu().numpy()
408
- ),
409
- "all/mel": utils.plot_spectrogram_to_numpy(
410
- mel[0].data.cpu().numpy()
411
- ),
412
- "all/stats_ssl": utils.plot_spectrogram_to_numpy(
413
- stats_ssl[0].data.cpu().numpy()
414
- ),
415
- }
416
- utils.summarize(
417
- writer=writer,
418
- global_step=global_step,
419
- images=image_dict,
420
- scalars=scalar_dict,
421
- )
422
- global_step += 1
423
- if epoch % hps.train.save_every_epoch == 0 and rank == 0:
424
- if hps.train.if_save_latest == 0:
425
- utils.save_checkpoint(
426
- net_g,
427
- optim_g,
428
- hps.train.learning_rate,
429
- epoch,
430
- os.path.join(
431
- "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
432
- ),
433
- )
434
- utils.save_checkpoint(
435
- net_d,
436
- optim_d,
437
- hps.train.learning_rate,
438
- epoch,
439
- os.path.join(
440
- "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
441
- ),
442
- )
443
- else:
444
- utils.save_checkpoint(
445
- net_g,
446
- optim_g,
447
- hps.train.learning_rate,
448
- epoch,
449
- os.path.join(
450
- "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
451
- ),
452
- )
453
- utils.save_checkpoint(
454
- net_d,
455
- optim_d,
456
- hps.train.learning_rate,
457
- epoch,
458
- os.path.join(
459
- "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
460
- ),
461
- )
462
- if rank == 0 and hps.train.if_save_every_weights == True:
463
- if hasattr(net_g, "module"):
464
- ckpt = net_g.module.state_dict()
465
- else:
466
- ckpt = net_g.state_dict()
467
- logger.info(
468
- "saving ckpt %s_e%s:%s"
469
- % (
470
- hps.name,
471
- epoch,
472
- savee(
473
- ckpt,
474
- hps.name + "_e%s_s%s" % (epoch, global_step),
475
- epoch,
476
- global_step,
477
- hps,
478
- ),
479
- )
480
- )
481
-
482
- if rank == 0:
483
- logger.info("====> Epoch: {}".format(epoch))
484
-
485
-
486
- def evaluate(hps, generator, eval_loader, writer_eval):
487
- generator.eval()
488
- image_dict = {}
489
- audio_dict = {}
490
- print("Evaluating ...")
491
- with torch.no_grad():
492
- for batch_idx, (
493
- ssl,
494
- ssl_lengths,
495
- spec,
496
- spec_lengths,
497
- y,
498
- y_lengths,
499
- text,
500
- text_lengths,
501
- ) in enumerate(eval_loader):
502
- print(111)
503
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
504
- y, y_lengths = y.cuda(), y_lengths.cuda()
505
- ssl = ssl.cuda()
506
- text, text_lengths = text.cuda(), text_lengths.cuda()
507
- for test in [0, 1]:
508
- y_hat, mask, *_ = generator.module.infer(
509
- ssl, spec, spec_lengths, text, text_lengths, test=test
510
- )
511
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
512
-
513
- mel = spec_to_mel_torch(
514
- spec,
515
- hps.data.filter_length,
516
- hps.data.n_mel_channels,
517
- hps.data.sampling_rate,
518
- hps.data.mel_fmin,
519
- hps.data.mel_fmax,
520
- )
521
- y_hat_mel = mel_spectrogram_torch(
522
- y_hat.squeeze(1).float(),
523
- hps.data.filter_length,
524
- hps.data.n_mel_channels,
525
- hps.data.sampling_rate,
526
- hps.data.hop_length,
527
- hps.data.win_length,
528
- hps.data.mel_fmin,
529
- hps.data.mel_fmax,
530
- )
531
- image_dict.update(
532
- {
533
- f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
534
- y_hat_mel[0].cpu().numpy()
535
- )
536
- }
537
- )
538
- audio_dict.update(
539
- {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
540
- )
541
- image_dict.update(
542
- {
543
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
544
- mel[0].cpu().numpy()
545
- )
546
- }
547
- )
548
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
549
-
550
- # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
551
- # audio_dict.update({
552
- # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
553
- # })
554
-
555
- utils.summarize(
556
- writer=writer_eval,
557
- global_step=global_step,
558
- images=image_dict,
559
- audios=audio_dict,
560
- audio_sampling_rate=hps.data.sampling_rate,
561
- )
562
- generator.train()
563
-
564
-
565
- if __name__ == "__main__":
566
- main()