|
import argparse |
|
import glob |
|
import os |
|
import sys |
|
import time |
|
import traceback |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP_th |
|
from torch.optim import Adam |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from TTS.utils.audio import AudioProcessor |
|
from TTS.utils.console_logger import ConsoleLogger |
|
from TTS.utils.distribute import init_distributed |
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters, |
|
create_experiment_folder, get_git_branch, |
|
remove_experiment_folder, set_init_dict) |
|
from TTS.utils.io import copy_model_files, load_config |
|
from TTS.utils.tensorboard_logger import TensorboardLogger |
|
from TTS.utils.training import setup_torch_training_env |
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data |
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset |
|
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator |
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint |
|
|
|
use_cuda, num_gpus = setup_torch_training_env(True, True) |
|
|
|
|
|
def setup_loader(ap, is_val=False, verbose=False): |
|
if is_val and not c.run_eval: |
|
loader = None |
|
else: |
|
dataset = WaveGradDataset(ap=ap, |
|
items=eval_data if is_val else train_data, |
|
seq_len=c.seq_len, |
|
hop_len=ap.hop_length, |
|
pad_short=c.pad_short, |
|
conv_pad=c.conv_pad, |
|
is_training=not is_val, |
|
return_segments=True, |
|
use_noise_augment=False, |
|
use_cache=c.use_cache, |
|
verbose=verbose) |
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
|
loader = DataLoader(dataset, |
|
batch_size=c.batch_size, |
|
shuffle=num_gpus <= 1, |
|
drop_last=False, |
|
sampler=sampler, |
|
num_workers=c.num_val_loader_workers |
|
if is_val else c.num_loader_workers, |
|
pin_memory=False) |
|
|
|
|
|
return loader |
|
|
|
|
|
def format_data(data): |
|
|
|
m, x = data |
|
x = x.unsqueeze(1) |
|
if use_cuda: |
|
m = m.cuda(non_blocking=True) |
|
x = x.cuda(non_blocking=True) |
|
return m, x |
|
|
|
|
|
def format_test_data(data): |
|
|
|
m, x = data |
|
m = m[None, ...] |
|
x = x[None, None, ...] |
|
if use_cuda: |
|
m = m.cuda(non_blocking=True) |
|
x = x.cuda(non_blocking=True) |
|
return m, x |
|
|
|
|
|
def train(model, criterion, optimizer, |
|
scheduler, scaler, ap, global_step, epoch): |
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) |
|
model.train() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
if use_cuda: |
|
batch_n_iter = int( |
|
len(data_loader.dataset) / (c.batch_size * num_gpus)) |
|
else: |
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) |
|
end_time = time.time() |
|
c_logger.print_train_start() |
|
|
|
noise_schedule = c['train_noise_schedule'] |
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) |
|
if hasattr(model, 'module'): |
|
model.module.compute_noise_level(betas) |
|
else: |
|
model.compute_noise_level(betas) |
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
|
|
|
|
m, x = format_data(data) |
|
loader_time = time.time() - end_time |
|
|
|
global_step += 1 |
|
|
|
with torch.cuda.amp.autocast(enabled=c.mixed_precision): |
|
|
|
if hasattr(model, 'module'): |
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x) |
|
else: |
|
noise, x_noisy, noise_scale = model.compute_y_n(x) |
|
|
|
|
|
noise_hat = model(x_noisy, m, noise_scale) |
|
|
|
|
|
loss = criterion(noise, noise_hat) |
|
loss_wavegrad_dict = {'wavegrad_loss':loss} |
|
|
|
|
|
if torch.isnan(loss).any(): |
|
raise RuntimeError(f'Detected NaN loss at step {global_step}.') |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
if c.mixed_precision: |
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), |
|
c.clip_grad) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
loss.backward() |
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), |
|
c.clip_grad) |
|
optimizer.step() |
|
|
|
|
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
|
|
loss_dict = dict() |
|
for key, value in loss_wavegrad_dict.items(): |
|
if isinstance(value, int): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
|
|
current_lr = list(optimizer.param_groups)[0]['lr'] |
|
|
|
|
|
update_train_values = dict() |
|
for key, value in loss_dict.items(): |
|
update_train_values['avg_' + key] = value |
|
update_train_values['avg_loader_time'] = loader_time |
|
update_train_values['avg_step_time'] = step_time |
|
keep_avg.update_values(update_train_values) |
|
|
|
|
|
if global_step % c.print_step == 0: |
|
log_dict = { |
|
'step_time': [step_time, 2], |
|
'loader_time': [loader_time, 4], |
|
"current_lr": current_lr, |
|
"grad_norm": grad_norm.item() |
|
} |
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step, |
|
log_dict, loss_dict, keep_avg.avg_values) |
|
|
|
if args.rank == 0: |
|
|
|
if global_step % 10 == 0: |
|
iter_stats = { |
|
"lr": current_lr, |
|
"grad_norm": grad_norm.item(), |
|
"step_time": step_time |
|
} |
|
iter_stats.update(loss_dict) |
|
tb_logger.tb_train_iter_stats(global_step, iter_stats) |
|
|
|
|
|
if global_step % c.save_step == 0: |
|
if c.checkpoint: |
|
|
|
save_checkpoint(model, |
|
optimizer, |
|
scheduler, |
|
None, |
|
None, |
|
None, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=loss_dict, |
|
scaler=scaler.state_dict() if c.mixed_precision else None) |
|
|
|
end_time = time.time() |
|
|
|
|
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) |
|
|
|
|
|
epoch_stats = {"epoch_time": epoch_time} |
|
epoch_stats.update(keep_avg.avg_values) |
|
if args.rank == 0: |
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) |
|
|
|
if c.tb_model_param_stats and args.rank == 0: |
|
tb_logger.tb_model_weights(model, global_step) |
|
return keep_avg.avg_values, global_step |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(model, criterion, ap, global_step, epoch): |
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) |
|
model.eval() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
end_time = time.time() |
|
c_logger.print_eval_start() |
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
|
|
|
|
m, x = format_data(data) |
|
loader_time = time.time() - end_time |
|
|
|
global_step += 1 |
|
|
|
|
|
if hasattr(model, 'module'): |
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x) |
|
else: |
|
noise, x_noisy, noise_scale = model.compute_y_n(x) |
|
|
|
|
|
|
|
noise_hat = model(x_noisy, m, noise_scale) |
|
|
|
|
|
loss = criterion(noise, noise_hat) |
|
loss_wavegrad_dict = {'wavegrad_loss':loss} |
|
|
|
|
|
loss_dict = dict() |
|
for key, value in loss_wavegrad_dict.items(): |
|
if isinstance(value, (int, float)): |
|
loss_dict[key] = value |
|
else: |
|
loss_dict[key] = value.item() |
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
|
|
update_eval_values = dict() |
|
for key, value in loss_dict.items(): |
|
update_eval_values['avg_' + key] = value |
|
update_eval_values['avg_loader_time'] = loader_time |
|
update_eval_values['avg_step_time'] = step_time |
|
keep_avg.update_values(update_eval_values) |
|
|
|
|
|
if c.print_eval: |
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) |
|
|
|
if args.rank == 0: |
|
data_loader.dataset.return_segments = False |
|
samples = data_loader.dataset.load_test_samples(1) |
|
m, x = format_test_data(samples[0]) |
|
|
|
|
|
noise_schedule = c['test_noise_schedule'] |
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) |
|
if hasattr(model, 'module'): |
|
model.module.compute_noise_level(betas) |
|
|
|
x_pred = model.module.inference(m) |
|
else: |
|
model.compute_noise_level(betas) |
|
|
|
x_pred = model.inference(m) |
|
|
|
|
|
figures = plot_results(x_pred, x, ap, global_step, 'eval') |
|
tb_logger.tb_eval_figures(global_step, figures) |
|
|
|
|
|
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() |
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, |
|
c.audio["sample_rate"]) |
|
|
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) |
|
data_loader.dataset.return_segments = True |
|
|
|
return keep_avg.avg_values |
|
|
|
|
|
def main(args): |
|
|
|
global train_data, eval_data |
|
print(f" > Loading wavs from: {c.data_path}") |
|
if c.feature_path is not None: |
|
print(f" > Loading features from: {c.feature_path}") |
|
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) |
|
else: |
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) |
|
|
|
|
|
ap = AudioProcessor(**c.audio) |
|
|
|
|
|
if num_gpus > 1: |
|
init_distributed(args.rank, num_gpus, args.group_id, |
|
c.distributed["backend"], c.distributed["url"]) |
|
|
|
|
|
model = setup_generator(c) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None |
|
|
|
|
|
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) |
|
|
|
|
|
scheduler = None |
|
if 'lr_scheduler' in c: |
|
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) |
|
scheduler = scheduler(optimizer, **c.lr_scheduler_params) |
|
|
|
|
|
criterion = torch.nn.L1Loss().cuda() |
|
|
|
if args.restore_path: |
|
checkpoint = torch.load(args.restore_path, map_location='cpu') |
|
try: |
|
print(" > Restoring Model...") |
|
model.load_state_dict(checkpoint['model']) |
|
print(" > Restoring Optimizer...") |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
if 'scheduler' in checkpoint: |
|
print(" > Restoring LR Scheduler...") |
|
scheduler.load_state_dict(checkpoint['scheduler']) |
|
|
|
scheduler.optimizer = optimizer |
|
if "scaler" in checkpoint and c.mixed_precision: |
|
print(" > Restoring AMP Scaler...") |
|
scaler.load_state_dict(checkpoint["scaler"]) |
|
except RuntimeError: |
|
|
|
print(" > Partial model initialization...") |
|
model_dict = model.state_dict() |
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c) |
|
model.load_state_dict(model_dict) |
|
del model_dict |
|
|
|
|
|
for group in optimizer.param_groups: |
|
group['lr'] = c.lr |
|
|
|
print(" > Model restored from step %d" % checkpoint['step'], |
|
flush=True) |
|
args.restore_step = checkpoint['step'] |
|
else: |
|
args.restore_step = 0 |
|
|
|
if use_cuda: |
|
model.cuda() |
|
criterion.cuda() |
|
|
|
|
|
if num_gpus > 1: |
|
model = DDP_th(model, device_ids=[args.rank]) |
|
|
|
num_params = count_parameters(model) |
|
print(" > WaveGrad has {} parameters".format(num_params), flush=True) |
|
|
|
if 'best_loss' not in locals(): |
|
best_loss = float('inf') |
|
|
|
global_step = args.restore_step |
|
for epoch in range(0, c.epochs): |
|
c_logger.print_epoch_start(epoch, c.epochs) |
|
_, global_step = train(model, criterion, optimizer, |
|
scheduler, scaler, ap, global_step, |
|
epoch) |
|
eval_avg_loss_dict = evaluate(model, criterion, ap, |
|
global_step, epoch) |
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) |
|
target_loss = eval_avg_loss_dict[c.target_loss] |
|
best_loss = save_best_model(target_loss, |
|
best_loss, |
|
model, |
|
optimizer, |
|
scheduler, |
|
None, |
|
None, |
|
None, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=eval_avg_loss_dict, |
|
scaler=scaler.state_dict() if c.mixed_precision else None) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--continue_path', |
|
type=str, |
|
help= |
|
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', |
|
default='', |
|
required='--config_path' not in sys.argv) |
|
parser.add_argument( |
|
'--restore_path', |
|
type=str, |
|
help='Model file to be restored. Use to finetune a model.', |
|
default='') |
|
parser.add_argument('--config_path', |
|
type=str, |
|
help='Path to config file for training.', |
|
required='--continue_path' not in sys.argv) |
|
parser.add_argument('--debug', |
|
type=bool, |
|
default=False, |
|
help='Do not verify commit integrity to run training.') |
|
|
|
|
|
parser.add_argument( |
|
'--rank', |
|
type=int, |
|
default=0, |
|
help='DISTRIBUTED: process rank for distributed training.') |
|
parser.add_argument('--group_id', |
|
type=str, |
|
default="", |
|
help='DISTRIBUTED: process group id.') |
|
args = parser.parse_args() |
|
|
|
if args.continue_path != '': |
|
args.output_path = args.continue_path |
|
args.config_path = os.path.join(args.continue_path, 'config.json') |
|
list_of_files = glob.glob( |
|
args.continue_path + |
|
"/*.pth.tar") |
|
latest_model_file = max(list_of_files, key=os.path.getctime) |
|
args.restore_path = latest_model_file |
|
print(f" > Training continues for {args.restore_path}") |
|
|
|
|
|
c = load_config(args.config_path) |
|
|
|
_ = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
if c.mixed_precision: |
|
print(" > Mixed precision is enabled") |
|
|
|
OUT_PATH = args.continue_path |
|
if args.continue_path == '': |
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, |
|
args.debug) |
|
|
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') |
|
|
|
c_logger = ConsoleLogger() |
|
|
|
if args.rank == 0: |
|
os.makedirs(AUDIO_PATH, exist_ok=True) |
|
new_fields = {} |
|
if args.restore_path: |
|
new_fields["restore_path"] = args.restore_path |
|
new_fields["github_branch"] = get_git_branch() |
|
copy_model_files(c, args.config_path, |
|
OUT_PATH, new_fields) |
|
os.chmod(AUDIO_PATH, 0o775) |
|
os.chmod(OUT_PATH, 0o775) |
|
|
|
LOG_DIR = OUT_PATH |
|
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER') |
|
|
|
|
|
tb_logger.tb_add_text('model-description', c['run_description'], 0) |
|
|
|
try: |
|
main(args) |
|
except KeyboardInterrupt: |
|
remove_experiment_folder(OUT_PATH) |
|
try: |
|
sys.exit(0) |
|
except SystemExit: |
|
os._exit(0) |
|
except Exception: |
|
remove_experiment_folder(OUT_PATH) |
|
traceback.print_exc() |
|
sys.exit(1) |
|
|