namelessai commited on
Commit
13ddedd
·
verified ·
1 Parent(s): a496765

Upload 3 files

Browse files
Files changed (3) hide show
  1. utilities/model.py +167 -0
  2. utilities/sampler.py +588 -0
  3. utilities/tools.py +541 -0
utilities/model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import audiosr.hifigan as hifigan
4
+
5
+
6
+ def get_vocoder_config():
7
+ return {
8
+ "resblock": "1",
9
+ "num_gpus": 6,
10
+ "batch_size": 16,
11
+ "learning_rate": 0.0002,
12
+ "adam_b1": 0.8,
13
+ "adam_b2": 0.99,
14
+ "lr_decay": 0.999,
15
+ "seed": 1234,
16
+ "upsample_rates": [5, 4, 2, 2, 2],
17
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
18
+ "upsample_initial_channel": 1024,
19
+ "resblock_kernel_sizes": [3, 7, 11],
20
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
21
+ "segment_size": 8192,
22
+ "num_mels": 64,
23
+ "num_freq": 1025,
24
+ "n_fft": 1024,
25
+ "hop_size": 160,
26
+ "win_size": 1024,
27
+ "sampling_rate": 16000,
28
+ "fmin": 0,
29
+ "fmax": 8000,
30
+ "fmax_for_loss": None,
31
+ "num_workers": 4,
32
+ "dist_config": {
33
+ "dist_backend": "nccl",
34
+ "dist_url": "tcp://localhost:54321",
35
+ "world_size": 1,
36
+ },
37
+ }
38
+
39
+
40
+ def get_vocoder_config_48k():
41
+ return {
42
+ "resblock": "1",
43
+ "num_gpus": 8,
44
+ "batch_size": 128,
45
+ "learning_rate": 0.0001,
46
+ "adam_b1": 0.8,
47
+ "adam_b2": 0.99,
48
+ "lr_decay": 0.999,
49
+ "seed": 1234,
50
+ "upsample_rates": [6, 5, 4, 2, 2],
51
+ "upsample_kernel_sizes": [12, 10, 8, 4, 4],
52
+ "upsample_initial_channel": 1536,
53
+ "resblock_kernel_sizes": [3, 7, 11, 15],
54
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
55
+ "segment_size": 15360,
56
+ "num_mels": 256,
57
+ "n_fft": 2048,
58
+ "hop_size": 480,
59
+ "win_size": 2048,
60
+ "sampling_rate": 48000,
61
+ "fmin": 20,
62
+ "fmax": 24000,
63
+ "fmax_for_loss": None,
64
+ "num_workers": 8,
65
+ "dist_config": {
66
+ "dist_backend": "nccl",
67
+ "dist_url": "tcp://localhost:18273",
68
+ "world_size": 1,
69
+ },
70
+ }
71
+
72
+
73
+ def get_available_checkpoint_keys(model, ckpt):
74
+ state_dict = torch.load(ckpt)["state_dict"]
75
+ current_state_dict = model.state_dict()
76
+ new_state_dict = {}
77
+ for k in state_dict.keys():
78
+ if (
79
+ k in current_state_dict.keys()
80
+ and current_state_dict[k].size() == state_dict[k].size()
81
+ ):
82
+ new_state_dict[k] = state_dict[k]
83
+ else:
84
+ print("==> WARNING: Skipping %s" % k)
85
+ print(
86
+ "%s out of %s keys are matched"
87
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
88
+ )
89
+ return new_state_dict
90
+
91
+
92
+ def get_param_num(model):
93
+ num_param = sum(param.numel() for param in model.parameters())
94
+ return num_param
95
+
96
+
97
+ def torch_version_orig_mod_remove(state_dict):
98
+ new_state_dict = {}
99
+ new_state_dict["generator"] = {}
100
+ for key in state_dict["generator"].keys():
101
+ if "_orig_mod." in key:
102
+ new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
103
+ "generator"
104
+ ][key]
105
+ else:
106
+ new_state_dict["generator"][key] = state_dict["generator"][key]
107
+ return new_state_dict
108
+
109
+
110
+ def get_vocoder(config, device, mel_bins):
111
+ name = "HiFi-GAN"
112
+ speaker = ""
113
+ if name == "MelGAN":
114
+ if speaker == "LJSpeech":
115
+ vocoder = torch.hub.load(
116
+ "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
117
+ )
118
+ elif speaker == "universal":
119
+ vocoder = torch.hub.load(
120
+ "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
121
+ )
122
+ vocoder.mel2wav.eval()
123
+ vocoder.mel2wav.to(device)
124
+ elif name == "HiFi-GAN":
125
+ if mel_bins == 64:
126
+ config = get_vocoder_config()
127
+ config = hifigan.AttrDict(config)
128
+ vocoder = hifigan.Generator_old(config)
129
+ # print("Load hifigan/g_01080000")
130
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
131
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
132
+ # ckpt = torch_version_orig_mod_remove(ckpt)
133
+ # vocoder.load_state_dict(ckpt["generator"])
134
+ vocoder.eval()
135
+ vocoder.remove_weight_norm()
136
+ vocoder.to(device)
137
+ else:
138
+ config = get_vocoder_config_48k()
139
+ config = hifigan.AttrDict(config)
140
+ vocoder = hifigan.Generator_old(config)
141
+ # print("Load hifigan/g_01080000")
142
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
143
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
144
+ # ckpt = torch_version_orig_mod_remove(ckpt)
145
+ # vocoder.load_state_dict(ckpt["generator"])
146
+ vocoder.eval()
147
+ vocoder.remove_weight_norm()
148
+ vocoder.to(device)
149
+ return vocoder
150
+
151
+
152
+ def vocoder_infer(mels, vocoder, lengths=None):
153
+ with torch.no_grad():
154
+ wavs = vocoder(mels).squeeze(1)
155
+
156
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
157
+
158
+ if lengths is not None:
159
+ wavs = wavs[:, :lengths]
160
+
161
+ # wavs = [wav for wav in wavs]
162
+
163
+ # for i in range(len(mels)):
164
+ # if lengths is not None:
165
+ # wavs[i] = wavs[i][: lengths[i]]
166
+
167
+ return wavs
utilities/sampler.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Optional, Union
2
+ from collections import Counter
3
+ import logging
4
+ from operator import itemgetter
5
+ import random
6
+
7
+ import numpy as np
8
+
9
+ from torch.utils.data import DistributedSampler
10
+ from torch.utils.data.sampler import Sampler
11
+
12
+ LOGGER = logging.getLogger(__name__)
13
+
14
+ from torch.utils.data import Dataset, Sampler
15
+
16
+
17
+ class DatasetFromSampler(Dataset):
18
+ """Dataset to create indexes from `Sampler`.
19
+ Args:
20
+ sampler: PyTorch sampler
21
+ """
22
+
23
+ def __init__(self, sampler: Sampler):
24
+ """Initialisation for DatasetFromSampler."""
25
+ self.sampler = sampler
26
+ self.sampler_list = None
27
+
28
+ def __getitem__(self, index: int):
29
+ """Gets element of the dataset.
30
+ Args:
31
+ index: index of the element in the dataset
32
+ Returns:
33
+ Single element by index
34
+ """
35
+ if self.sampler_list is None:
36
+ self.sampler_list = list(self.sampler)
37
+ return self.sampler_list[index]
38
+
39
+ def __len__(self) -> int:
40
+ """
41
+ Returns:
42
+ int: length of the dataset
43
+ """
44
+ return len(self.sampler)
45
+
46
+
47
+ class BalanceClassSampler(Sampler):
48
+ """Allows you to create stratified sample on unbalanced classes.
49
+
50
+ Args:
51
+ labels: list of class label for each elem in the dataset
52
+ mode: Strategy to balance classes.
53
+ Must be one of [downsampling, upsampling]
54
+
55
+ Python API examples:
56
+
57
+ .. code-block:: python
58
+
59
+ import os
60
+ from torch import nn, optim
61
+ from torch.utils.data import DataLoader
62
+ from catalyst import dl
63
+ from catalyst.data import ToTensor, BalanceClassSampler
64
+ from catalyst.contrib.datasets import MNIST
65
+
66
+ train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
67
+ train_labels = train_data.targets.cpu().numpy().tolist()
68
+ train_sampler = BalanceClassSampler(train_labels, mode=5000)
69
+ valid_data = MNIST(os.getcwd(), train=False)
70
+
71
+ loaders = {
72
+ "train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
73
+ "valid": DataLoader(valid_data, batch_size=32),
74
+ }
75
+
76
+ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
77
+ criterion = nn.CrossEntropyLoss()
78
+ optimizer = optim.Adam(model.parameters(), lr=0.02)
79
+
80
+ runner = dl.SupervisedRunner()
81
+ # model training
82
+ runner.train(
83
+ model=model,
84
+ criterion=criterion,
85
+ optimizer=optimizer,
86
+ loaders=loaders,
87
+ num_epochs=1,
88
+ logdir="./logs",
89
+ valid_loader="valid",
90
+ valid_metric="loss",
91
+ minimize_valid_metric=True,
92
+ verbose=True,
93
+ )
94
+ """
95
+
96
+ def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
97
+ """Sampler initialisation."""
98
+ super().__init__(labels)
99
+
100
+ labels = np.array(labels)
101
+ samples_per_class = {label: (labels == label).sum() for label in set(labels)}
102
+
103
+ self.lbl2idx = {
104
+ label: np.arange(len(labels))[labels == label].tolist()
105
+ for label in set(labels)
106
+ }
107
+
108
+ if isinstance(mode, str):
109
+ assert mode in ["downsampling", "upsampling"]
110
+
111
+ if isinstance(mode, int) or mode == "upsampling":
112
+ samples_per_class = (
113
+ mode if isinstance(mode, int) else max(samples_per_class.values())
114
+ )
115
+ else:
116
+ samples_per_class = min(samples_per_class.values())
117
+
118
+ self.labels = labels
119
+ self.samples_per_class = samples_per_class
120
+ self.length = self.samples_per_class * len(set(labels))
121
+
122
+ def __iter__(self) -> Iterator[int]:
123
+ """
124
+ Returns:
125
+ iterator of indices of stratified sample
126
+ """
127
+ indices = []
128
+ for key in sorted(self.lbl2idx):
129
+ replace_flag = self.samples_per_class > len(self.lbl2idx[key])
130
+ indices += np.random.choice(
131
+ self.lbl2idx[key], self.samples_per_class, replace=replace_flag
132
+ ).tolist()
133
+ assert len(indices) == self.length
134
+ np.random.shuffle(indices)
135
+
136
+ return iter(indices)
137
+
138
+ def __len__(self) -> int:
139
+ """
140
+ Returns:
141
+ length of result sample
142
+ """
143
+ return self.length
144
+
145
+
146
+ class BatchBalanceClassSampler(Sampler):
147
+ """
148
+ This kind of sampler can be used for both metric learning and classification task.
149
+
150
+ BatchSampler with the given strategy for the C unique classes dataset:
151
+ - Selection `num_classes` of C classes for each batch
152
+ - Selection `num_samples` instances for each class in the batch
153
+ The epoch ends after `num_batches`.
154
+ So, the batch sise is `num_classes` * `num_samples`.
155
+
156
+ One of the purposes of this sampler is to be used for
157
+ forming triplets and pos/neg pairs inside the batch.
158
+ To guarante existance of these pairs in the batch,
159
+ `num_classes` and `num_samples` should be > 1. (1)
160
+
161
+ This type of sampling can be found in the classical paper of Person Re-Id,
162
+ where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
163
+ `In Defense of the Triplet Loss for Person Re-Identification`_.
164
+
165
+ Args:
166
+ labels: list of classes labeles for each elem in the dataset
167
+ num_classes: number of classes in a batch, should be > 1
168
+ num_samples: number of instances of each class in a batch, should be > 1
169
+ num_batches: number of batches in epoch
170
+ (default = len(labels) // (num_classes * num_samples))
171
+
172
+ .. _In Defense of the Triplet Loss for Person Re-Identification:
173
+ https://arxiv.org/abs/1703.07737
174
+
175
+ Python API examples:
176
+
177
+ .. code-block:: python
178
+
179
+ import os
180
+ from torch import nn, optim
181
+ from torch.utils.data import DataLoader
182
+ from catalyst import dl
183
+ from catalyst.data import ToTensor, BatchBalanceClassSampler
184
+ from catalyst.contrib.datasets import MNIST
185
+
186
+ train_data = MNIST(os.getcwd(), train=True, download=True)
187
+ train_labels = train_data.targets.cpu().numpy().tolist()
188
+ train_sampler = BatchBalanceClassSampler(
189
+ train_labels, num_classes=10, num_samples=4)
190
+ valid_data = MNIST(os.getcwd(), train=False)
191
+
192
+ loaders = {
193
+ "train": DataLoader(train_data, batch_sampler=train_sampler),
194
+ "valid": DataLoader(valid_data, batch_size=32),
195
+ }
196
+
197
+ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
198
+ criterion = nn.CrossEntropyLoss()
199
+ optimizer = optim.Adam(model.parameters(), lr=0.02)
200
+
201
+ runner = dl.SupervisedRunner()
202
+ # model training
203
+ runner.train(
204
+ model=model,
205
+ criterion=criterion,
206
+ optimizer=optimizer,
207
+ loaders=loaders,
208
+ num_epochs=1,
209
+ logdir="./logs",
210
+ valid_loader="valid",
211
+ valid_metric="loss",
212
+ minimize_valid_metric=True,
213
+ verbose=True,
214
+ )
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ labels: Union[List[int], np.ndarray],
220
+ num_classes: int,
221
+ num_samples: int,
222
+ num_batches: int = None,
223
+ ):
224
+ """Sampler initialisation."""
225
+ super().__init__(labels)
226
+ classes = set(labels)
227
+
228
+ assert isinstance(num_classes, int) and isinstance(num_samples, int)
229
+ assert (1 < num_classes <= len(classes)) and (1 < num_samples)
230
+ assert all(
231
+ n > 1 for n in Counter(labels).values()
232
+ ), "Each class shoud contain at least 2 instances to fit (1)"
233
+
234
+ labels = np.array(labels)
235
+ self._labels = list(set(labels.tolist()))
236
+ self._num_classes = num_classes
237
+ self._num_samples = num_samples
238
+ self._batch_size = self._num_classes * self._num_samples
239
+ self._num_batches = num_batches or len(labels) // self._batch_size
240
+ self.lbl2idx = {
241
+ label: np.arange(len(labels))[labels == label].tolist()
242
+ for label in set(labels)
243
+ }
244
+
245
+ @property
246
+ def batch_size(self) -> int:
247
+ """
248
+ Returns:
249
+ this value should be used in DataLoader as batch size
250
+ """
251
+ return self._batch_size
252
+
253
+ @property
254
+ def batches_in_epoch(self) -> int:
255
+ """
256
+ Returns:
257
+ number of batches in an epoch
258
+ """
259
+ return self._num_batches
260
+
261
+ def __len__(self) -> int:
262
+ """
263
+ Returns:
264
+ number of samples in an epoch
265
+ """
266
+ return self._num_batches # * self._batch_size
267
+
268
+ def __iter__(self) -> Iterator[int]:
269
+ """
270
+ Returns:
271
+ indeces for sampling dataset elems during an epoch
272
+ """
273
+ indices = []
274
+ for _ in range(self._num_batches):
275
+ batch_indices = []
276
+ classes_for_batch = random.sample(self._labels, self._num_classes)
277
+ while self._num_classes != len(set(classes_for_batch)):
278
+ classes_for_batch = random.sample(self._labels, self._num_classes)
279
+ for cls_id in classes_for_batch:
280
+ replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
281
+ batch_indices += np.random.choice(
282
+ self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
283
+ ).tolist()
284
+ indices.append(batch_indices)
285
+ return iter(indices)
286
+
287
+
288
+ class DynamicBalanceClassSampler(Sampler):
289
+ """
290
+ This kind of sampler can be used for classification tasks with significant
291
+ class imbalance.
292
+
293
+ The idea of this sampler that we start with the original class distribution
294
+ and gradually move to uniform class distribution like with downsampling.
295
+
296
+ Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
297
+ i and :math: #C_min is a size of the rarest class, so :math: D_i define
298
+ class distribution. Also define :math: g(n_epoch) is a exponential
299
+ scheduler. On each epoch current :math: D_i calculated as
300
+ :math: current D_i = D_i ^ g(n_epoch),
301
+ after this data samples according this distribution.
302
+
303
+ Notes:
304
+ In the end of the training, epochs will contain only
305
+ min_size_class * n_classes examples. So, possible it will not
306
+ necessary to do validation on each epoch. For this reason use
307
+ ControlFlowCallback.
308
+
309
+ Examples:
310
+
311
+ >>> import torch
312
+ >>> import numpy as np
313
+
314
+ >>> from catalyst.data import DynamicBalanceClassSampler
315
+ >>> from torch.utils import data
316
+
317
+ >>> features = torch.Tensor(np.random.random((200, 100)))
318
+ >>> labels = np.random.randint(0, 4, size=(200,))
319
+ >>> sampler = DynamicBalanceClassSampler(labels)
320
+ >>> labels = torch.LongTensor(labels)
321
+ >>> dataset = data.TensorDataset(features, labels)
322
+ >>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
323
+
324
+ >>> for batch in loader:
325
+ >>> b_features, b_labels = batch
326
+
327
+ Sampler was inspired by https://arxiv.org/abs/1901.06783
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ labels: List[Union[int, str]],
333
+ exp_lambda: float = 0.9,
334
+ start_epoch: int = 0,
335
+ max_d: Optional[int] = None,
336
+ mode: Union[str, int] = "downsampling",
337
+ ignore_warning: bool = False,
338
+ ):
339
+ """
340
+ Args:
341
+ labels: list of labels for each elem in the dataset
342
+ exp_lambda: exponent figure for schedule
343
+ start_epoch: start epoch number, can be useful for multi-stage
344
+ experiments
345
+ max_d: if not None, limit on the difference between the most
346
+ frequent and the rarest classes, heuristic
347
+ mode: number of samples per class in the end of training. Must be
348
+ "downsampling" or number. Before change it, make sure that you
349
+ understand how does it work
350
+ ignore_warning: ignore warning about min class size
351
+ """
352
+ assert isinstance(start_epoch, int)
353
+ assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
354
+ super().__init__(labels)
355
+ self.exp_lambda = exp_lambda
356
+ if max_d is None:
357
+ max_d = np.inf
358
+ self.max_d = max_d
359
+ self.epoch = start_epoch
360
+ labels = np.array(labels)
361
+ samples_per_class = Counter(labels)
362
+ self.min_class_size = min(samples_per_class.values())
363
+
364
+ if self.min_class_size < 100 and not ignore_warning:
365
+ LOGGER.warning(
366
+ f"the smallest class contains only"
367
+ f" {self.min_class_size} examples. At the end of"
368
+ f" training, epochs will contain only"
369
+ f" {self.min_class_size * len(samples_per_class)}"
370
+ f" examples"
371
+ )
372
+
373
+ self.original_d = {
374
+ key: value / self.min_class_size for key, value in samples_per_class.items()
375
+ }
376
+ self.label2idxes = {
377
+ label: np.arange(len(labels))[labels == label].tolist()
378
+ for label in set(labels)
379
+ }
380
+
381
+ if isinstance(mode, int):
382
+ self.min_class_size = mode
383
+ else:
384
+ assert mode == "downsampling"
385
+
386
+ self.labels = labels
387
+ self._update()
388
+
389
+ def _update(self) -> None:
390
+ """Update d coefficients."""
391
+ current_d = {
392
+ key: min(value ** self._exp_scheduler(), self.max_d)
393
+ for key, value in self.original_d.items()
394
+ }
395
+ samples_per_classes = {
396
+ key: int(value * self.min_class_size) for key, value in current_d.items()
397
+ }
398
+ self.samples_per_classes = samples_per_classes
399
+ self.length = np.sum(list(samples_per_classes.values()))
400
+ self.epoch += 1
401
+
402
+ def _exp_scheduler(self) -> float:
403
+ return self.exp_lambda**self.epoch
404
+
405
+ def __iter__(self) -> Iterator[int]:
406
+ """
407
+ Returns:
408
+ iterator of indices of stratified sample
409
+ """
410
+ indices = []
411
+ for key in sorted(self.label2idxes):
412
+ samples_per_class = self.samples_per_classes[key]
413
+ replace_flag = samples_per_class > len(self.label2idxes[key])
414
+ indices += np.random.choice(
415
+ self.label2idxes[key], samples_per_class, replace=replace_flag
416
+ ).tolist()
417
+ assert len(indices) == self.length
418
+ np.random.shuffle(indices)
419
+ self._update()
420
+ return iter(indices)
421
+
422
+ def __len__(self) -> int:
423
+ """
424
+ Returns:
425
+ length of result sample
426
+ """
427
+ return self.length
428
+
429
+
430
+ class MiniEpochSampler(Sampler):
431
+ """
432
+ Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
433
+
434
+ Args:
435
+ data_len: Size of the dataset
436
+ mini_epoch_len: Num samples from the dataset used in one
437
+ mini epoch.
438
+ drop_last: If ``True``, sampler will drop the last batches
439
+ if its size would be less than ``batches_per_epoch``
440
+ shuffle: one of ``"always"``, ``"real_epoch"``, or `None``.
441
+ The sampler will shuffle indices
442
+ > "per_mini_epoch" - every mini epoch (every ``__iter__`` call)
443
+ > "per_epoch" -- every real epoch
444
+ > None -- don't shuffle
445
+
446
+ Example:
447
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100)
448
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True)
449
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100,
450
+ >>> shuffle="per_epoch")
451
+ """
452
+
453
+ def __init__(
454
+ self,
455
+ data_len: int,
456
+ mini_epoch_len: int,
457
+ drop_last: bool = False,
458
+ shuffle: str = None,
459
+ ):
460
+ """Sampler initialisation."""
461
+ super().__init__(None)
462
+
463
+ self.data_len = int(data_len)
464
+ self.mini_epoch_len = int(mini_epoch_len)
465
+
466
+ self.steps = int(data_len / self.mini_epoch_len)
467
+ self.state_i = 0
468
+
469
+ has_reminder = data_len - self.steps * mini_epoch_len > 0
470
+ if self.steps == 0:
471
+ self.divider = 1
472
+ elif has_reminder and not drop_last:
473
+ self.divider = self.steps + 1
474
+ else:
475
+ self.divider = self.steps
476
+
477
+ self._indices = np.arange(self.data_len)
478
+ self.indices = self._indices
479
+ self.end_pointer = max(self.data_len, self.mini_epoch_len)
480
+
481
+ if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]):
482
+ raise ValueError(
483
+ "Shuffle must be one of ['per_mini_epoch', 'per_epoch']. "
484
+ + f"Got {shuffle}"
485
+ )
486
+ self.shuffle_type = shuffle
487
+
488
+ def shuffle(self) -> None:
489
+ """Shuffle sampler indices."""
490
+ if self.shuffle_type == "per_mini_epoch" or (
491
+ self.shuffle_type == "per_epoch" and self.state_i == 0
492
+ ):
493
+ if self.data_len >= self.mini_epoch_len:
494
+ self.indices = self._indices
495
+ np.random.shuffle(self.indices)
496
+ else:
497
+ self.indices = np.random.choice(
498
+ self._indices, self.mini_epoch_len, replace=True
499
+ )
500
+
501
+ def __iter__(self) -> Iterator[int]:
502
+ """Iterate over sampler.
503
+
504
+ Returns:
505
+ python iterator
506
+ """
507
+ self.state_i = self.state_i % self.divider
508
+ self.shuffle()
509
+
510
+ start = self.state_i * self.mini_epoch_len
511
+ stop = (
512
+ self.end_pointer
513
+ if (self.state_i == self.steps)
514
+ else (self.state_i + 1) * self.mini_epoch_len
515
+ )
516
+ indices = self.indices[start:stop].tolist()
517
+
518
+ self.state_i += 1
519
+ return iter(indices)
520
+
521
+ def __len__(self) -> int:
522
+ """
523
+ Returns:
524
+ int: length of the mini-epoch
525
+ """
526
+ return self.mini_epoch_len
527
+
528
+
529
+ class DistributedSamplerWrapper(DistributedSampler):
530
+ """
531
+ Wrapper over `Sampler` for distributed training.
532
+ Allows you to use any sampler in distributed mode.
533
+
534
+ It is especially useful in conjunction with
535
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
536
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
537
+ sampler, and load a subset of subsampled data of the original dataset
538
+ that is exclusive to it.
539
+
540
+ .. note::
541
+ Sampler is assumed to be of constant size.
542
+ """
543
+
544
+ def __init__(
545
+ self,
546
+ sampler,
547
+ num_replicas: Optional[int] = None,
548
+ rank: Optional[int] = None,
549
+ shuffle: bool = True,
550
+ ):
551
+ """
552
+
553
+ Args:
554
+ sampler: Sampler used for subsampling
555
+ num_replicas (int, optional): Number of processes participating in
556
+ distributed training
557
+ rank (int, optional): Rank of the current process
558
+ within ``num_replicas``
559
+ shuffle (bool, optional): If true (default),
560
+ sampler will shuffle the indices
561
+ """
562
+ super(DistributedSamplerWrapper, self).__init__(
563
+ DatasetFromSampler(sampler),
564
+ num_replicas=num_replicas,
565
+ rank=rank,
566
+ shuffle=shuffle,
567
+ )
568
+ self.sampler = sampler
569
+
570
+ def __iter__(self) -> Iterator[int]:
571
+ """Iterate over sampler.
572
+
573
+ Returns:
574
+ python iterator
575
+ """
576
+ self.dataset = DatasetFromSampler(self.sampler)
577
+ indexes_of_indexes = super().__iter__()
578
+ subsampler_indexes = self.dataset
579
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
580
+
581
+
582
+ __all__ = [
583
+ "BalanceClassSampler",
584
+ "BatchBalanceClassSampler",
585
+ "DistributedSamplerWrapper",
586
+ "DynamicBalanceClassSampler",
587
+ "MiniEpochSampler",
588
+ ]
utilities/tools.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Haohe Liu
2
+ # Email: [email protected]
3
+ # Date: 11 Feb 2023
4
+
5
+ import os
6
+ import json
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import matplotlib
12
+ from scipy.io import wavfile
13
+ from matplotlib import pyplot as plt
14
+
15
+
16
+ matplotlib.use("Agg")
17
+
18
+ import hashlib
19
+ import os
20
+
21
+ import requests
22
+ from tqdm import tqdm
23
+
24
+ URL_MAP = {
25
+ "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
26
+ "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
27
+ "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
28
+ }
29
+
30
+ CKPT_MAP = {
31
+ "vggishish_lpaps": "vggishish16.pt",
32
+ "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
33
+ "melception": "melception-21-05-10T09-28-40.pt",
34
+ }
35
+
36
+ MD5_MAP = {
37
+ "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
38
+ "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
39
+ "melception": "a71a41041e945b457c7d3d814bbcf72d",
40
+ }
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+
45
+ def load_json(fname):
46
+ with open(fname, "r") as f:
47
+ data = json.load(f)
48
+ return data
49
+
50
+
51
+ def read_json(dataset_json_file):
52
+ with open(dataset_json_file, "r") as fp:
53
+ data_json = json.load(fp)
54
+ return data_json["data"]
55
+
56
+
57
+ def copy_test_subset_data(metadata, testset_copy_target_path):
58
+ # metadata = read_json(testset_metadata)
59
+ os.makedirs(testset_copy_target_path, exist_ok=True)
60
+ if len(os.listdir(testset_copy_target_path)) == len(metadata):
61
+ return
62
+ else:
63
+ # delete files in folder testset_copy_target_path
64
+ for file in os.listdir(testset_copy_target_path):
65
+ try:
66
+ os.remove(os.path.join(testset_copy_target_path, file))
67
+ except Exception as e:
68
+ print(e)
69
+
70
+ print("Copying test subset data to {}".format(testset_copy_target_path))
71
+ for each in tqdm(metadata):
72
+ cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
73
+ os.system(cmd)
74
+
75
+
76
+ def listdir_nohidden(path):
77
+ for f in os.listdir(path):
78
+ if not f.startswith("."):
79
+ yield f
80
+
81
+
82
+ def get_restore_step(path):
83
+ checkpoints = os.listdir(path)
84
+ if os.path.exists(os.path.join(path, "final.ckpt")):
85
+ return "final.ckpt", 0
86
+ elif not os.path.exists(os.path.join(path, "last.ckpt")):
87
+ steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
88
+ return checkpoints[np.argmax(steps)], np.max(steps)
89
+ else:
90
+ steps = []
91
+ for x in checkpoints:
92
+ if "last" in x:
93
+ if "-v" not in x:
94
+ fname = "last.ckpt"
95
+ else:
96
+ this_version = int(x.split(".ckpt")[0].split("-v")[1])
97
+ steps.append(this_version)
98
+ if len(steps) == 0 or this_version > np.max(steps):
99
+ fname = "last-v%s.ckpt" % this_version
100
+ return fname, 0
101
+
102
+
103
+ def download(url, local_path, chunk_size=1024):
104
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
105
+ with requests.get(url, stream=True) as r:
106
+ total_size = int(r.headers.get("content-length", 0))
107
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
108
+ with open(local_path, "wb") as f:
109
+ for data in r.iter_content(chunk_size=chunk_size):
110
+ if data:
111
+ f.write(data)
112
+ pbar.update(chunk_size)
113
+
114
+
115
+ def md5_hash(path):
116
+ with open(path, "rb") as f:
117
+ content = f.read()
118
+ return hashlib.md5(content).hexdigest()
119
+
120
+
121
+ def get_ckpt_path(name, root, check=False):
122
+ assert name in URL_MAP
123
+ path = os.path.join(root, CKPT_MAP[name])
124
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
125
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
126
+ download(URL_MAP[name], path)
127
+ md5 = md5_hash(path)
128
+ assert md5 == MD5_MAP[name], md5
129
+ return path
130
+
131
+
132
+ class KeyNotFoundError(Exception):
133
+ def __init__(self, cause, keys=None, visited=None):
134
+ self.cause = cause
135
+ self.keys = keys
136
+ self.visited = visited
137
+ messages = list()
138
+ if keys is not None:
139
+ messages.append("Key not found: {}".format(keys))
140
+ if visited is not None:
141
+ messages.append("Visited: {}".format(visited))
142
+ messages.append("Cause:\n{}".format(cause))
143
+ message = "\n".join(messages)
144
+ super().__init__(message)
145
+
146
+
147
+ def retrieve(
148
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
149
+ ):
150
+ """Given a nested list or dict return the desired value at key expanding
151
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
152
+ is done in-place.
153
+
154
+ Parameters
155
+ ----------
156
+ list_or_dict : list or dict
157
+ Possibly nested list or dictionary.
158
+ key : str
159
+ key/to/value, path like string describing all keys necessary to
160
+ consider to get to the desired value. List indices can also be
161
+ passed here.
162
+ splitval : str
163
+ String that defines the delimiter between keys of the
164
+ different depth levels in `key`.
165
+ default : obj
166
+ Value returned if :attr:`key` is not found.
167
+ expand : bool
168
+ Whether to expand callable nodes on the path or not.
169
+
170
+ Returns
171
+ -------
172
+ The desired value or if :attr:`default` is not ``None`` and the
173
+ :attr:`key` is not found returns ``default``.
174
+
175
+ Raises
176
+ ------
177
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
178
+ ``None``.
179
+ """
180
+
181
+ keys = key.split(splitval)
182
+
183
+ success = True
184
+ try:
185
+ visited = []
186
+ parent = None
187
+ last_key = None
188
+ for key in keys:
189
+ if callable(list_or_dict):
190
+ if not expand:
191
+ raise KeyNotFoundError(
192
+ ValueError(
193
+ "Trying to get past callable node with expand=False."
194
+ ),
195
+ keys=keys,
196
+ visited=visited,
197
+ )
198
+ list_or_dict = list_or_dict()
199
+ parent[last_key] = list_or_dict
200
+
201
+ last_key = key
202
+ parent = list_or_dict
203
+
204
+ try:
205
+ if isinstance(list_or_dict, dict):
206
+ list_or_dict = list_or_dict[key]
207
+ else:
208
+ list_or_dict = list_or_dict[int(key)]
209
+ except (KeyError, IndexError, ValueError) as e:
210
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
211
+
212
+ visited += [key]
213
+ # final expansion of retrieved value
214
+ if expand and callable(list_or_dict):
215
+ list_or_dict = list_or_dict()
216
+ parent[last_key] = list_or_dict
217
+ except KeyNotFoundError as e:
218
+ if default is None:
219
+ raise e
220
+ else:
221
+ list_or_dict = default
222
+ success = False
223
+
224
+ if not pass_success:
225
+ return list_or_dict
226
+ else:
227
+ return list_or_dict, success
228
+
229
+
230
+ def to_device(data, device):
231
+ if len(data) == 12:
232
+ (
233
+ ids,
234
+ raw_texts,
235
+ speakers,
236
+ texts,
237
+ src_lens,
238
+ max_src_len,
239
+ mels,
240
+ mel_lens,
241
+ max_mel_len,
242
+ pitches,
243
+ energies,
244
+ durations,
245
+ ) = data
246
+
247
+ speakers = torch.from_numpy(speakers).long().to(device)
248
+ texts = torch.from_numpy(texts).long().to(device)
249
+ src_lens = torch.from_numpy(src_lens).to(device)
250
+ mels = torch.from_numpy(mels).float().to(device)
251
+ mel_lens = torch.from_numpy(mel_lens).to(device)
252
+ pitches = torch.from_numpy(pitches).float().to(device)
253
+ energies = torch.from_numpy(energies).to(device)
254
+ durations = torch.from_numpy(durations).long().to(device)
255
+
256
+ return (
257
+ ids,
258
+ raw_texts,
259
+ speakers,
260
+ texts,
261
+ src_lens,
262
+ max_src_len,
263
+ mels,
264
+ mel_lens,
265
+ max_mel_len,
266
+ pitches,
267
+ energies,
268
+ durations,
269
+ )
270
+
271
+ if len(data) == 6:
272
+ (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
273
+
274
+ speakers = torch.from_numpy(speakers).long().to(device)
275
+ texts = torch.from_numpy(texts).long().to(device)
276
+ src_lens = torch.from_numpy(src_lens).to(device)
277
+
278
+ return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
279
+
280
+
281
+ def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
282
+ # if losses is not None:
283
+ # logger.add_scalar("Loss/total_loss", losses[0], step)
284
+ # logger.add_scalar("Loss/mel_loss", losses[1], step)
285
+ # logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
286
+ # logger.add_scalar("Loss/pitch_loss", losses[3], step)
287
+ # logger.add_scalar("Loss/energy_loss", losses[4], step)
288
+ # logger.add_scalar("Loss/duration_loss", losses[5], step)
289
+ # if(len(losses) > 6):
290
+ # logger.add_scalar("Loss/disc_loss", losses[6], step)
291
+ # logger.add_scalar("Loss/fmap_loss", losses[7], step)
292
+ # logger.add_scalar("Loss/r_loss", losses[8], step)
293
+ # logger.add_scalar("Loss/g_loss", losses[9], step)
294
+ # logger.add_scalar("Loss/gen_loss", losses[10], step)
295
+ # logger.add_scalar("Loss/diff_loss", losses[11], step)
296
+
297
+ if fig is not None:
298
+ logger.add_figure(tag, fig)
299
+
300
+ if audio is not None:
301
+ audio = audio / (max(abs(audio)) * 1.1)
302
+ logger.add_audio(
303
+ tag,
304
+ audio,
305
+ sample_rate=sampling_rate,
306
+ )
307
+
308
+
309
+ def get_mask_from_lengths(lengths, max_len=None):
310
+ batch_size = lengths.shape[0]
311
+ if max_len is None:
312
+ max_len = torch.max(lengths).item()
313
+
314
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
315
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
316
+
317
+ return mask
318
+
319
+
320
+ def expand(values, durations):
321
+ out = list()
322
+ for value, d in zip(values, durations):
323
+ out += [value] * max(0, int(d))
324
+ return np.array(out)
325
+
326
+
327
+ def synth_one_sample_val(
328
+ targets, predictions, vocoder, model_config, preprocess_config
329
+ ):
330
+ index = np.random.choice(list(np.arange(targets[6].size(0))))
331
+
332
+ basename = targets[0][index]
333
+ src_len = predictions[8][index].item()
334
+ mel_len = predictions[9][index].item()
335
+ mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
336
+
337
+ mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
338
+ postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
339
+ duration = targets[11][index, :src_len].detach().cpu().numpy()
340
+
341
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
342
+ pitch = predictions[2][index, :src_len].detach().cpu().numpy()
343
+ pitch = expand(pitch, duration)
344
+ else:
345
+ pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
346
+
347
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
348
+ energy = predictions[3][index, :src_len].detach().cpu().numpy()
349
+ energy = expand(energy, duration)
350
+ else:
351
+ energy = predictions[3][index, :mel_len].detach().cpu().numpy()
352
+
353
+ with open(
354
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
355
+ ) as f:
356
+ stats = json.load(f)
357
+ stats = stats["pitch"] + stats["energy"][:2]
358
+
359
+ # from datetime import datetime
360
+ # now = datetime.now()
361
+ # current_time = now.strftime("%D:%H:%M:%S")
362
+ # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
363
+ # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
364
+ # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
365
+
366
+ fig = plot_mel(
367
+ [
368
+ (mel_prediction.cpu().numpy(), pitch, energy),
369
+ (postnet_mel_prediction.cpu().numpy(), pitch, energy),
370
+ (mel_target.cpu().numpy(), pitch, energy),
371
+ ],
372
+ stats,
373
+ [
374
+ "Raw mel spectrogram prediction",
375
+ "Postnet mel prediction",
376
+ "Ground-Truth Spectrogram",
377
+ ],
378
+ )
379
+
380
+ if vocoder is not None:
381
+ from .model import vocoder_infer
382
+
383
+ wav_reconstruction = vocoder_infer(
384
+ mel_target.unsqueeze(0),
385
+ vocoder,
386
+ model_config,
387
+ preprocess_config,
388
+ )[0]
389
+ wav_prediction = vocoder_infer(
390
+ postnet_mel_prediction.unsqueeze(0),
391
+ vocoder,
392
+ model_config,
393
+ preprocess_config,
394
+ )[0]
395
+ else:
396
+ wav_reconstruction = wav_prediction = None
397
+
398
+ return fig, wav_reconstruction, wav_prediction, basename
399
+
400
+
401
+ def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
402
+ if vocoder is not None:
403
+ from .model import vocoder_infer
404
+
405
+ wav_reconstruction = vocoder_infer(
406
+ mel_input.permute(0, 2, 1),
407
+ vocoder,
408
+ )
409
+ wav_prediction = vocoder_infer(
410
+ mel_prediction.permute(0, 2, 1),
411
+ vocoder,
412
+ )
413
+ else:
414
+ wav_reconstruction = wav_prediction = None
415
+
416
+ return wav_reconstruction, wav_prediction
417
+
418
+
419
+ def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
420
+ # (diff_output, diff_loss, latent_loss) = diffusion
421
+
422
+ basenames = targets[0]
423
+
424
+ for i in range(len(predictions[1])):
425
+ basename = basenames[i]
426
+ src_len = predictions[8][i].item()
427
+ mel_len = predictions[9][i].item()
428
+ mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
429
+ # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
430
+ # duration = predictions[5][i, :src_len].detach().cpu().numpy()
431
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
432
+ pitch = predictions[2][i, :src_len].detach().cpu().numpy()
433
+ # pitch = expand(pitch, duration)
434
+ else:
435
+ pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
436
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
437
+ energy = predictions[3][i, :src_len].detach().cpu().numpy()
438
+ # energy = expand(energy, duration)
439
+ else:
440
+ energy = predictions[3][i, :mel_len].detach().cpu().numpy()
441
+ # import ipdb; ipdb.set_trace()
442
+ with open(
443
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
444
+ ) as f:
445
+ stats = json.load(f)
446
+ stats = stats["pitch"] + stats["energy"][:2]
447
+
448
+ fig = plot_mel(
449
+ [
450
+ (mel_prediction.cpu().numpy(), pitch, energy),
451
+ ],
452
+ stats,
453
+ ["Synthetized Spectrogram by PostNet"],
454
+ )
455
+ # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
456
+ plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
457
+ plt.close()
458
+
459
+ from .model import vocoder_infer
460
+
461
+ mel_predictions = predictions[1].transpose(1, 2)
462
+ lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
463
+ wav_predictions = vocoder_infer(
464
+ mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
465
+ )
466
+
467
+ sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
468
+ for wav, basename in zip(wav_predictions, basenames):
469
+ wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
470
+
471
+
472
+ def plot_mel(data, titles=None):
473
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
474
+ if titles is None:
475
+ titles = [None for i in range(len(data))]
476
+
477
+ for i in range(len(data)):
478
+ mel = data[i]
479
+ axes[i][0].imshow(mel, origin="lower", aspect="auto")
480
+ axes[i][0].set_aspect(2.5, adjustable="box")
481
+ axes[i][0].set_ylim(0, mel.shape[0])
482
+ axes[i][0].set_title(titles[i], fontsize="medium")
483
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
484
+ axes[i][0].set_anchor("W")
485
+
486
+ return fig
487
+
488
+
489
+ def pad_1D(inputs, PAD=0):
490
+ def pad_data(x, length, PAD):
491
+ x_padded = np.pad(
492
+ x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
493
+ )
494
+ return x_padded
495
+
496
+ max_len = max((len(x) for x in inputs))
497
+ padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
498
+
499
+ return padded
500
+
501
+
502
+ def pad_2D(inputs, maxlen=None):
503
+ def pad(x, max_len):
504
+ PAD = 0
505
+ if np.shape(x)[0] > max_len:
506
+ raise ValueError("not max_len")
507
+
508
+ s = np.shape(x)[1]
509
+ x_padded = np.pad(
510
+ x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
511
+ )
512
+ return x_padded[:, :s]
513
+
514
+ if maxlen:
515
+ output = np.stack([pad(x, maxlen) for x in inputs])
516
+ else:
517
+ max_len = max(np.shape(x)[0] for x in inputs)
518
+ output = np.stack([pad(x, max_len) for x in inputs])
519
+
520
+ return output
521
+
522
+
523
+ def pad(input_ele, mel_max_length=None):
524
+ if mel_max_length:
525
+ max_len = mel_max_length
526
+ else:
527
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
528
+
529
+ out_list = list()
530
+ for i, batch in enumerate(input_ele):
531
+ if len(batch.shape) == 1:
532
+ one_batch_padded = F.pad(
533
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
534
+ )
535
+ elif len(batch.shape) == 2:
536
+ one_batch_padded = F.pad(
537
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
538
+ )
539
+ out_list.append(one_batch_padded)
540
+ out_padded = torch.stack(out_list)
541
+ return out_padded