File size: 4,671 Bytes
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch

from Architectures.ToucanTTS.toucantts_meta_train_loop import train_loop as multi_language_loop
from Architectures.ToucanTTS.toucantts_train_loop import train_loop as mono_language_loop


def train_loop(net,  # an already initialized ToucanTTS model that should be trained.
               datasets,
               # a list of datasets to train on. Every dataset within a language should already be a concat dataset of all the datasets
               # in that language. So every list entry here should be a (combined) dataset for each language. For the case of a monolingual model, pass a list
               # with only one dataset in it. This will trigger the arbiter to call the train loop for simple one language training runs rather than the complex
               # LAML based one.
               train_samplers,  # the sampler(s) for the dataloader(s) (gpu_count or single GPU use different ones)
               gpu_count,  # amount of GPUs to use
               device,  # the device where this training should run on.
               save_directory,  # directory where the models and visualizations should be saved.
               steps_per_checkpoint=None,  # how many steps should be trained before a checkpoint is created. This is only relevant for the multilingual case,
               # the monolingual case will do this once per epoch, regardless of the steps.
               path_to_checkpoint=None,  # path to a trained checkpoint to either continue training or fine-tune from.
               lr=0.0001,  # learning rate of the model.
               resume=False,  # whether to automatically load the most recent checkpoint and resume training from it.
               warmup_steps=4000,  # how many steps until the learning rate reaches the specified value and starts decreasing again.
               use_wandb=False,  # whether to use online experiment tracking with weights and biases. Requires prior CLI login.
               batch_size=32,  # how many samples to put into one batch. Higher batch size is more stable, but requires more VRAM.
               eval_lang="eng",  # in which language the evaluation sentence is to be plotted.
               fine_tune=False,  # whether to use the provided checkpoint as basis for fine-tuning.
               steps=200000,  # how many updates to run until training is completed
               use_less_loss=False,  # whether to use the loss that enforces a structure in the language embedding space
               freeze_lang_embs=False,  # whether to use the language embeddings from a checkpoint without modifying them, to maintain compatibility with the zero-shot method. This treats language embeddings from the given checkpoint as constants.
               ):
    torch.multiprocessing.set_start_method('spawn', force=True)
    if type(datasets) != list:
        datasets = [datasets]
    if len(datasets) > 1:
        multi_language_loop(net=net,
                            datasets=datasets,
                            train_samplers=train_samplers,
                            device=device,
                            save_directory=save_directory,
                            batch_size=batch_size,
                            steps=steps,
                            steps_per_checkpoint=steps_per_checkpoint,
                            lr=lr,
                            lang=eval_lang,
                            path_to_checkpoint=path_to_checkpoint,
                            resume=resume,
                            fine_tune=fine_tune,
                            warmup_steps=warmup_steps,
                            use_wandb=use_wandb,
                            gpu_count=gpu_count,
                            use_less_loss=use_less_loss,
                            freeze_lang_embs=freeze_lang_embs
                            )
    else:
        mono_language_loop(net=net,
                           train_dataset=datasets[0],
                           train_sampler=train_samplers[0],
                           device=device,
                           save_directory=save_directory,
                           batch_size=batch_size,
                           lang=eval_lang,
                           lr=lr,
                           warmup_steps=warmup_steps,
                           path_to_checkpoint=path_to_checkpoint,
                           fine_tune=fine_tune,
                           resume=resume,
                           steps=steps,
                           use_wandb=use_wandb,
                           gpu_count=gpu_count,
                           steps_per_checkpoint=steps_per_checkpoint
                           )