voice-xtts2 / TTS /bin /train_vocoder_wavegrad.py
antoniomae1234's picture
changes in flenema
2493d72 verified
import argparse
import glob
import os
import sys
import time
import traceback
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveGradDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=True,
use_noise_augment=False,
use_cache=c.use_cache,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
# return a whole audio segment
m, x = data
x = x.unsqueeze(1)
if use_cuda:
m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return m, x
def format_test_data(data):
# return a whole audio segment
m, x = data
m = m[None, ...]
x = x[None, None, ...]
if use_cuda:
m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return m, x
def train(model, criterion, optimizer,
scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
# setup noise schedule
noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(betas)
else:
model.compute_noise_level(betas)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
# check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
optimizer.zero_grad()
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# schedule update
if scheduler is not None:
scheduler.step()
# disconnect loss values
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, int):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
# epoch/step timing
step_time = time.time() - start_time
epoch_time += step_time
# get current learning rates
current_lr = list(optimizer.param_groups)[0]['lr']
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr": current_lr,
"grad_norm": grad_norm.item()
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm.item(),
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Training Epoch Stats
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
if args.rank == 0:
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats
if c.tb_model_param_stats and args.rank == 0:
tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
end_time = time.time()
c_logger.print_eval_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
step_time = time.time() - start_time
epoch_time += step_time
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
data_loader.dataset.return_segments = False
samples = data_loader.dataset.load_test_samples(1)
m, x = format_test_data(samples[0])
# setup noise schedule and inference
noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(betas)
# compute voice
x_pred = model.module.inference(m)
else:
model.compute_noise_level(betas)
# compute voice
x_pred = model.inference(m)
# compute spectrograms
figures = plot_results(x_pred, x, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
data_loader.dataset.return_segments = True
return keep_avg.avg_values
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global train_data, eval_data
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
# setup audio processor
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
# setup models
model = setup_generator(c)
# scaler for mixed_precision
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
# setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
# schedulers
scheduler = None
if 'lr_scheduler' in c:
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
# setup criterion
criterion = torch.nn.L1Loss().cuda()
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
print(" > Restoring LR Scheduler...")
scheduler.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler.optimizer = optimizer
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
# reset lr if not countinuining training.
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank])
num_params = count_parameters(model)
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
if 'best_loss' not in locals():
best_loss = float('inf')
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer,
scheduler, scaler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model(target_loss,
best_loss,
model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help=
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument('--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(
args.continue_path +
"/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.mixed_precision:
print(" > Mixed precision is enabled")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path,
OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)