File size: 3,601 Bytes
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
https://alexander-stasiuk.medium.com/pytorch-weights-averaging-e2c0fa611a0c
"""

import os

import torch

from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
from Utility.storage_config import MODELS_DIR


def load_net_toucan(path):
    check_dict = torch.load(path, map_location=torch.device("cpu"))
    net = ToucanTTS(weights=check_dict["model"], config=check_dict["config"])
    return net, check_dict["default_emb"]


def load_net_bigvgan(path):
    check_dict = torch.load(path, map_location=torch.device("cpu"))
    net = HiFiGAN(weights=check_dict["generator"])
    return net, None


def get_n_recent_checkpoints_paths(checkpoint_dir, n=5):
    print("selecting checkpoints...")
    checkpoint_list = list()
    for el in os.listdir(checkpoint_dir):
        if el.endswith(".pt") and el.startswith("checkpoint_"):
            try:
                checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
            except RuntimeError:
                pass
    if len(checkpoint_list) == 0:
        return None
    elif len(checkpoint_list) < n:
        n = len(checkpoint_list)
    checkpoint_list.sort(reverse=True)
    return [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:n]]


def average_checkpoints(list_of_checkpoint_paths, load_func):
    # COLLECT CHECKPOINTS
    if list_of_checkpoint_paths is None or len(list_of_checkpoint_paths) == 0:
        return None
    checkpoints_weights = {}
    model = None
    default_embed = None

    # LOAD CHECKPOINTS
    for path_to_checkpoint in list_of_checkpoint_paths:
        print("loading model {}".format(path_to_checkpoint))
        model, default_embed = load_func(path=path_to_checkpoint)
        checkpoints_weights[path_to_checkpoint] = dict(model.named_parameters())

    # AVERAGE CHECKPOINTS
    params = model.named_parameters()
    dict_params = dict(params)
    checkpoint_amount = len(checkpoints_weights)
    print("averaging...")
    for name in dict_params.keys():
        custom_params = None
        for _, checkpoint_parameters in checkpoints_weights.items():
            if custom_params is None:
                custom_params = checkpoint_parameters[name].data
            else:
                custom_params += checkpoint_parameters[name].data
        dict_params[name].data.copy_(custom_params / checkpoint_amount)
    model_dict = model.state_dict()
    model_dict.update(dict_params)
    model.load_state_dict(model_dict)
    model.eval()
    return model, default_embed


def save_model_for_use(model, name="", default_embed=None, dict_name="model"):
    print("saving model...")
    torch.save({dict_name: model.state_dict(), "default_emb": default_embed, "config": model.config}, name)
    print("...done!")


def make_best_in_all():
    for model_dir in os.listdir(MODELS_DIR):
        if os.path.isdir(os.path.join(MODELS_DIR, model_dir)):
            if "ToucanTTS" in model_dir:
                checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3)
                if checkpoint_paths is None:
                    continue
                averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan)
                save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(MODELS_DIR, model_dir, "best.pt"))


def count_parameters(net):
    return sum(p.numel() for p in net.parameters() if p.requires_grad)


if __name__ == '__main__':
    make_best_in_all()