File size: 5,160 Bytes
fa90792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch

import audiosr.hifigan as hifigan


def get_vocoder_config():
    return {
        "resblock": "1",
        "num_gpus": 6,
        "batch_size": 16,
        "learning_rate": 0.0002,
        "adam_b1": 0.8,
        "adam_b2": 0.99,
        "lr_decay": 0.999,
        "seed": 1234,
        "upsample_rates": [5, 4, 2, 2, 2],
        "upsample_kernel_sizes": [16, 16, 8, 4, 4],
        "upsample_initial_channel": 1024,
        "resblock_kernel_sizes": [3, 7, 11],
        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        "segment_size": 8192,
        "num_mels": 64,
        "num_freq": 1025,
        "n_fft": 1024,
        "hop_size": 160,
        "win_size": 1024,
        "sampling_rate": 16000,
        "fmin": 0,
        "fmax": 8000,
        "fmax_for_loss": None,
        "num_workers": 4,
        "dist_config": {
            "dist_backend": "nccl",
            "dist_url": "tcp://localhost:54321",
            "world_size": 1,
        },
    }


def get_vocoder_config_48k():
    return {
        "resblock": "1",
        "num_gpus": 8,
        "batch_size": 128,
        "learning_rate": 0.0001,
        "adam_b1": 0.8,
        "adam_b2": 0.99,
        "lr_decay": 0.999,
        "seed": 1234,
        "upsample_rates": [6, 5, 4, 2, 2],
        "upsample_kernel_sizes": [12, 10, 8, 4, 4],
        "upsample_initial_channel": 1536,
        "resblock_kernel_sizes": [3, 7, 11, 15],
        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
        "segment_size": 15360,
        "num_mels": 256,
        "n_fft": 2048,
        "hop_size": 480,
        "win_size": 2048,
        "sampling_rate": 48000,
        "fmin": 20,
        "fmax": 24000,
        "fmax_for_loss": None,
        "num_workers": 8,
        "dist_config": {
            "dist_backend": "nccl",
            "dist_url": "tcp://localhost:18273",
            "world_size": 1,
        },
    }


def get_available_checkpoint_keys(model, ckpt):
    state_dict = torch.load(ckpt)["state_dict"]
    current_state_dict = model.state_dict()
    new_state_dict = {}
    for k in state_dict.keys():
        if (
            k in current_state_dict.keys()
            and current_state_dict[k].size() == state_dict[k].size()
        ):
            new_state_dict[k] = state_dict[k]
        else:
            print("==> WARNING: Skipping %s" % k)
    print(
        "%s out of %s keys are matched"
        % (len(new_state_dict.keys()), len(state_dict.keys()))
    )
    return new_state_dict


def get_param_num(model):
    num_param = sum(param.numel() for param in model.parameters())
    return num_param


def torch_version_orig_mod_remove(state_dict):
    new_state_dict = {}
    new_state_dict["generator"] = {}
    for key in state_dict["generator"].keys():
        if "_orig_mod." in key:
            new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
                "generator"
            ][key]
        else:
            new_state_dict["generator"][key] = state_dict["generator"][key]
    return new_state_dict


def get_vocoder(config, device, mel_bins):
    name = "HiFi-GAN"
    speaker = ""
    if name == "MelGAN":
        if speaker == "LJSpeech":
            vocoder = torch.hub.load(
                "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
            )
        elif speaker == "universal":
            vocoder = torch.hub.load(
                "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
            )
        vocoder.mel2wav.eval()
        vocoder.mel2wav.to(device)
    elif name == "HiFi-GAN":
        if mel_bins == 64:
            config = get_vocoder_config()
            config = hifigan.AttrDict(config)
            vocoder = hifigan.Generator_old(config)
            # print("Load hifigan/g_01080000")
            # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
            # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
            # ckpt = torch_version_orig_mod_remove(ckpt)
            # vocoder.load_state_dict(ckpt["generator"])
            vocoder.eval()
            vocoder.remove_weight_norm()
            vocoder.to(device)
        else:
            config = get_vocoder_config_48k()
            config = hifigan.AttrDict(config)
            vocoder = hifigan.Generator_old(config)
            # print("Load hifigan/g_01080000")
            # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
            # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
            # ckpt = torch_version_orig_mod_remove(ckpt)
            # vocoder.load_state_dict(ckpt["generator"])
            vocoder.eval()
            vocoder.remove_weight_norm()
            vocoder.to(device)
    return vocoder


def vocoder_infer(mels, vocoder, lengths=None):
    with torch.no_grad():
        wavs = vocoder(mels).squeeze(1)

    wavs = (wavs.cpu().numpy() * 32768).astype("int16")

    if lengths is not None:
        wavs = wavs[:, :lengths]

    # wavs = [wav for wav in wavs]

    # for i in range(len(mels)):
    #     if lengths is not None:
    #         wavs[i] = wavs[i][: lengths[i]]

    return wavs