EnglishToucan / Architectures /ToucanTTS /toucantts_train_loop_arbiter.py
Flux9665's picture
initial commit
6faeba1
raw
history blame
4.67 kB
import torch
from Architectures.ToucanTTS.toucantts_meta_train_loop import train_loop as multi_language_loop
from Architectures.ToucanTTS.toucantts_train_loop import train_loop as mono_language_loop
def train_loop(net, # an already initialized ToucanTTS model that should be trained.
datasets,
# a list of datasets to train on. Every dataset within a language should already be a concat dataset of all the datasets
# in that language. So every list entry here should be a (combined) dataset for each language. For the case of a monolingual model, pass a list
# with only one dataset in it. This will trigger the arbiter to call the train loop for simple one language training runs rather than the complex
# LAML based one.
train_samplers, # the sampler(s) for the dataloader(s) (gpu_count or single GPU use different ones)
gpu_count, # amount of GPUs to use
device, # the device where this training should run on.
save_directory, # directory where the models and visualizations should be saved.
steps_per_checkpoint=None, # how many steps should be trained before a checkpoint is created. This is only relevant for the multilingual case,
# the monolingual case will do this once per epoch, regardless of the steps.
path_to_checkpoint=None, # path to a trained checkpoint to either continue training or fine-tune from.
lr=0.0001, # learning rate of the model.
resume=False, # whether to automatically load the most recent checkpoint and resume training from it.
warmup_steps=4000, # how many steps until the learning rate reaches the specified value and starts decreasing again.
use_wandb=False, # whether to use online experiment tracking with weights and biases. Requires prior CLI login.
batch_size=32, # how many samples to put into one batch. Higher batch size is more stable, but requires more VRAM.
eval_lang="eng", # in which language the evaluation sentence is to be plotted.
fine_tune=False, # whether to use the provided checkpoint as basis for fine-tuning.
steps=200000, # how many updates to run until training is completed
use_less_loss=False, # whether to use the loss that enforces a structure in the language embedding space
freeze_lang_embs=False, # whether to use the language embeddings from a checkpoint without modifying them, to maintain compatibility with the zero-shot method. This treats language embeddings from the given checkpoint as constants.
):
torch.multiprocessing.set_start_method('spawn', force=True)
if type(datasets) != list:
datasets = [datasets]
if len(datasets) > 1:
multi_language_loop(net=net,
datasets=datasets,
train_samplers=train_samplers,
device=device,
save_directory=save_directory,
batch_size=batch_size,
steps=steps,
steps_per_checkpoint=steps_per_checkpoint,
lr=lr,
lang=eval_lang,
path_to_checkpoint=path_to_checkpoint,
resume=resume,
fine_tune=fine_tune,
warmup_steps=warmup_steps,
use_wandb=use_wandb,
gpu_count=gpu_count,
use_less_loss=use_less_loss,
freeze_lang_embs=freeze_lang_embs
)
else:
mono_language_loop(net=net,
train_dataset=datasets[0],
train_sampler=train_samplers[0],
device=device,
save_directory=save_directory,
batch_size=batch_size,
lang=eval_lang,
lr=lr,
warmup_steps=warmup_steps,
path_to_checkpoint=path_to_checkpoint,
fine_tune=fine_tune,
resume=resume,
steps=steps,
use_wandb=use_wandb,
gpu_count=gpu_count,
steps_per_checkpoint=steps_per_checkpoint
)