File size: 5,250 Bytes
8b0ab3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import absolute_import, division, print_function, unicode_literals
from typing import Tuple
import sys
from argparse import ArgumentParser

import torch
import numpy as np
import os
import json
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../src/glow_tts"))

from scipy.io.wavfile import write
from hifi.env import AttrDict
from hifi.models import Generator


from text import text_to_sequence
import commons
import models
import utils


def check_directory(dir):
    if not os.path.exists(dir):
        sys.exit("Error: {} directory does not exist".format(dir))


class TextToMel:
    def __init__(self, glow_model_dir, device="cuda"):
        self.glow_model_dir = glow_model_dir
        check_directory(self.glow_model_dir)
        self.device = device
        self.hps, self.glow_tts_model = self.load_glow_tts()
        pass

    def load_glow_tts(self):
        hps = utils.get_hparams_from_dir(self.glow_model_dir)
        checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir)
        symbols = list(hps.data.punc) + list(hps.data.chars)
        glow_tts_model = models.FlowGenerator(
            len(symbols) + getattr(hps.data, "add_blank", False),
            out_channels=hps.data.n_mel_channels,
            **hps.model
        )  # .to(self.device)

        if self.device == "cuda":
            glow_tts_model.to("cuda")

        utils.load_checkpoint(checkpoint_path, glow_tts_model)
        glow_tts_model.decoder.store_inverse()
        _ = glow_tts_model.eval()

        return hps, glow_tts_model

    def generate_mel(self, text, noise_scale=0.667, length_scale=1.0):
        symbols = list(self.hps.data.punc) + list(self.hps.data.chars)
        cleaner = self.hps.data.text_cleaners
        if getattr(self.hps.data, "add_blank", False):
            text_norm = text_to_sequence(text, symbols, cleaner)
            text_norm = commons.intersperse(text_norm, len(symbols))
        else:  # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality
            text = " " + text.strip() + " "
            text_norm = text_to_sequence(text, symbols, cleaner)

        sequence = np.array(text_norm)[None, :]

        del symbols
        del cleaner
        del text
        del text_norm

        if self.device == "cuda":
            x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
            x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda()
        else:
            x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long()
            x_tst_lengths = torch.tensor([x_tst.shape[1]])

        with torch.no_grad():
            (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model(
                x_tst,
                x_tst_lengths,
                gen=True,
                noise_scale=noise_scale,
                length_scale=length_scale,
            )
        del x_tst
        del x_tst_lengths
        torch.cuda.empty_cache()
        return y_gen_tst
        #return y_gen_tst.cpu().detach().numpy()


class MelToWav:
    def __init__(self, hifi_model_dir, device="cuda"):
        self.hifi_model_dir = hifi_model_dir
        check_directory(self.hifi_model_dir)
        self.device = device
        self.h, self.hifi_gan_generator = self.load_hifi_gan()
        pass

    def load_hifi_gan(self):
        checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*")
        config_file = os.path.join(self.hifi_model_dir, "config.json")
        data = open(config_file).read()
        json_config = json.loads(data)
        h = AttrDict(json_config)
        torch.manual_seed(h.seed)

        generator = Generator(h).to(self.device)

        assert os.path.isfile(checkpoint_path)
        print("Loading '{}'".format(checkpoint_path))
        state_dict_g = torch.load(checkpoint_path, map_location=self.device)
        print("Complete.")

        generator.load_state_dict(state_dict_g["generator"])

        generator.eval()
        generator.remove_weight_norm()

        return h, generator

    def generate_wav(self, mel):
        #mel = torch.FloatTensor(mel).to(self.device)

        y_g_hat = self.hifi_gan_generator(mel.to(self.device))  # passing through vocoder
        audio = y_g_hat.squeeze()
        audio = audio * 32768.0
        audio = audio.cpu().detach().numpy().astype("int16")

        del y_g_hat
        del mel
        torch.cuda.empty_cache()
        return audio, self.h.sampling_rate


if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("-m", "--model", required=True, type=str)
    parser.add_argument("-g", "--gan", required=True, type=str)
    parser.add_argument("-d", "--device", type=str, default="cpu")
    parser.add_argument("-t", "--text", type=str, required=True)
    parser.add_argument("-w", "--wav", type=str, required=True)
    args = parser.parse_args()

    text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device)
    mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device)

    mel = text_to_mel.generate_mel(args.text)
    audio, sr = mel_to_wav.generate_wav(mel)

    write(filename=args.wav, rate=sr, data=audio)

    pass