File size: 1,995 Bytes
51a61da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections import OrderedDict

import torch

import utils
from models import SynthesizerTrn


def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith('module'):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ','.join(k.split('.')[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


def removeOptimizer(config: str, input_model: str, output_model: str):
    hps = utils.get_hparams_from_file(config)

    net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
                           hps.train.segment_size // hps.data.hop_length,
                           **hps.model)

    optim_g = torch.optim.AdamW(net_g.parameters(),
                                hps.train.learning_rate,
                                betas=hps.train.betas,
                                eps=hps.train.eps)

    state_dict_g = torch.load(input_model, map_location="cpu")
    new_dict_g = copyStateDict(state_dict_g)
    keys = []
    for k, v in new_dict_g['model'].items():
        keys.append(k)

    new_dict_g = {k: new_dict_g['model'][k] for k in keys}

    torch.save(
        {
            'model': new_dict_g,
            'iteration': 0,
            'optimizer': optim_g.state_dict(),
            'learning_rate': 0.0001
        }, output_model)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        type=str,
                        default='configs/config.json')
    parser.add_argument("-i", "--input", type=str)
    parser.add_argument("-o", "--output", type=str, default=None)

    args = parser.parse_args()

    output = args.output

    if output is None:
        import os.path
        filename, ext = os.path.splitext(args.input)
        output = filename + "_release" + ext

    removeOptimizer(args.config, args.input, output)