|
import argparse |
|
import os |
|
import sys |
|
import traceback |
|
import time |
|
import glob |
|
import random |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
from TTS.tts.utils.visual import plot_spectrogram |
|
from TTS.utils.audio import AudioProcessor |
|
from TTS.utils.radam import RAdam |
|
from TTS.utils.io import copy_model_files, load_config |
|
from TTS.utils.training import setup_torch_training_env |
|
from TTS.utils.console_logger import ConsoleLogger |
|
from TTS.utils.tensorboard_logger import TensorboardLogger |
|
from TTS.utils.generic_utils import ( |
|
KeepAverage, |
|
count_parameters, |
|
create_experiment_folder, |
|
get_git_branch, |
|
remove_experiment_folder, |
|
set_init_dict, |
|
) |
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset |
|
from TTS.vocoder.datasets.preprocess import ( |
|
load_wav_data, |
|
load_wav_feat_data |
|
) |
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss |
|
from TTS.vocoder.utils.generic_utils import setup_wavernn |
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint |
|
|
|
|
|
use_cuda, num_gpus = setup_torch_training_env(True, True) |
|
|
|
|
|
def setup_loader(ap, is_val=False, verbose=False): |
|
if is_val and not c.run_eval: |
|
loader = None |
|
else: |
|
dataset = WaveRNNDataset(ap=ap, |
|
items=eval_data if is_val else train_data, |
|
seq_len=c.seq_len, |
|
hop_len=ap.hop_length, |
|
pad=c.padding, |
|
mode=c.mode, |
|
mulaw=c.mulaw, |
|
is_training=not is_val, |
|
verbose=verbose, |
|
) |
|
|
|
loader = DataLoader(dataset, |
|
shuffle=True, |
|
collate_fn=dataset.collate, |
|
batch_size=c.batch_size, |
|
num_workers=c.num_val_loader_workers |
|
if is_val |
|
else c.num_loader_workers, |
|
pin_memory=True, |
|
) |
|
return loader |
|
|
|
|
|
def format_data(data): |
|
|
|
x_input = data[0] |
|
mels = data[1] |
|
y_coarse = data[2] |
|
|
|
|
|
if use_cuda: |
|
x_input = x_input.cuda(non_blocking=True) |
|
mels = mels.cuda(non_blocking=True) |
|
y_coarse = y_coarse.cuda(non_blocking=True) |
|
|
|
return x_input, mels, y_coarse |
|
|
|
|
|
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch): |
|
|
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) |
|
model.train() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
if use_cuda: |
|
batch_n_iter = int(len(data_loader.dataset) / |
|
(c.batch_size * num_gpus)) |
|
else: |
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) |
|
end_time = time.time() |
|
c_logger.print_train_start() |
|
|
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
x_input, mels, y_coarse = format_data(data) |
|
loader_time = time.time() - end_time |
|
global_step += 1 |
|
|
|
optimizer.zero_grad() |
|
|
|
if c.mixed_precision: |
|
|
|
with torch.cuda.amp.autocast(): |
|
y_hat = model(x_input, mels) |
|
if isinstance(model.mode, int): |
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) |
|
else: |
|
y_coarse = y_coarse.float() |
|
y_coarse = y_coarse.unsqueeze(-1) |
|
|
|
loss = criterion(y_hat, y_coarse) |
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
if c.grad_clip > 0: |
|
torch.nn.utils.clip_grad_norm_( |
|
model.parameters(), c.grad_clip) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
|
|
y_hat = model(x_input, mels) |
|
if isinstance(model.mode, int): |
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) |
|
else: |
|
y_coarse = y_coarse.float() |
|
y_coarse = y_coarse.unsqueeze(-1) |
|
|
|
loss = criterion(y_hat, y_coarse) |
|
if loss.item() is None: |
|
raise RuntimeError(" [!] None loss. Exiting ...") |
|
loss.backward() |
|
if c.grad_clip > 0: |
|
torch.nn.utils.clip_grad_norm_( |
|
model.parameters(), c.grad_clip) |
|
optimizer.step() |
|
|
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
|
|
cur_lr = list(optimizer.param_groups)[0]["lr"] |
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
update_train_values = dict() |
|
loss_dict = dict() |
|
loss_dict["model_loss"] = loss.item() |
|
for key, value in loss_dict.items(): |
|
update_train_values["avg_" + key] = value |
|
update_train_values["avg_loader_time"] = loader_time |
|
update_train_values["avg_step_time"] = step_time |
|
keep_avg.update_values(update_train_values) |
|
|
|
|
|
if global_step % c.print_step == 0: |
|
log_dict = {"step_time": [step_time, 2], |
|
"loader_time": [loader_time, 4], |
|
"current_lr": cur_lr, |
|
} |
|
c_logger.print_train_step(batch_n_iter, |
|
num_iter, |
|
global_step, |
|
log_dict, |
|
loss_dict, |
|
keep_avg.avg_values, |
|
) |
|
|
|
|
|
if global_step % 10 == 0: |
|
iter_stats = {"lr": cur_lr, "step_time": step_time} |
|
iter_stats.update(loss_dict) |
|
tb_logger.tb_train_iter_stats(global_step, iter_stats) |
|
|
|
|
|
if global_step % c.save_step == 0: |
|
if c.checkpoint: |
|
|
|
save_checkpoint(model, |
|
optimizer, |
|
scheduler, |
|
None, |
|
None, |
|
None, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=loss_dict, |
|
scaler=scaler.state_dict() if c.mixed_precision else None |
|
) |
|
|
|
|
|
rand_idx = random.randrange(0, len(train_data)) |
|
wav_path = train_data[rand_idx] if not isinstance( |
|
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] |
|
wav = ap.load_wav(wav_path) |
|
ground_mel = ap.melspectrogram(wav) |
|
sample_wav = model.generate(ground_mel, |
|
c.batched, |
|
c.target_samples, |
|
c.overlap_samples, |
|
use_cuda |
|
) |
|
predict_mel = ap.melspectrogram(sample_wav) |
|
|
|
|
|
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), |
|
"train/prediction": plot_spectrogram(predict_mel.T) |
|
} |
|
tb_logger.tb_train_figures(global_step, figures) |
|
|
|
|
|
tb_logger.tb_train_audios( |
|
global_step, { |
|
"train/audio": sample_wav}, c.audio["sample_rate"] |
|
) |
|
end_time = time.time() |
|
|
|
|
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) |
|
|
|
|
|
epoch_stats = {"epoch_time": epoch_time} |
|
epoch_stats.update(keep_avg.avg_values) |
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) |
|
|
|
|
|
|
|
return keep_avg.avg_values, global_step |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(model, criterion, ap, global_step, epoch): |
|
|
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) |
|
model.eval() |
|
epoch_time = 0 |
|
keep_avg = KeepAverage() |
|
end_time = time.time() |
|
c_logger.print_eval_start() |
|
with torch.no_grad(): |
|
for num_iter, data in enumerate(data_loader): |
|
start_time = time.time() |
|
|
|
x_input, mels, y_coarse = format_data(data) |
|
loader_time = time.time() - end_time |
|
global_step += 1 |
|
|
|
y_hat = model(x_input, mels) |
|
if isinstance(model.mode, int): |
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) |
|
else: |
|
y_coarse = y_coarse.float() |
|
y_coarse = y_coarse.unsqueeze(-1) |
|
loss = criterion(y_hat, y_coarse) |
|
|
|
|
|
|
|
loss_dict = dict() |
|
loss_dict["model_loss"] = loss.item() |
|
|
|
step_time = time.time() - start_time |
|
epoch_time += step_time |
|
|
|
|
|
update_eval_values = dict() |
|
for key, value in loss_dict.items(): |
|
update_eval_values["avg_" + key] = value |
|
update_eval_values["avg_loader_time"] = loader_time |
|
update_eval_values["avg_step_time"] = step_time |
|
keep_avg.update_values(update_eval_values) |
|
|
|
|
|
if c.print_eval: |
|
c_logger.print_eval_step( |
|
num_iter, loss_dict, keep_avg.avg_values) |
|
|
|
if epoch % c.test_every_epochs == 0 and epoch != 0: |
|
|
|
rand_idx = random.randrange(0, len(eval_data)) |
|
wav_path = eval_data[rand_idx] if not isinstance( |
|
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] |
|
wav = ap.load_wav(wav_path) |
|
ground_mel = ap.melspectrogram(wav) |
|
sample_wav = model.generate(ground_mel, |
|
c.batched, |
|
c.target_samples, |
|
c.overlap_samples, |
|
use_cuda |
|
) |
|
predict_mel = ap.melspectrogram(sample_wav) |
|
|
|
|
|
tb_logger.tb_eval_audios( |
|
global_step, { |
|
"eval/audio": sample_wav}, c.audio["sample_rate"] |
|
) |
|
|
|
|
|
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T), |
|
"eval/prediction": plot_spectrogram(predict_mel.T) |
|
} |
|
tb_logger.tb_eval_figures(global_step, figures) |
|
|
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) |
|
return keep_avg.avg_values |
|
|
|
|
|
|
|
def main(args): |
|
|
|
global train_data, eval_data |
|
|
|
|
|
ap = AudioProcessor(**c.audio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f" > Loading wavs from: {c.data_path}") |
|
if c.feature_path is not None: |
|
print(f" > Loading features from: {c.feature_path}") |
|
eval_data, train_data = load_wav_feat_data( |
|
c.data_path, c.feature_path, c.eval_split_size) |
|
else: |
|
eval_data, train_data = load_wav_data( |
|
c.data_path, c.eval_split_size) |
|
|
|
model_wavernn = setup_wavernn(c) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None |
|
|
|
|
|
if c.mode == "mold": |
|
criterion = discretized_mix_logistic_loss |
|
elif c.mode == "gauss": |
|
criterion = gaussian_loss |
|
elif isinstance(c.mode, int): |
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
if use_cuda: |
|
model_wavernn.cuda() |
|
if isinstance(c.mode, int): |
|
criterion.cuda() |
|
|
|
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0) |
|
|
|
scheduler = None |
|
if "lr_scheduler" in c: |
|
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) |
|
scheduler = scheduler(optimizer, **c.lr_scheduler_params) |
|
|
|
|
|
|
|
|
|
|
|
if args.restore_path: |
|
checkpoint = torch.load(args.restore_path, map_location="cpu") |
|
try: |
|
print(" > Restoring Model...") |
|
model_wavernn.load_state_dict(checkpoint["model"]) |
|
print(" > Restoring Optimizer...") |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if "scheduler" in checkpoint: |
|
print(" > Restoring Generator LR Scheduler...") |
|
scheduler.load_state_dict(checkpoint["scheduler"]) |
|
scheduler.optimizer = optimizer |
|
if "scaler" in checkpoint and c.mixed_precision: |
|
print(" > Restoring AMP Scaler...") |
|
scaler.load_state_dict(checkpoint["scaler"]) |
|
except RuntimeError: |
|
|
|
print(" > Partial model initialization...") |
|
model_dict = model_wavernn.state_dict() |
|
model_dict = set_init_dict(model_dict, checkpoint["model"], c) |
|
model_wavernn.load_state_dict(model_dict) |
|
|
|
print(" > Model restored from step %d" % |
|
checkpoint["step"], flush=True) |
|
args.restore_step = checkpoint["step"] |
|
else: |
|
args.restore_step = 0 |
|
|
|
|
|
|
|
|
|
|
|
num_parameters = count_parameters(model_wavernn) |
|
print(" > Model has {} parameters".format(num_parameters), flush=True) |
|
|
|
if "best_loss" not in locals(): |
|
best_loss = float("inf") |
|
|
|
global_step = args.restore_step |
|
for epoch in range(0, c.epochs): |
|
c_logger.print_epoch_start(epoch, c.epochs) |
|
_, global_step = train(model_wavernn, optimizer, |
|
criterion, scheduler, scaler, ap, global_step, epoch) |
|
eval_avg_loss_dict = evaluate( |
|
model_wavernn, criterion, ap, global_step, epoch) |
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) |
|
target_loss = eval_avg_loss_dict["avg_model_loss"] |
|
best_loss = save_best_model( |
|
target_loss, |
|
best_loss, |
|
model_wavernn, |
|
optimizer, |
|
scheduler, |
|
None, |
|
None, |
|
None, |
|
global_step, |
|
epoch, |
|
OUT_PATH, |
|
model_losses=eval_avg_loss_dict, |
|
scaler=scaler.state_dict() if c.mixed_precision else None |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--continue_path", |
|
type=str, |
|
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', |
|
default="", |
|
required="--config_path" not in sys.argv, |
|
) |
|
parser.add_argument( |
|
"--restore_path", |
|
type=str, |
|
help="Model file to be restored. Use to finetune a model.", |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--config_path", |
|
type=str, |
|
help="Path to config file for training.", |
|
required="--continue_path" not in sys.argv, |
|
) |
|
parser.add_argument( |
|
"--debug", |
|
type=bool, |
|
default=False, |
|
help="Do not verify commit integrity to run training.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--rank", |
|
type=int, |
|
default=0, |
|
help="DISTRIBUTED: process rank for distributed training.", |
|
) |
|
parser.add_argument( |
|
"--group_id", type=str, default="", help="DISTRIBUTED: process group id." |
|
) |
|
args = parser.parse_args() |
|
|
|
if args.continue_path != "": |
|
args.output_path = args.continue_path |
|
args.config_path = os.path.join(args.continue_path, "config.json") |
|
list_of_files = glob.glob( |
|
args.continue_path + "/*.pth.tar" |
|
) |
|
latest_model_file = max(list_of_files, key=os.path.getctime) |
|
args.restore_path = latest_model_file |
|
print(f" > Training continues for {args.restore_path}") |
|
|
|
|
|
c = load_config(args.config_path) |
|
|
|
_ = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
OUT_PATH = args.continue_path |
|
if args.continue_path == "": |
|
OUT_PATH = create_experiment_folder( |
|
c.output_path, c.run_name, args.debug |
|
) |
|
|
|
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") |
|
|
|
c_logger = ConsoleLogger() |
|
|
|
if args.rank == 0: |
|
os.makedirs(AUDIO_PATH, exist_ok=True) |
|
new_fields = {} |
|
if args.restore_path: |
|
new_fields["restore_path"] = args.restore_path |
|
new_fields["github_branch"] = get_git_branch() |
|
copy_model_files( |
|
c, args.config_path, OUT_PATH, new_fields |
|
) |
|
os.chmod(AUDIO_PATH, 0o775) |
|
os.chmod(OUT_PATH, 0o775) |
|
|
|
LOG_DIR = OUT_PATH |
|
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") |
|
|
|
|
|
tb_logger.tb_add_text("model-description", c["run_description"], 0) |
|
|
|
try: |
|
main(args) |
|
except KeyboardInterrupt: |
|
remove_experiment_folder(OUT_PATH) |
|
try: |
|
sys.exit(0) |
|
except SystemExit: |
|
os._exit(0) |
|
except Exception: |
|
remove_experiment_folder(OUT_PATH) |
|
traceback.print_exc() |
|
sys.exit(1) |
|
|