Spaces:
Running
on
T4
Running
on
T4
import os | |
import time | |
import torch | |
import torch.multiprocessing | |
import wandb | |
from torch.optim.lr_scheduler import MultiStepLR | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from Architectures.Vocoder.AdversarialLoss import discriminator_adv_loss | |
from Architectures.Vocoder.AdversarialLoss import generator_adv_loss | |
from Architectures.Vocoder.FeatureMatchingLoss import feature_loss | |
from Architectures.Vocoder.MelSpecLoss import MelSpectrogramLoss | |
from Utility.utils import delete_old_checkpoints | |
from Utility.utils import get_most_recent_checkpoint | |
from run_weight_averaging import average_checkpoints | |
from run_weight_averaging import get_n_recent_checkpoints_paths | |
from run_weight_averaging import load_net_bigvgan | |
def train_loop(generator, | |
discriminator, | |
train_dataset, | |
device, | |
model_save_dir, | |
epochs_per_save=1, | |
path_to_checkpoint=None, | |
batch_size=32, | |
epochs=100, | |
resume=False, | |
generator_steps_per_discriminator_step=5, | |
generator_warmup=30000, | |
use_wandb=False, | |
finetune=False | |
): | |
step_counter = 0 | |
epoch = 0 | |
mel_l1 = MelSpectrogramLoss().to(device) | |
g = generator.to(device) | |
d = discriminator.to(device) | |
g.train() | |
d.train() | |
optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.001, weight_decay=0.0) | |
scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000]) | |
optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0) | |
scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000]) | |
train_loader = DataLoader(dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=8, | |
pin_memory=True, | |
drop_last=True, | |
prefetch_factor=2, | |
persistent_workers=True) | |
if resume: | |
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=model_save_dir) | |
if path_to_checkpoint is not None: | |
check_dict = torch.load(path_to_checkpoint, map_location=device) | |
if not finetune: | |
optimizer_g.load_state_dict(check_dict["generator_optimizer"]) | |
optimizer_d.load_state_dict(check_dict["discriminator_optimizer"]) | |
scheduler_g.load_state_dict(check_dict["generator_scheduler"]) | |
scheduler_d.load_state_dict(check_dict["discriminator_scheduler"]) | |
step_counter = check_dict["step_counter"] | |
d.load_state_dict(check_dict["discriminator"]) | |
g.load_state_dict(check_dict["generator"]) | |
start_time = time.time() | |
for _ in range(epochs): | |
epoch += 1 | |
discriminator_losses = list() | |
generator_losses = list() | |
mel_losses = list() | |
feat_match_losses = list() | |
adversarial_losses = list() | |
optimizer_g.zero_grad() | |
optimizer_d.zero_grad() | |
for datapoint in tqdm(train_loader): | |
############################ | |
# Generator # | |
############################ | |
gold_wave = datapoint[0].to(device).unsqueeze(1) | |
melspec = datapoint[1].to(device) | |
pred_wave, intermediate_wave_upsampled_twice, intermediate_wave_upsampled_once = g(melspec) | |
if torch.any(torch.isnan(pred_wave)): | |
print("A NaN in the wave! Skipping...") | |
continue | |
mel_loss = mel_l1(pred_wave.squeeze(1), gold_wave) | |
generator_total_loss = mel_loss * 85.0 | |
if step_counter > generator_warmup + 100: # a bit of warmup helps, but it's not that important | |
d_outs, d_fmaps = d(wave=pred_wave, | |
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice, | |
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once) | |
adversarial_loss = generator_adv_loss(d_outs) | |
adversarial_losses.append(adversarial_loss.item()) | |
generator_total_loss = generator_total_loss + adversarial_loss * 2 # based on own experience | |
d_gold_outs, d_gold_fmaps = d(gold_wave) | |
feature_matching_loss = feature_loss(d_gold_fmaps, d_fmaps) | |
feat_match_losses.append(feature_matching_loss.item()) | |
generator_total_loss = generator_total_loss + feature_matching_loss | |
if torch.isnan(generator_total_loss): | |
print("Loss turned to NaN, skipping. The GAN possibly collapsed.") | |
continue | |
step_counter += 1 | |
optimizer_g.zero_grad() | |
generator_total_loss.backward() | |
generator_losses.append(generator_total_loss.item()) | |
mel_losses.append(mel_loss.item()) | |
torch.nn.utils.clip_grad_norm_(g.parameters(), 10.0) | |
optimizer_g.step() | |
scheduler_g.step() | |
optimizer_g.zero_grad() | |
############################ | |
# Discriminator # | |
############################ | |
if step_counter > generator_warmup and step_counter % generator_steps_per_discriminator_step == 0: | |
d_outs, d_fmaps = d(wave=pred_wave.detach(), | |
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice.detach(), | |
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once.detach(), | |
discriminator_train_flag=True) | |
d_gold_outs, d_gold_fmaps = d(gold_wave, | |
discriminator_train_flag=True) # have to recompute unfortunately due to autograd behaviour | |
discriminator_loss = discriminator_adv_loss(d_gold_outs, d_outs) | |
optimizer_d.zero_grad() | |
discriminator_loss.backward() | |
discriminator_losses.append(discriminator_loss.item()) | |
torch.nn.utils.clip_grad_norm_(d.parameters(), 10.0) | |
optimizer_d.step() | |
scheduler_d.step() | |
optimizer_d.zero_grad() | |
########################## | |
# Epoch Complete # | |
########################## | |
if epoch % epochs_per_save == 0: | |
g.eval() | |
torch.save({ | |
"generator" : g.state_dict(), | |
"discriminator" : d.state_dict(), | |
"generator_optimizer" : optimizer_g.state_dict(), | |
"discriminator_optimizer": optimizer_d.state_dict(), | |
"generator_scheduler" : scheduler_g.state_dict(), | |
"discriminator_scheduler": scheduler_d.state_dict(), | |
"step_counter" : step_counter | |
}, os.path.join(model_save_dir, "checkpoint_{}.pt".format(step_counter))) | |
g.train() | |
delete_old_checkpoints(model_save_dir, keep=5) | |
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=model_save_dir, n=2) | |
averaged_model, _ = average_checkpoints(checkpoint_paths, load_func=load_net_bigvgan) | |
torch.save(averaged_model.state_dict(), os.path.join(model_save_dir, "best.pt")) | |
# LOGGING | |
log_dict = dict() | |
log_dict["Generator Loss"] = round(sum(generator_losses) / len(generator_losses), 3) | |
log_dict["Mel Loss"] = round(sum(mel_losses) / len(mel_losses), 3) | |
if len(feat_match_losses) > 0: | |
log_dict["Feature Matching Loss"] = round(sum(feat_match_losses) / len(feat_match_losses), 3) | |
if len(adversarial_losses) > 0: | |
log_dict["Adversarial Loss"] = round(sum(adversarial_losses) / len(adversarial_losses), 3) | |
if len(discriminator_losses) > 0: | |
log_dict["Discriminator Loss"] = round(sum(discriminator_losses) / len(discriminator_losses), 3) | |
print("Time elapsed for this run: {} Minutes".format(round((time.time() - start_time) / 60))) | |
for key in log_dict: | |
print(f"{key}: {log_dict[key]}") | |
if use_wandb: | |
wandb.log(log_dict, step=step_counter) | |