Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import torch | |
import torch.multiprocessing | |
import wandb | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor | |
from Utility.WarmupScheduler import ToucanWarmupScheduler as WarmupScheduler | |
from Utility.utils import delete_old_checkpoints | |
from Utility.utils import get_most_recent_checkpoint | |
from Utility.utils import plot_progress_spec_toucantts | |
from run_weight_averaging import average_checkpoints | |
from run_weight_averaging import get_n_recent_checkpoints_paths | |
from run_weight_averaging import load_net_toucan | |
from run_weight_averaging import save_model_for_use | |
def collate_and_pad(batch): | |
# text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id, speaker embedding | |
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True).float(), | |
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1), | |
[datapoint[2] for datapoint in batch], | |
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1), | |
pad_sequence([datapoint[4] for datapoint in batch], batch_first=True), | |
pad_sequence([datapoint[5] for datapoint in batch], batch_first=True), | |
pad_sequence([datapoint[6] for datapoint in batch], batch_first=True), | |
None, | |
torch.stack([datapoint[8] for datapoint in batch]), | |
torch.stack([datapoint[9] for datapoint in batch])) | |
def train_loop(net, | |
train_dataset, | |
device, | |
save_directory, | |
batch_size, | |
lang, | |
lr, | |
warmup_steps, | |
path_to_checkpoint, | |
fine_tune, | |
resume, | |
steps, | |
use_wandb, | |
train_sampler, | |
gpu_count, | |
steps_per_checkpoint | |
): | |
""" | |
see train loop arbiter for explanations of the arguments | |
""" | |
net = net.to(device) | |
if gpu_count > 1: | |
rank = int(os.environ["LOCAL_RANK"]) | |
else: | |
rank = 0 | |
if steps_per_checkpoint is None: | |
steps_per_checkpoint = len(train_dataset) // batch_size | |
if steps < warmup_steps * 5: | |
print(f"too much warmup given the amount of steps, reducing warmup to {warmup_steps} steps") | |
warmup_steps = steps // 5 | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) | |
train_loader = DataLoader(dataset=train_dataset, | |
batch_sampler=batch_sampler_train, | |
num_workers=0, # has to be 0, otherwise copies of the dataset are created, which is not feasible for large scale trainings. This is not optimal for small trainings, but necessary for scalability. | |
pin_memory=True, | |
prefetch_factor=None, | |
collate_fn=collate_and_pad) | |
ap = CodecAudioPreprocessor(input_sr=-1, device=device) | |
spec_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device) | |
step_counter = 0 | |
if isinstance(net, torch.nn.parallel.DistributedDataParallel): | |
model = net.module | |
else: | |
model = net | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
scheduler = WarmupScheduler(optimizer, peak_lr=lr, warmup_steps=warmup_steps, max_steps=steps) | |
epoch = 0 | |
if resume: | |
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory) | |
if path_to_checkpoint is not None: | |
check_dict = torch.load(path_to_checkpoint, map_location=device) | |
model.load_state_dict(check_dict["model"]) | |
if not fine_tune: | |
optimizer.load_state_dict(check_dict["optimizer"]) | |
scheduler.load_state_dict(check_dict["scheduler"]) | |
step_counter = check_dict["step_counter"] | |
start_time = time.time() | |
regression_losses_total = list() | |
stochastic_losses_total = list() | |
duration_losses_total = list() | |
pitch_losses_total = list() | |
energy_losses_total = list() | |
while True: | |
net.train() | |
epoch += 1 | |
for batch in tqdm(train_loader): | |
text_tensors = batch[0].to(device) | |
text_lengths = batch[1].squeeze().to(device) | |
speech_indexes = batch[2] | |
speech_lengths = batch[3].squeeze().to(device) | |
gold_durations = batch[4].to(device) | |
gold_pitch = batch[6].to(device) # mind the switched order | |
gold_energy = batch[5].to(device) # mind the switched order | |
lang_ids = batch[8].squeeze(1).to(device) | |
speech_batch = list() # I wish this could be done in the collate function or in the getitem, but using DL models in multiprocessing on very large datasets causes just way too many issues. | |
for speech_sample in speech_indexes: | |
with torch.inference_mode(): | |
wave = ap.indexes_to_audio(speech_sample.int().to(device)).detach() | |
mel = spec_extractor.audio_to_mel_spec_tensor(wave, explicit_sampling_rate=16000).transpose(0, 1).detach().cpu() | |
gold_speech_sample = mel.clone() | |
speech_batch.append(gold_speech_sample) | |
gold_speech = pad_sequence(speech_batch, batch_first=True).to(device) | |
run_stochastic = (step_counter > warmup_steps) or fine_tune | |
train_loss = 0.0 | |
utterance_embedding = batch[9].to(device) | |
regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss = net( | |
text_tensors=text_tensors, | |
text_lengths=text_lengths, | |
gold_speech=gold_speech, | |
speech_lengths=speech_lengths, | |
gold_durations=gold_durations, | |
gold_pitch=gold_pitch, | |
gold_energy=gold_energy, | |
utterance_embedding=utterance_embedding, | |
lang_ids=lang_ids, | |
return_feats=False, | |
run_stochastic=run_stochastic | |
) | |
if torch.isnan(regression_loss) or torch.isnan(duration_loss) or torch.isnan(pitch_loss) or torch.isnan(energy_loss): | |
print("One of the losses turned to NaN! Skipping this batch ...") | |
continue | |
train_loss = train_loss + duration_loss | |
train_loss = train_loss + pitch_loss | |
train_loss = train_loss + energy_loss | |
train_loss = train_loss + regression_loss | |
regression_losses_total.append(regression_loss.item()) | |
duration_losses_total.append(duration_loss.item()) | |
pitch_losses_total.append(pitch_loss.item()) | |
energy_losses_total.append(energy_loss.item()) | |
if stochastic_loss is not None: | |
if torch.isnan(stochastic_loss): | |
print("Flow loss turned to NaN! Skipping this batch ...") | |
continue | |
stochastic_losses_total.append(stochastic_loss.item()) | |
train_loss = train_loss + stochastic_loss | |
else: | |
stochastic_losses_total.append(0) | |
optimizer.zero_grad() | |
if type(train_loss) is float: | |
print("There is no loss for this step! Skipping ...") | |
continue | |
if gpu_count > 1: | |
torch.distributed.barrier() | |
train_loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0, error_if_nonfinite=False) | |
optimizer.step() | |
scheduler.step() | |
step_counter += 1 | |
if step_counter % steps_per_checkpoint == 0: | |
# evaluation interval is happening | |
if rank == 0: | |
net.eval() | |
default_embedding = train_dataset[0][9].to(device) | |
torch.save({ | |
"model" : model.state_dict(), | |
"optimizer" : optimizer.state_dict(), | |
"step_counter": step_counter, | |
"scheduler" : scheduler.state_dict(), | |
"default_emb" : default_embedding, | |
"config" : model.config | |
}, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter))) | |
delete_old_checkpoints(save_directory, keep=5) | |
print(f"\nEpoch: {epoch}") | |
print(f"Time elapsed: {round((time.time() - start_time) / 60)} Minutes") | |
print("Reconstruction Loss: {}".format(round(sum(regression_losses_total) / len(regression_losses_total), 3))) | |
print(f"Steps: {step_counter}\n") | |
if use_wandb: | |
wandb.log({ | |
"regression_loss": round(sum(regression_losses_total) / len(regression_losses_total), 5), | |
"stochastic_loss": round(sum(stochastic_losses_total) / len(stochastic_losses_total), 5), | |
"duration_loss" : round(sum(duration_losses_total) / len(duration_losses_total), 5), | |
"pitch_loss" : round(sum(pitch_losses_total) / len(pitch_losses_total), 5), | |
"energy_loss" : round(sum(energy_losses_total) / len(energy_losses_total), 5), | |
"learning_rate" : optimizer.param_groups[0]['lr'] | |
}, step=step_counter) | |
regression_losses_total = list() | |
stochastic_losses_total = list() | |
duration_losses_total = list() | |
pitch_losses_total = list() | |
energy_losses_total = list() | |
path_to_most_recent_plot = plot_progress_spec_toucantts(model, | |
device, | |
save_dir=save_directory, | |
step=step_counter, | |
lang=lang, | |
default_emb=default_embedding, | |
run_stochastic=run_stochastic) | |
if use_wandb: | |
wandb.log({ | |
"progress_plot": wandb.Image(path_to_most_recent_plot) | |
}, step=step_counter) | |
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=save_directory, n=1) | |
averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan) | |
save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(save_directory, "best.pt")) | |
if step_counter > steps: | |
return # DONE | |
net.train() | |
print("\n\n\nEPOCH COMPLETE\n\n\n") | |