File size: 8,833 Bytes
9206300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import matplotlib
matplotlib.use('Agg')
from tasks.base_task import data_loader
from tasks.tts.fs2 import FastSpeech2Task
from tasks.tts.dataset_utils import FastSpeechDataset, BaseTTSDataset
import glob
import importlib
from utils.pitch_utils import norm_interp_f0, denorm_f0, f0_to_coarse
from inference.base_tts_infer import load_data_preprocessor
from data_gen.tts.emotion import inference as EmotionEncoder
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
from data_gen.tts.emotion.inference import preprocess_wav
from tqdm import tqdm
from utils.hparams import hparams
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
import random
import torch
import torch.optim
import torch.nn.functional as F
import torch.utils.data
from utils.indexed_datasets import IndexedDataset
from resemblyzer import VoiceEncoder
import torch.distributions
import numpy as np
import utils
import os



class GenerSpeech_dataset(BaseTTSDataset):
    def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
        super().__init__(prefix, shuffle, test_items, test_sizes, data_dir)
        self.f0_mean, self.f0_std = hparams.get('f0_mean', None), hparams.get('f0_std', None)
        if prefix == 'valid':
            indexed_ds = IndexedDataset(f'{self.data_dir}/train')
            sizes = np.load(f'{self.data_dir}/train_lengths.npy')
            index = [i for i in range(len(indexed_ds))]
            random.shuffle(index)
            index = index[:300]
            self.sizes = sizes[index]
            self.indexed_ds = []
            for i in index:
                self.indexed_ds.append(indexed_ds[i])
            self.avail_idxs = list(range(len(self.sizes)))
            if hparams['min_frames'] > 0:
                self.avail_idxs = [x for x in self.avail_idxs if self.sizes[x] >= hparams['min_frames']]
            self.sizes = [self.sizes[i] for i in self.avail_idxs]

        if prefix == 'test' and hparams['test_input_dir'] != '':
            self.preprocessor, self.preprocess_args = load_data_preprocessor()
            self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
            self.avail_idxs = [i for i, _ in enumerate(self.sizes)]


    def load_test_inputs(self, test_input_dir):
        inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3'))
        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
        pkg = ".".join(binarizer_cls.split(".")[:-1])
        cls_name = binarizer_cls.split(".")[-1]
        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)

        phone_encoder = build_phone_encoder(hparams['binary_data_dir'])
        word_encoder = build_word_encoder(hparams['binary_data_dir'])
        voice_encoder = VoiceEncoder().cuda()

        encoder = [phone_encoder, word_encoder]
        sizes = []
        items = []
        EmotionEncoder.load_model(hparams['emotion_encoder_path'])
        preprocessor, preprocess_args = self.preprocessor, self.preprocess_args

        for wav_fn in tqdm(inp_wav_paths):
            item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_")
            spk_id = emotion = 0
            item2tgfn = wav_fn.replace('.wav', '.TextGrid') # prepare textgrid alignment
            txtpath = wav_fn.replace('.wav', '.txt')  # prepare text
            with open(txtpath, 'r') as f:
                text_raw = f.readlines()
                f.close()
            ph, txt = preprocessor.txt_to_ph(preprocessor.txt_processor, text_raw[0], preprocess_args)

            item = binarizer_cls.process_item(item_name, ph, txt, item2tgfn, wav_fn, spk_id, emotion, encoder, hparams['binarization_args'])
            item['emo_embed'] = Embed_utterance(preprocess_wav(item['wav_fn']))
            item['spk_embed'] = voice_encoder.embed_utterance(item['wav'])
            items.append(item)
            sizes.append(item['len'])
        return items, sizes

    def _get_item(self, index):
        if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
            index = self.avail_idxs[index]
        if self.indexed_ds is None:
            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
        return self.indexed_ds[index]

    def __getitem__(self, index):
        hparams = self.hparams
        item = self._get_item(index)
        assert len(item['mel']) == self.sizes[index], (len(item['mel']), self.sizes[index])
        max_frames = hparams['max_frames']
        spec = torch.Tensor(item['mel'])[:max_frames]
        max_frames = spec.shape[0] // hparams['frames_multiple'] * hparams['frames_multiple']
        spec = spec[:max_frames]
        phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
        sample = {
            "id": index,
            "item_name": item['item_name'],
            "text": item['txt'],
            "txt_token": phone,
            "mel": spec,
            "mel_nonpadding": spec.abs().sum(-1) > 0,
        }
        spec = sample['mel']
        T = spec.shape[0]
        sample['mel2ph'] = mel2ph = torch.LongTensor(item['mel2ph'])[:T] if 'mel2ph' in item else None
        if hparams['use_pitch_embed']:
            assert 'f0' in item
            if hparams.get('normalize_pitch', False):
                f0 = item["f0"]
                if len(f0 > 0) > 0 and f0[f0 > 0].std() > 0:
                    f0[f0 > 0] = (f0[f0 > 0] - f0[f0 > 0].mean()) / f0[f0 > 0].std() * hparams['f0_std'] + \
                                 hparams['f0_mean']
                    f0[f0 > 0] = f0[f0 > 0].clip(min=60, max=500)
                pitch = f0_to_coarse(f0)
                pitch = torch.LongTensor(pitch[:max_frames])
            else:
                pitch = torch.LongTensor(item.get("pitch"))[:max_frames] if "pitch" in item else None
            f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
            uv = torch.FloatTensor(uv)
            f0 = torch.FloatTensor(f0)
        else:
            f0 = uv = torch.zeros_like(mel2ph)
            pitch = None
        sample["f0"], sample["uv"], sample["pitch"] = f0, uv, pitch
        sample["spk_embed"] = torch.Tensor(item['spk_embed'])
        sample["emotion"] = item['emotion']
        sample["emo_embed"] = torch.Tensor(item['emo_embed'])

        if hparams.get('use_word', False):
            sample["ph_words"] = item["ph_words"]
            sample["word_tokens"] = torch.LongTensor(item["word_tokens"])
            sample["mel2word"] = torch.LongTensor(item.get("mel2word"))[:max_frames]
            sample["ph2word"] = torch.LongTensor(item['ph2word'][:hparams['max_input_tokens']])
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        hparams = self.hparams
        id = torch.LongTensor([s['id'] for s in samples])
        item_names = [s['item_name'] for s in samples]
        text = [s['text'] for s in samples]
        txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
        mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
        txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
        mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])

        batch = {
            'id': id,
            'item_name': item_names,
            'nsamples': len(samples),
            'text': text,
            'txt_tokens': txt_tokens,
            'txt_lengths': txt_lengths,
            'mels': mels,
            'mel_lengths': mel_lengths,
        }

        f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
        pitch = utils.collate_1d([s['pitch'] for s in samples]) if samples[0]['pitch'] is not None else None
        uv = utils.collate_1d([s['uv'] for s in samples])
        mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) if samples[0]['mel2ph'] is not None else None
        batch.update({
            'mel2ph': mel2ph,
            'pitch': pitch,
            'f0': f0,
            'uv': uv,
        })
        spk_embed = torch.stack([s['spk_embed'] for s in samples])
        batch['spk_embed'] = spk_embed
        emo_embed = torch.stack([s['emo_embed'] for s in samples])
        batch['emo_embed'] = emo_embed

        if hparams.get('use_word', False):
            ph_words = [s['ph_words'] for s in samples]
            batch['ph_words'] = ph_words
            word_tokens = utils.collate_1d([s['word_tokens'] for s in samples], 0)
            batch['word_tokens'] = word_tokens
            mel2word = utils.collate_1d([s['mel2word'] for s in samples], 0)
            batch['mel2word'] = mel2word
            ph2word = utils.collate_1d([s['ph2word'] for s in samples], 0)
            batch['ph2word'] = ph2word
        return batch