|
import argparse |
|
import glob |
|
import os |
|
import sys |
|
import time |
|
import traceback |
|
from inspect import signature |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from TTS.utils.audio import AudioProcessor |
|
from TTS.utils.console_logger import ConsoleLogger |
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters, |
|
create_experiment_folder, get_git_branch, |
|
remove_experiment_folder, set_init_dict) |
|
from TTS.utils.io import copy_model_files, load_config |
|
from TTS.utils.radam import RAdam |
|
from TTS.utils.tensorboard_logger import TensorboardLogger |
|
from TTS.utils.training import setup_torch_training_env |
|
from TTS.vocoder.datasets.gan_dataset import GANDataset |
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data |
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss |
|
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator, |
|
setup_generator) |
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint |
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP_th |
|
from torch.utils.data.distributed import DistributedSampler |
|
from TTS.utils.distribute import init_distributed |
|
|
|
use_cuda, num_gpus = setup_torch_training_env(True, True) |
|
|
|
|
|
def setup_loader(ap, is_val=False, verbose=False): |
|
if is_val and not c.run_eval: |
|
loader = None |
|
else: |
|
dataset = GANDataset(ap=ap, |
|
items=eval_data if is_val else train_data, |
|
seq_len=c.seq_len, |
|
hop_len=ap.hop_length, |
|
pad_short=c.pad_short, |
|
conv_pad=c.conv_pad, |
|
is_training=not is_val, |
|
return_segments=not is_val, |
|
use_noise_augment=c.use_noise_augment, |
|
use_cache=c.use_cache, |
|
verbose=verbose) |
|
dataset.shuffle_mapping() |
|
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None |
|
loader = DataLoader(dataset, |
|
batch_size=1 if is_val else c.batch_size, |
|
shuffle=False if num_gpus > 1 else True, |
|
drop_last=False, |
|
sampler=sampler, |
|
num_workers=c.num_val_loader_workers |
|
if is_val else c.num_loader_workers, |
|
pin_memory=False) |
|
return loader |
|
|
|
|
|
def format_data(data): |
|
if isinstance(data[0], list): |
|
|
|
c_G, x_G = data[0] |
|
c_D, x_D = data[1] |
|
|
|
|
|
if use_cuda: |
|
c_G = c_G.cuda(non_blocking=True) |
|
x_G = x_G.cuda(non_blocking=True) |
|
c_D = c_D.cuda(non_blocking=True) |
|
x_D = x_D.cuda(non_blocking=True) |
|
|
|
return c_G, x_G, c_D, x_D |
|
|
|
|
|
co, x = data |
|
if use_cuda: |
|
co = co.cuda(non_blocking=True) |
|
x = x.cuda(non_blocking=True) |
|
return co, x, None, None |
|
|
|
|
|
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, |
|
scheduler_G, scheduler_D, ap, global_step, epoch): |
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) |
|
model_G.train() |
|
model_D.train() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
if use_cuda: |
|
batch_n_iter = int( |
|
len(data_loader.dataset) / (c.batch_size * num_gpus)) |
|
else: |
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) |
|
end_time = time.time() |
|
c_logger.print_train_start() |
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
|
|
|
|
c_G, y_G, c_D, y_D = format_data(data) |
|
loader_time = time.time() - end_time |
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
y_hat = model_G(c_G) |
|
y_hat_sub = None |
|
y_G_sub = None |
|
y_hat_vis = y_hat |
|
|
|
|
|
if y_hat.shape[1] > 1: |
|
y_hat_sub = y_hat |
|
y_hat = model_G.pqmf_synthesis(y_hat) |
|
y_hat_vis = y_hat |
|
y_G_sub = model_G.pqmf_analysis(y_G) |
|
|
|
scores_fake, feats_fake, feats_real = None, None, None |
|
if global_step > c.steps_to_start_discriminator: |
|
|
|
|
|
if len(signature(model_D.forward).parameters) == 2: |
|
D_out_fake = model_D(y_hat, c_G) |
|
else: |
|
D_out_fake = model_D(y_hat) |
|
D_out_real = None |
|
|
|
if c.use_feat_match_loss: |
|
with torch.no_grad(): |
|
D_out_real = model_D(y_G) |
|
|
|
|
|
if isinstance(D_out_fake, tuple): |
|
scores_fake, feats_fake = D_out_fake |
|
if D_out_real is None: |
|
feats_real = None |
|
else: |
|
_, feats_real = D_out_real |
|
else: |
|
scores_fake = D_out_fake |
|
|
|
|
|
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, |
|
feats_real, y_hat_sub, y_G_sub) |
|
loss_G = loss_G_dict['G_loss'] |
|
|
|
|
|
optimizer_G.zero_grad() |
|
loss_G.backward() |
|
if c.gen_clip_grad > 0: |
|
torch.nn.utils.clip_grad_norm_(model_G.parameters(), |
|
c.gen_clip_grad) |
|
optimizer_G.step() |
|
if scheduler_G is not None: |
|
scheduler_G.step() |
|
|
|
loss_dict = dict() |
|
for key, value in loss_G_dict.items(): |
|
if isinstance(value, int): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
|
|
|
|
|
|
if global_step >= c.steps_to_start_discriminator: |
|
|
|
with torch.no_grad(): |
|
y_hat = model_G(c_D) |
|
|
|
|
|
if y_hat.shape[1] > 1: |
|
y_hat = model_G.pqmf_synthesis(y_hat) |
|
|
|
|
|
if len(signature(model_D.forward).parameters) == 2: |
|
D_out_fake = model_D(y_hat.detach(), c_D) |
|
D_out_real = model_D(y_D, c_D) |
|
else: |
|
D_out_fake = model_D(y_hat.detach()) |
|
D_out_real = model_D(y_D) |
|
|
|
|
|
if isinstance(D_out_fake, tuple): |
|
scores_fake, feats_fake = D_out_fake |
|
if D_out_real is None: |
|
scores_real, feats_real = None, None |
|
else: |
|
scores_real, feats_real = D_out_real |
|
else: |
|
scores_fake = D_out_fake |
|
scores_real = D_out_real |
|
|
|
|
|
loss_D_dict = criterion_D(scores_fake, scores_real) |
|
loss_D = loss_D_dict['D_loss'] |
|
|
|
|
|
optimizer_D.zero_grad() |
|
loss_D.backward() |
|
if c.disc_clip_grad > 0: |
|
torch.nn.utils.clip_grad_norm_(model_D.parameters(), |
|
c.disc_clip_grad) |
|
optimizer_D.step() |
|
if scheduler_D is not None: |
|
scheduler_D.step() |
|
|
|
for key, value in loss_D_dict.items(): |
|
if isinstance(value, (int, float)): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
|
|
current_lr_G = list(optimizer_G.param_groups)[0]['lr'] |
|
current_lr_D = list(optimizer_D.param_groups)[0]['lr'] |
|
|
|
|
|
update_train_values = dict() |
|
for key, value in loss_dict.items(): |
|
update_train_values['avg_' + key] = value |
|
update_train_values['avg_loader_time'] = loader_time |
|
update_train_values['avg_step_time'] = step_time |
|
keep_avg.update_values(update_train_values) |
|
|
|
|
|
if global_step % c.print_step == 0: |
|
log_dict = { |
|
'step_time': [step_time, 2], |
|
'loader_time': [loader_time, 4], |
|
"current_lr_G": current_lr_G, |
|
"current_lr_D": current_lr_D |
|
} |
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step, |
|
log_dict, loss_dict, keep_avg.avg_values) |
|
|
|
if args.rank == 0: |
|
|
|
if global_step % 10 == 0: |
|
iter_stats = { |
|
"lr_G": current_lr_G, |
|
"lr_D": current_lr_D, |
|
"step_time": step_time |
|
} |
|
iter_stats.update(loss_dict) |
|
tb_logger.tb_train_iter_stats(global_step, iter_stats) |
|
|
|
|
|
if global_step % c.save_step == 0: |
|
if c.checkpoint: |
|
|
|
save_checkpoint(model_G, |
|
optimizer_G, |
|
scheduler_G, |
|
model_D, |
|
optimizer_D, |
|
scheduler_D, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=loss_dict) |
|
|
|
|
|
figures = plot_results(y_hat_vis, y_G, ap, global_step, |
|
'train') |
|
tb_logger.tb_train_figures(global_step, figures) |
|
|
|
|
|
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() |
|
tb_logger.tb_train_audios(global_step, |
|
{'train/audio': sample_voice}, |
|
c.audio["sample_rate"]) |
|
end_time = time.time() |
|
|
|
|
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) |
|
|
|
|
|
epoch_stats = {"epoch_time": epoch_time} |
|
epoch_stats.update(keep_avg.avg_values) |
|
if args.rank == 0: |
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) |
|
|
|
|
|
|
|
return keep_avg.avg_values, global_step |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch): |
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) |
|
model_G.eval() |
|
model_D.eval() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
end_time = time.time() |
|
c_logger.print_eval_start() |
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
|
|
|
|
c_G, y_G, _, _ = format_data(data) |
|
loader_time = time.time() - end_time |
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
y_hat = model_G(c_G) |
|
y_hat_sub = None |
|
y_G_sub = None |
|
|
|
|
|
if y_hat.shape[1] > 1: |
|
y_hat_sub = y_hat |
|
y_hat = model_G.pqmf_synthesis(y_hat) |
|
y_G_sub = model_G.pqmf_analysis(y_G) |
|
|
|
scores_fake, feats_fake, feats_real = None, None, None |
|
if global_step > c.steps_to_start_discriminator: |
|
|
|
if len(signature(model_D.forward).parameters) == 2: |
|
D_out_fake = model_D(y_hat, c_G) |
|
else: |
|
D_out_fake = model_D(y_hat) |
|
D_out_real = None |
|
|
|
if c.use_feat_match_loss: |
|
with torch.no_grad(): |
|
D_out_real = model_D(y_G) |
|
|
|
|
|
if isinstance(D_out_fake, tuple): |
|
scores_fake, feats_fake = D_out_fake |
|
if D_out_real is None: |
|
feats_real = None |
|
else: |
|
_, feats_real = D_out_real |
|
else: |
|
scores_fake = D_out_fake |
|
feats_fake, feats_real = None, None |
|
|
|
|
|
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, |
|
feats_real, y_hat_sub, y_G_sub) |
|
|
|
loss_dict = dict() |
|
for key, value in loss_G_dict.items(): |
|
if isinstance(value, (int, float)): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
|
|
|
|
|
|
|
|
if global_step >= c.steps_to_start_discriminator: |
|
|
|
with torch.no_grad(): |
|
y_hat = model_G(c_G) |
|
|
|
|
|
if y_hat.shape[1] > 1: |
|
y_hat = model_G.pqmf_synthesis(y_hat) |
|
|
|
|
|
if len(signature(model_D.forward).parameters) == 2: |
|
D_out_fake = model_D(y_hat.detach(), c_G) |
|
D_out_real = model_D(y_G, c_G) |
|
else: |
|
D_out_fake = model_D(y_hat.detach()) |
|
D_out_real = model_D(y_G) |
|
|
|
|
|
if isinstance(D_out_fake, tuple): |
|
scores_fake, feats_fake = D_out_fake |
|
if D_out_real is None: |
|
scores_real, feats_real = None, None |
|
else: |
|
scores_real, feats_real = D_out_real |
|
else: |
|
scores_fake = D_out_fake |
|
scores_real = D_out_real |
|
|
|
|
|
loss_D_dict = criterion_D(scores_fake, scores_real) |
|
|
|
for key, value in loss_D_dict.items(): |
|
if isinstance(value, (int, float)): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
|
|
update_eval_values = dict() |
|
for key, value in loss_dict.items(): |
|
update_eval_values['avg_' + key] = value |
|
update_eval_values['avg_loader_time'] = loader_time |
|
update_eval_values['avg_step_time'] = step_time |
|
keep_avg.update_values(update_eval_values) |
|
|
|
|
|
if c.print_eval: |
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) |
|
|
|
if args.rank == 0: |
|
|
|
figures = plot_results(y_hat, y_G, ap, global_step, 'eval') |
|
tb_logger.tb_eval_figures(global_step, figures) |
|
|
|
|
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() |
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, |
|
c.audio["sample_rate"]) |
|
|
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) |
|
|
|
|
|
data_loader.return_segments = False |
|
|
|
return keep_avg.avg_values |
|
|
|
|
|
|
|
def main(args): |
|
|
|
global train_data, eval_data |
|
print(f" > Loading wavs from: {c.data_path}") |
|
if c.feature_path is not None: |
|
print(f" > Loading features from: {c.feature_path}") |
|
eval_data, train_data = load_wav_feat_data( |
|
c.data_path, c.feature_path, c.eval_split_size) |
|
else: |
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) |
|
|
|
|
|
ap = AudioProcessor(**c.audio) |
|
|
|
|
|
if num_gpus > 1: |
|
init_distributed(args.rank, num_gpus, args.group_id, |
|
c.distributed["backend"], c.distributed["url"]) |
|
|
|
|
|
model_gen = setup_generator(c) |
|
model_disc = setup_discriminator(c) |
|
|
|
|
|
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) |
|
optimizer_disc = RAdam(model_disc.parameters(), |
|
lr=c.lr_disc, |
|
weight_decay=0) |
|
|
|
|
|
scheduler_gen = None |
|
scheduler_disc = None |
|
if 'lr_scheduler_gen' in c: |
|
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) |
|
scheduler_gen = scheduler_gen( |
|
optimizer_gen, **c.lr_scheduler_gen_params) |
|
if 'lr_scheduler_disc' in c: |
|
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) |
|
scheduler_disc = scheduler_disc( |
|
optimizer_disc, **c.lr_scheduler_disc_params) |
|
|
|
|
|
criterion_gen = GeneratorLoss(c) |
|
criterion_disc = DiscriminatorLoss(c) |
|
|
|
if args.restore_path: |
|
checkpoint = torch.load(args.restore_path, map_location='cpu') |
|
try: |
|
print(" > Restoring Generator Model...") |
|
model_gen.load_state_dict(checkpoint['model']) |
|
print(" > Restoring Generator Optimizer...") |
|
optimizer_gen.load_state_dict(checkpoint['optimizer']) |
|
print(" > Restoring Discriminator Model...") |
|
model_disc.load_state_dict(checkpoint['model_disc']) |
|
print(" > Restoring Discriminator Optimizer...") |
|
optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) |
|
if 'scheduler' in checkpoint: |
|
print(" > Restoring Generator LR Scheduler...") |
|
scheduler_gen.load_state_dict(checkpoint['scheduler']) |
|
|
|
scheduler_gen.optimizer = optimizer_gen |
|
if 'scheduler_disc' in checkpoint: |
|
print(" > Restoring Discriminator LR Scheduler...") |
|
scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) |
|
scheduler_disc.optimizer = optimizer_disc |
|
except RuntimeError: |
|
|
|
print(" > Partial model initialization...") |
|
model_dict = model_gen.state_dict() |
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c) |
|
model_gen.load_state_dict(model_dict) |
|
|
|
model_dict = model_disc.state_dict() |
|
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c) |
|
model_disc.load_state_dict(model_dict) |
|
del model_dict |
|
|
|
|
|
for group in optimizer_gen.param_groups: |
|
group['lr'] = c.lr_gen |
|
|
|
for group in optimizer_disc.param_groups: |
|
group['lr'] = c.lr_disc |
|
|
|
print(" > Model restored from step %d" % checkpoint['step'], |
|
flush=True) |
|
args.restore_step = checkpoint['step'] |
|
else: |
|
args.restore_step = 0 |
|
|
|
if use_cuda: |
|
model_gen.cuda() |
|
criterion_gen.cuda() |
|
model_disc.cuda() |
|
criterion_disc.cuda() |
|
|
|
|
|
if num_gpus > 1: |
|
model_gen = DDP_th(model_gen, device_ids=[args.rank]) |
|
model_disc = DDP_th(model_disc, device_ids=[args.rank]) |
|
|
|
num_params = count_parameters(model_gen) |
|
print(" > Generator has {} parameters".format(num_params), flush=True) |
|
num_params = count_parameters(model_disc) |
|
print(" > Discriminator has {} parameters".format(num_params), flush=True) |
|
|
|
if 'best_loss' not in locals(): |
|
best_loss = float('inf') |
|
|
|
global_step = args.restore_step |
|
for epoch in range(0, c.epochs): |
|
c_logger.print_epoch_start(epoch, c.epochs) |
|
_, global_step = train(model_gen, criterion_gen, optimizer_gen, |
|
model_disc, criterion_disc, optimizer_disc, |
|
scheduler_gen, scheduler_disc, ap, global_step, |
|
epoch) |
|
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, |
|
global_step, epoch) |
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) |
|
target_loss = eval_avg_loss_dict[c.target_loss] |
|
best_loss = save_best_model(target_loss, |
|
best_loss, |
|
model_gen, |
|
optimizer_gen, |
|
scheduler_gen, |
|
model_disc, |
|
optimizer_disc, |
|
scheduler_disc, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=eval_avg_loss_dict) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--continue_path', |
|
type=str, |
|
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', |
|
default='', |
|
required='--config_path' not in sys.argv) |
|
parser.add_argument( |
|
'--restore_path', |
|
type=str, |
|
help='Model file to be restored. Use to finetune a model.', |
|
default='') |
|
parser.add_argument('--config_path', |
|
type=str, |
|
help='Path to config file for training.', |
|
required='--continue_path' not in sys.argv) |
|
parser.add_argument('--debug', |
|
type=bool, |
|
default=False, |
|
help='Do not verify commit integrity to run training.') |
|
|
|
|
|
parser.add_argument( |
|
'--rank', |
|
type=int, |
|
default=0, |
|
help='DISTRIBUTED: process rank for distributed training.') |
|
parser.add_argument('--group_id', |
|
type=str, |
|
default="", |
|
help='DISTRIBUTED: process group id.') |
|
args = parser.parse_args() |
|
|
|
if args.continue_path != '': |
|
args.output_path = args.continue_path |
|
args.config_path = os.path.join(args.continue_path, 'config.json') |
|
list_of_files = glob.glob( |
|
args.continue_path + |
|
"/*.pth.tar") |
|
latest_model_file = max(list_of_files, key=os.path.getctime) |
|
args.restore_path = latest_model_file |
|
print(f" > Training continues for {args.restore_path}") |
|
|
|
|
|
c = load_config(args.config_path) |
|
|
|
_ = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
OUT_PATH = args.continue_path |
|
if args.continue_path == '': |
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, |
|
args.debug) |
|
|
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') |
|
|
|
c_logger = ConsoleLogger() |
|
|
|
if args.rank == 0: |
|
os.makedirs(AUDIO_PATH, exist_ok=True) |
|
new_fields = {} |
|
if args.restore_path: |
|
new_fields["restore_path"] = args.restore_path |
|
new_fields["github_branch"] = get_git_branch() |
|
copy_model_files(c, args.config_path, |
|
OUT_PATH, new_fields) |
|
os.chmod(AUDIO_PATH, 0o775) |
|
os.chmod(OUT_PATH, 0o775) |
|
|
|
LOG_DIR = OUT_PATH |
|
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER') |
|
|
|
|
|
tb_logger.tb_add_text('model-description', c['run_description'], 0) |
|
|
|
try: |
|
main(args) |
|
except KeyboardInterrupt: |
|
remove_experiment_folder(OUT_PATH) |
|
try: |
|
sys.exit(0) |
|
except SystemExit: |
|
os._exit(0) |
|
except Exception: |
|
remove_experiment_folder(OUT_PATH) |
|
traceback.print_exc() |
|
sys.exit(1) |
|
|