File size: 5,843 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import datetime
import os
import re

import torch
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.utils.generic_utils import check_argument


def to_camel(text):
    text = text.capitalize()
    return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)


def setup_model(c):
    model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
                           c.model['lstm_dim'], c.model['num_lstm_layers'])
    return model


def save_checkpoint(model, optimizer, model_loss, out_path,
                    current_step, epoch):
    checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
    checkpoint_path = os.path.join(out_path, checkpoint_path)
    print(" | | > Checkpoint saving : {}".format(checkpoint_path))

    new_state_dict = model.state_dict()
    state = {
        'model': new_state_dict,
        'optimizer': optimizer.state_dict() if optimizer is not None else None,
        'step': current_step,
        'epoch': epoch,
        'loss': model_loss,
        'date': datetime.date.today().strftime("%B %d, %Y"),
    }
    torch.save(state, checkpoint_path)


def save_best_model(model, optimizer, model_loss, best_loss, out_path,
                    current_step):
    if model_loss < best_loss:
        new_state_dict = model.state_dict()
        state = {
            'model': new_state_dict,
            'optimizer': optimizer.state_dict(),
            'step': current_step,
            'loss': model_loss,
            'date': datetime.date.today().strftime("%B %d, %Y"),
        }
        best_loss = model_loss
        bestmodel_path = 'best_model.pth.tar'
        bestmodel_path = os.path.join(out_path, bestmodel_path)
        print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
            model_loss, bestmodel_path))
        torch.save(state, bestmodel_path)
    return best_loss


def check_config_speaker_encoder(c):
    """Check the config.json file of the speaker encoder"""
    check_argument('run_name', c, restricted=True, val_type=str)
    check_argument('run_description', c, val_type=str)

    # audio processing parameters
    check_argument('audio', c, restricted=True, val_type=dict)
    check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
    check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
    check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
    check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
    check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
    check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
    check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
    check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
    check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
    check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)

    # training parameters
    check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
    check_argument('grad_clip', c, restricted=True, val_type=float)
    check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
    check_argument('lr', c, restricted=True, val_type=float, min_val=0)
    check_argument('lr_decay', c, restricted=True, val_type=bool)
    check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
    check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
    check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
    check_argument('num_loader_workers', c, restricted=True, val_type=int)
    check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)

    # checkpoint and output parameters
    check_argument('steps_plot_stats', c, restricted=True, val_type=int)
    check_argument('checkpoint', c, restricted=True, val_type=bool)
    check_argument('save_step', c, restricted=True, val_type=int)
    check_argument('print_step', c, restricted=True, val_type=int)
    check_argument('output_path', c, restricted=True, val_type=str)

    # model parameters
    check_argument('model', c, restricted=True, val_type=dict)
    check_argument('input_dim', c['model'], restricted=True, val_type=int)
    check_argument('proj_dim', c['model'], restricted=True, val_type=int)
    check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
    check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
    check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)

    # in-memory storage parameters
    check_argument('storage', c, restricted=True, val_type=dict)
    check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
    check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
    check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)

    # datasets - checking only the first entry
    check_argument('datasets', c, restricted=True, val_type=list)
    for dataset_entry in c['datasets']:
        check_argument('name', dataset_entry, restricted=True, val_type=str)
        check_argument('path', dataset_entry, restricted=True, val_type=str)
        check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
        check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)