Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
8.49 kB
import os
import time
import torch
import torch.multiprocessing
import wandb
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from Architectures.Vocoder.AdversarialLoss import discriminator_adv_loss
from Architectures.Vocoder.AdversarialLoss import generator_adv_loss
from Architectures.Vocoder.FeatureMatchingLoss import feature_loss
from Architectures.Vocoder.MelSpecLoss import MelSpectrogramLoss
from Utility.utils import delete_old_checkpoints
from Utility.utils import get_most_recent_checkpoint
from run_weight_averaging import average_checkpoints
from run_weight_averaging import get_n_recent_checkpoints_paths
from run_weight_averaging import load_net_bigvgan
def train_loop(generator,
discriminator,
train_dataset,
device,
model_save_dir,
epochs_per_save=1,
path_to_checkpoint=None,
batch_size=32,
epochs=100,
resume=False,
generator_steps_per_discriminator_step=5,
generator_warmup=30000,
use_wandb=False,
finetune=False
):
step_counter = 0
epoch = 0
mel_l1 = MelSpectrogramLoss().to(device)
g = generator.to(device)
d = discriminator.to(device)
g.train()
d.train()
optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.001, weight_decay=0.0)
scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000])
optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0)
scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000])
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
drop_last=True,
prefetch_factor=2,
persistent_workers=True)
if resume:
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=model_save_dir)
if path_to_checkpoint is not None:
check_dict = torch.load(path_to_checkpoint, map_location=device)
if not finetune:
optimizer_g.load_state_dict(check_dict["generator_optimizer"])
optimizer_d.load_state_dict(check_dict["discriminator_optimizer"])
scheduler_g.load_state_dict(check_dict["generator_scheduler"])
scheduler_d.load_state_dict(check_dict["discriminator_scheduler"])
step_counter = check_dict["step_counter"]
d.load_state_dict(check_dict["discriminator"])
g.load_state_dict(check_dict["generator"])
start_time = time.time()
for _ in range(epochs):
epoch += 1
discriminator_losses = list()
generator_losses = list()
mel_losses = list()
feat_match_losses = list()
adversarial_losses = list()
optimizer_g.zero_grad()
optimizer_d.zero_grad()
for datapoint in tqdm(train_loader):
############################
# Generator #
############################
gold_wave = datapoint[0].to(device).unsqueeze(1)
melspec = datapoint[1].to(device)
pred_wave, intermediate_wave_upsampled_twice, intermediate_wave_upsampled_once = g(melspec)
if torch.any(torch.isnan(pred_wave)):
print("A NaN in the wave! Skipping...")
continue
mel_loss = mel_l1(pred_wave.squeeze(1), gold_wave)
generator_total_loss = mel_loss * 85.0
if step_counter > generator_warmup + 100: # a bit of warmup helps, but it's not that important
d_outs, d_fmaps = d(wave=pred_wave,
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice,
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once)
adversarial_loss = generator_adv_loss(d_outs)
adversarial_losses.append(adversarial_loss.item())
generator_total_loss = generator_total_loss + adversarial_loss * 2 # based on own experience
d_gold_outs, d_gold_fmaps = d(gold_wave)
feature_matching_loss = feature_loss(d_gold_fmaps, d_fmaps)
feat_match_losses.append(feature_matching_loss.item())
generator_total_loss = generator_total_loss + feature_matching_loss
if torch.isnan(generator_total_loss):
print("Loss turned to NaN, skipping. The GAN possibly collapsed.")
continue
step_counter += 1
optimizer_g.zero_grad()
generator_total_loss.backward()
generator_losses.append(generator_total_loss.item())
mel_losses.append(mel_loss.item())
torch.nn.utils.clip_grad_norm_(g.parameters(), 10.0)
optimizer_g.step()
scheduler_g.step()
optimizer_g.zero_grad()
############################
# Discriminator #
############################
if step_counter > generator_warmup and step_counter % generator_steps_per_discriminator_step == 0:
d_outs, d_fmaps = d(wave=pred_wave.detach(),
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice.detach(),
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once.detach(),
discriminator_train_flag=True)
d_gold_outs, d_gold_fmaps = d(gold_wave,
discriminator_train_flag=True) # have to recompute unfortunately due to autograd behaviour
discriminator_loss = discriminator_adv_loss(d_gold_outs, d_outs)
optimizer_d.zero_grad()
discriminator_loss.backward()
discriminator_losses.append(discriminator_loss.item())
torch.nn.utils.clip_grad_norm_(d.parameters(), 10.0)
optimizer_d.step()
scheduler_d.step()
optimizer_d.zero_grad()
##########################
# Epoch Complete #
##########################
if epoch % epochs_per_save == 0:
g.eval()
torch.save({
"generator" : g.state_dict(),
"discriminator" : d.state_dict(),
"generator_optimizer" : optimizer_g.state_dict(),
"discriminator_optimizer": optimizer_d.state_dict(),
"generator_scheduler" : scheduler_g.state_dict(),
"discriminator_scheduler": scheduler_d.state_dict(),
"step_counter" : step_counter
}, os.path.join(model_save_dir, "checkpoint_{}.pt".format(step_counter)))
g.train()
delete_old_checkpoints(model_save_dir, keep=5)
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=model_save_dir, n=2)
averaged_model, _ = average_checkpoints(checkpoint_paths, load_func=load_net_bigvgan)
torch.save(averaged_model.state_dict(), os.path.join(model_save_dir, "best.pt"))
# LOGGING
log_dict = dict()
log_dict["Generator Loss"] = round(sum(generator_losses) / len(generator_losses), 3)
log_dict["Mel Loss"] = round(sum(mel_losses) / len(mel_losses), 3)
if len(feat_match_losses) > 0:
log_dict["Feature Matching Loss"] = round(sum(feat_match_losses) / len(feat_match_losses), 3)
if len(adversarial_losses) > 0:
log_dict["Adversarial Loss"] = round(sum(adversarial_losses) / len(adversarial_losses), 3)
if len(discriminator_losses) > 0:
log_dict["Discriminator Loss"] = round(sum(discriminator_losses) / len(discriminator_losses), 3)
print("Time elapsed for this run: {} Minutes".format(round((time.time() - start_time) / 60)))
for key in log_dict:
print(f"{key}: {log_dict[key]}")
if use_wandb:
wandb.log(log_dict, step=step_counter)