Spaces:
Running
on
T4
Running
on
T4
File size: 8,492 Bytes
9e275b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import time
import torch
import torch.multiprocessing
import wandb
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from Architectures.Vocoder.AdversarialLoss import discriminator_adv_loss
from Architectures.Vocoder.AdversarialLoss import generator_adv_loss
from Architectures.Vocoder.FeatureMatchingLoss import feature_loss
from Architectures.Vocoder.MelSpecLoss import MelSpectrogramLoss
from Utility.utils import delete_old_checkpoints
from Utility.utils import get_most_recent_checkpoint
from run_weight_averaging import average_checkpoints
from run_weight_averaging import get_n_recent_checkpoints_paths
from run_weight_averaging import load_net_bigvgan
def train_loop(generator,
discriminator,
train_dataset,
device,
model_save_dir,
epochs_per_save=1,
path_to_checkpoint=None,
batch_size=32,
epochs=100,
resume=False,
generator_steps_per_discriminator_step=5,
generator_warmup=30000,
use_wandb=False,
finetune=False
):
step_counter = 0
epoch = 0
mel_l1 = MelSpectrogramLoss().to(device)
g = generator.to(device)
d = discriminator.to(device)
g.train()
d.train()
optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.001, weight_decay=0.0)
scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000])
optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0)
scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000])
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
drop_last=True,
prefetch_factor=2,
persistent_workers=True)
if resume:
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=model_save_dir)
if path_to_checkpoint is not None:
check_dict = torch.load(path_to_checkpoint, map_location=device)
if not finetune:
optimizer_g.load_state_dict(check_dict["generator_optimizer"])
optimizer_d.load_state_dict(check_dict["discriminator_optimizer"])
scheduler_g.load_state_dict(check_dict["generator_scheduler"])
scheduler_d.load_state_dict(check_dict["discriminator_scheduler"])
step_counter = check_dict["step_counter"]
d.load_state_dict(check_dict["discriminator"])
g.load_state_dict(check_dict["generator"])
start_time = time.time()
for _ in range(epochs):
epoch += 1
discriminator_losses = list()
generator_losses = list()
mel_losses = list()
feat_match_losses = list()
adversarial_losses = list()
optimizer_g.zero_grad()
optimizer_d.zero_grad()
for datapoint in tqdm(train_loader):
############################
# Generator #
############################
gold_wave = datapoint[0].to(device).unsqueeze(1)
melspec = datapoint[1].to(device)
pred_wave, intermediate_wave_upsampled_twice, intermediate_wave_upsampled_once = g(melspec)
if torch.any(torch.isnan(pred_wave)):
print("A NaN in the wave! Skipping...")
continue
mel_loss = mel_l1(pred_wave.squeeze(1), gold_wave)
generator_total_loss = mel_loss * 85.0
if step_counter > generator_warmup + 100: # a bit of warmup helps, but it's not that important
d_outs, d_fmaps = d(wave=pred_wave,
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice,
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once)
adversarial_loss = generator_adv_loss(d_outs)
adversarial_losses.append(adversarial_loss.item())
generator_total_loss = generator_total_loss + adversarial_loss * 2 # based on own experience
d_gold_outs, d_gold_fmaps = d(gold_wave)
feature_matching_loss = feature_loss(d_gold_fmaps, d_fmaps)
feat_match_losses.append(feature_matching_loss.item())
generator_total_loss = generator_total_loss + feature_matching_loss
if torch.isnan(generator_total_loss):
print("Loss turned to NaN, skipping. The GAN possibly collapsed.")
continue
step_counter += 1
optimizer_g.zero_grad()
generator_total_loss.backward()
generator_losses.append(generator_total_loss.item())
mel_losses.append(mel_loss.item())
torch.nn.utils.clip_grad_norm_(g.parameters(), 10.0)
optimizer_g.step()
scheduler_g.step()
optimizer_g.zero_grad()
############################
# Discriminator #
############################
if step_counter > generator_warmup and step_counter % generator_steps_per_discriminator_step == 0:
d_outs, d_fmaps = d(wave=pred_wave.detach(),
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice.detach(),
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once.detach(),
discriminator_train_flag=True)
d_gold_outs, d_gold_fmaps = d(gold_wave,
discriminator_train_flag=True) # have to recompute unfortunately due to autograd behaviour
discriminator_loss = discriminator_adv_loss(d_gold_outs, d_outs)
optimizer_d.zero_grad()
discriminator_loss.backward()
discriminator_losses.append(discriminator_loss.item())
torch.nn.utils.clip_grad_norm_(d.parameters(), 10.0)
optimizer_d.step()
scheduler_d.step()
optimizer_d.zero_grad()
##########################
# Epoch Complete #
##########################
if epoch % epochs_per_save == 0:
g.eval()
torch.save({
"generator" : g.state_dict(),
"discriminator" : d.state_dict(),
"generator_optimizer" : optimizer_g.state_dict(),
"discriminator_optimizer": optimizer_d.state_dict(),
"generator_scheduler" : scheduler_g.state_dict(),
"discriminator_scheduler": scheduler_d.state_dict(),
"step_counter" : step_counter
}, os.path.join(model_save_dir, "checkpoint_{}.pt".format(step_counter)))
g.train()
delete_old_checkpoints(model_save_dir, keep=5)
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=model_save_dir, n=2)
averaged_model, _ = average_checkpoints(checkpoint_paths, load_func=load_net_bigvgan)
torch.save(averaged_model.state_dict(), os.path.join(model_save_dir, "best.pt"))
# LOGGING
log_dict = dict()
log_dict["Generator Loss"] = round(sum(generator_losses) / len(generator_losses), 3)
log_dict["Mel Loss"] = round(sum(mel_losses) / len(mel_losses), 3)
if len(feat_match_losses) > 0:
log_dict["Feature Matching Loss"] = round(sum(feat_match_losses) / len(feat_match_losses), 3)
if len(adversarial_losses) > 0:
log_dict["Adversarial Loss"] = round(sum(adversarial_losses) / len(adversarial_losses), 3)
if len(discriminator_losses) > 0:
log_dict["Discriminator Loss"] = round(sum(discriminator_losses) / len(discriminator_losses), 3)
print("Time elapsed for this run: {} Minutes".format(round((time.time() - start_time) / 60)))
for key in log_dict:
print(f"{key}: {log_dict[key]}")
if use_wandb:
wandb.log(log_dict, step=step_counter)
|