Upload 3 files
Browse files- utilities/model.py +167 -0
- utilities/sampler.py +588 -0
- utilities/tools.py +541 -0
utilities/model.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import audiosr.hifigan as hifigan
|
4 |
+
|
5 |
+
|
6 |
+
def get_vocoder_config():
|
7 |
+
return {
|
8 |
+
"resblock": "1",
|
9 |
+
"num_gpus": 6,
|
10 |
+
"batch_size": 16,
|
11 |
+
"learning_rate": 0.0002,
|
12 |
+
"adam_b1": 0.8,
|
13 |
+
"adam_b2": 0.99,
|
14 |
+
"lr_decay": 0.999,
|
15 |
+
"seed": 1234,
|
16 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
17 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
18 |
+
"upsample_initial_channel": 1024,
|
19 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
20 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
21 |
+
"segment_size": 8192,
|
22 |
+
"num_mels": 64,
|
23 |
+
"num_freq": 1025,
|
24 |
+
"n_fft": 1024,
|
25 |
+
"hop_size": 160,
|
26 |
+
"win_size": 1024,
|
27 |
+
"sampling_rate": 16000,
|
28 |
+
"fmin": 0,
|
29 |
+
"fmax": 8000,
|
30 |
+
"fmax_for_loss": None,
|
31 |
+
"num_workers": 4,
|
32 |
+
"dist_config": {
|
33 |
+
"dist_backend": "nccl",
|
34 |
+
"dist_url": "tcp://localhost:54321",
|
35 |
+
"world_size": 1,
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def get_vocoder_config_48k():
|
41 |
+
return {
|
42 |
+
"resblock": "1",
|
43 |
+
"num_gpus": 8,
|
44 |
+
"batch_size": 128,
|
45 |
+
"learning_rate": 0.0001,
|
46 |
+
"adam_b1": 0.8,
|
47 |
+
"adam_b2": 0.99,
|
48 |
+
"lr_decay": 0.999,
|
49 |
+
"seed": 1234,
|
50 |
+
"upsample_rates": [6, 5, 4, 2, 2],
|
51 |
+
"upsample_kernel_sizes": [12, 10, 8, 4, 4],
|
52 |
+
"upsample_initial_channel": 1536,
|
53 |
+
"resblock_kernel_sizes": [3, 7, 11, 15],
|
54 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
55 |
+
"segment_size": 15360,
|
56 |
+
"num_mels": 256,
|
57 |
+
"n_fft": 2048,
|
58 |
+
"hop_size": 480,
|
59 |
+
"win_size": 2048,
|
60 |
+
"sampling_rate": 48000,
|
61 |
+
"fmin": 20,
|
62 |
+
"fmax": 24000,
|
63 |
+
"fmax_for_loss": None,
|
64 |
+
"num_workers": 8,
|
65 |
+
"dist_config": {
|
66 |
+
"dist_backend": "nccl",
|
67 |
+
"dist_url": "tcp://localhost:18273",
|
68 |
+
"world_size": 1,
|
69 |
+
},
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def get_available_checkpoint_keys(model, ckpt):
|
74 |
+
state_dict = torch.load(ckpt)["state_dict"]
|
75 |
+
current_state_dict = model.state_dict()
|
76 |
+
new_state_dict = {}
|
77 |
+
for k in state_dict.keys():
|
78 |
+
if (
|
79 |
+
k in current_state_dict.keys()
|
80 |
+
and current_state_dict[k].size() == state_dict[k].size()
|
81 |
+
):
|
82 |
+
new_state_dict[k] = state_dict[k]
|
83 |
+
else:
|
84 |
+
print("==> WARNING: Skipping %s" % k)
|
85 |
+
print(
|
86 |
+
"%s out of %s keys are matched"
|
87 |
+
% (len(new_state_dict.keys()), len(state_dict.keys()))
|
88 |
+
)
|
89 |
+
return new_state_dict
|
90 |
+
|
91 |
+
|
92 |
+
def get_param_num(model):
|
93 |
+
num_param = sum(param.numel() for param in model.parameters())
|
94 |
+
return num_param
|
95 |
+
|
96 |
+
|
97 |
+
def torch_version_orig_mod_remove(state_dict):
|
98 |
+
new_state_dict = {}
|
99 |
+
new_state_dict["generator"] = {}
|
100 |
+
for key in state_dict["generator"].keys():
|
101 |
+
if "_orig_mod." in key:
|
102 |
+
new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
|
103 |
+
"generator"
|
104 |
+
][key]
|
105 |
+
else:
|
106 |
+
new_state_dict["generator"][key] = state_dict["generator"][key]
|
107 |
+
return new_state_dict
|
108 |
+
|
109 |
+
|
110 |
+
def get_vocoder(config, device, mel_bins):
|
111 |
+
name = "HiFi-GAN"
|
112 |
+
speaker = ""
|
113 |
+
if name == "MelGAN":
|
114 |
+
if speaker == "LJSpeech":
|
115 |
+
vocoder = torch.hub.load(
|
116 |
+
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
|
117 |
+
)
|
118 |
+
elif speaker == "universal":
|
119 |
+
vocoder = torch.hub.load(
|
120 |
+
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
|
121 |
+
)
|
122 |
+
vocoder.mel2wav.eval()
|
123 |
+
vocoder.mel2wav.to(device)
|
124 |
+
elif name == "HiFi-GAN":
|
125 |
+
if mel_bins == 64:
|
126 |
+
config = get_vocoder_config()
|
127 |
+
config = hifigan.AttrDict(config)
|
128 |
+
vocoder = hifigan.Generator_old(config)
|
129 |
+
# print("Load hifigan/g_01080000")
|
130 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
|
131 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
|
132 |
+
# ckpt = torch_version_orig_mod_remove(ckpt)
|
133 |
+
# vocoder.load_state_dict(ckpt["generator"])
|
134 |
+
vocoder.eval()
|
135 |
+
vocoder.remove_weight_norm()
|
136 |
+
vocoder.to(device)
|
137 |
+
else:
|
138 |
+
config = get_vocoder_config_48k()
|
139 |
+
config = hifigan.AttrDict(config)
|
140 |
+
vocoder = hifigan.Generator_old(config)
|
141 |
+
# print("Load hifigan/g_01080000")
|
142 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
|
143 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
|
144 |
+
# ckpt = torch_version_orig_mod_remove(ckpt)
|
145 |
+
# vocoder.load_state_dict(ckpt["generator"])
|
146 |
+
vocoder.eval()
|
147 |
+
vocoder.remove_weight_norm()
|
148 |
+
vocoder.to(device)
|
149 |
+
return vocoder
|
150 |
+
|
151 |
+
|
152 |
+
def vocoder_infer(mels, vocoder, lengths=None):
|
153 |
+
with torch.no_grad():
|
154 |
+
wavs = vocoder(mels).squeeze(1)
|
155 |
+
|
156 |
+
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
157 |
+
|
158 |
+
if lengths is not None:
|
159 |
+
wavs = wavs[:, :lengths]
|
160 |
+
|
161 |
+
# wavs = [wav for wav in wavs]
|
162 |
+
|
163 |
+
# for i in range(len(mels)):
|
164 |
+
# if lengths is not None:
|
165 |
+
# wavs[i] = wavs[i][: lengths[i]]
|
166 |
+
|
167 |
+
return wavs
|
utilities/sampler.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterator, List, Optional, Union
|
2 |
+
from collections import Counter
|
3 |
+
import logging
|
4 |
+
from operator import itemgetter
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torch.utils.data import DistributedSampler
|
10 |
+
from torch.utils.data.sampler import Sampler
|
11 |
+
|
12 |
+
LOGGER = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
from torch.utils.data import Dataset, Sampler
|
15 |
+
|
16 |
+
|
17 |
+
class DatasetFromSampler(Dataset):
|
18 |
+
"""Dataset to create indexes from `Sampler`.
|
19 |
+
Args:
|
20 |
+
sampler: PyTorch sampler
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, sampler: Sampler):
|
24 |
+
"""Initialisation for DatasetFromSampler."""
|
25 |
+
self.sampler = sampler
|
26 |
+
self.sampler_list = None
|
27 |
+
|
28 |
+
def __getitem__(self, index: int):
|
29 |
+
"""Gets element of the dataset.
|
30 |
+
Args:
|
31 |
+
index: index of the element in the dataset
|
32 |
+
Returns:
|
33 |
+
Single element by index
|
34 |
+
"""
|
35 |
+
if self.sampler_list is None:
|
36 |
+
self.sampler_list = list(self.sampler)
|
37 |
+
return self.sampler_list[index]
|
38 |
+
|
39 |
+
def __len__(self) -> int:
|
40 |
+
"""
|
41 |
+
Returns:
|
42 |
+
int: length of the dataset
|
43 |
+
"""
|
44 |
+
return len(self.sampler)
|
45 |
+
|
46 |
+
|
47 |
+
class BalanceClassSampler(Sampler):
|
48 |
+
"""Allows you to create stratified sample on unbalanced classes.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
labels: list of class label for each elem in the dataset
|
52 |
+
mode: Strategy to balance classes.
|
53 |
+
Must be one of [downsampling, upsampling]
|
54 |
+
|
55 |
+
Python API examples:
|
56 |
+
|
57 |
+
.. code-block:: python
|
58 |
+
|
59 |
+
import os
|
60 |
+
from torch import nn, optim
|
61 |
+
from torch.utils.data import DataLoader
|
62 |
+
from catalyst import dl
|
63 |
+
from catalyst.data import ToTensor, BalanceClassSampler
|
64 |
+
from catalyst.contrib.datasets import MNIST
|
65 |
+
|
66 |
+
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
|
67 |
+
train_labels = train_data.targets.cpu().numpy().tolist()
|
68 |
+
train_sampler = BalanceClassSampler(train_labels, mode=5000)
|
69 |
+
valid_data = MNIST(os.getcwd(), train=False)
|
70 |
+
|
71 |
+
loaders = {
|
72 |
+
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
|
73 |
+
"valid": DataLoader(valid_data, batch_size=32),
|
74 |
+
}
|
75 |
+
|
76 |
+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
|
77 |
+
criterion = nn.CrossEntropyLoss()
|
78 |
+
optimizer = optim.Adam(model.parameters(), lr=0.02)
|
79 |
+
|
80 |
+
runner = dl.SupervisedRunner()
|
81 |
+
# model training
|
82 |
+
runner.train(
|
83 |
+
model=model,
|
84 |
+
criterion=criterion,
|
85 |
+
optimizer=optimizer,
|
86 |
+
loaders=loaders,
|
87 |
+
num_epochs=1,
|
88 |
+
logdir="./logs",
|
89 |
+
valid_loader="valid",
|
90 |
+
valid_metric="loss",
|
91 |
+
minimize_valid_metric=True,
|
92 |
+
verbose=True,
|
93 |
+
)
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
|
97 |
+
"""Sampler initialisation."""
|
98 |
+
super().__init__(labels)
|
99 |
+
|
100 |
+
labels = np.array(labels)
|
101 |
+
samples_per_class = {label: (labels == label).sum() for label in set(labels)}
|
102 |
+
|
103 |
+
self.lbl2idx = {
|
104 |
+
label: np.arange(len(labels))[labels == label].tolist()
|
105 |
+
for label in set(labels)
|
106 |
+
}
|
107 |
+
|
108 |
+
if isinstance(mode, str):
|
109 |
+
assert mode in ["downsampling", "upsampling"]
|
110 |
+
|
111 |
+
if isinstance(mode, int) or mode == "upsampling":
|
112 |
+
samples_per_class = (
|
113 |
+
mode if isinstance(mode, int) else max(samples_per_class.values())
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
samples_per_class = min(samples_per_class.values())
|
117 |
+
|
118 |
+
self.labels = labels
|
119 |
+
self.samples_per_class = samples_per_class
|
120 |
+
self.length = self.samples_per_class * len(set(labels))
|
121 |
+
|
122 |
+
def __iter__(self) -> Iterator[int]:
|
123 |
+
"""
|
124 |
+
Returns:
|
125 |
+
iterator of indices of stratified sample
|
126 |
+
"""
|
127 |
+
indices = []
|
128 |
+
for key in sorted(self.lbl2idx):
|
129 |
+
replace_flag = self.samples_per_class > len(self.lbl2idx[key])
|
130 |
+
indices += np.random.choice(
|
131 |
+
self.lbl2idx[key], self.samples_per_class, replace=replace_flag
|
132 |
+
).tolist()
|
133 |
+
assert len(indices) == self.length
|
134 |
+
np.random.shuffle(indices)
|
135 |
+
|
136 |
+
return iter(indices)
|
137 |
+
|
138 |
+
def __len__(self) -> int:
|
139 |
+
"""
|
140 |
+
Returns:
|
141 |
+
length of result sample
|
142 |
+
"""
|
143 |
+
return self.length
|
144 |
+
|
145 |
+
|
146 |
+
class BatchBalanceClassSampler(Sampler):
|
147 |
+
"""
|
148 |
+
This kind of sampler can be used for both metric learning and classification task.
|
149 |
+
|
150 |
+
BatchSampler with the given strategy for the C unique classes dataset:
|
151 |
+
- Selection `num_classes` of C classes for each batch
|
152 |
+
- Selection `num_samples` instances for each class in the batch
|
153 |
+
The epoch ends after `num_batches`.
|
154 |
+
So, the batch sise is `num_classes` * `num_samples`.
|
155 |
+
|
156 |
+
One of the purposes of this sampler is to be used for
|
157 |
+
forming triplets and pos/neg pairs inside the batch.
|
158 |
+
To guarante existance of these pairs in the batch,
|
159 |
+
`num_classes` and `num_samples` should be > 1. (1)
|
160 |
+
|
161 |
+
This type of sampling can be found in the classical paper of Person Re-Id,
|
162 |
+
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
|
163 |
+
`In Defense of the Triplet Loss for Person Re-Identification`_.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
labels: list of classes labeles for each elem in the dataset
|
167 |
+
num_classes: number of classes in a batch, should be > 1
|
168 |
+
num_samples: number of instances of each class in a batch, should be > 1
|
169 |
+
num_batches: number of batches in epoch
|
170 |
+
(default = len(labels) // (num_classes * num_samples))
|
171 |
+
|
172 |
+
.. _In Defense of the Triplet Loss for Person Re-Identification:
|
173 |
+
https://arxiv.org/abs/1703.07737
|
174 |
+
|
175 |
+
Python API examples:
|
176 |
+
|
177 |
+
.. code-block:: python
|
178 |
+
|
179 |
+
import os
|
180 |
+
from torch import nn, optim
|
181 |
+
from torch.utils.data import DataLoader
|
182 |
+
from catalyst import dl
|
183 |
+
from catalyst.data import ToTensor, BatchBalanceClassSampler
|
184 |
+
from catalyst.contrib.datasets import MNIST
|
185 |
+
|
186 |
+
train_data = MNIST(os.getcwd(), train=True, download=True)
|
187 |
+
train_labels = train_data.targets.cpu().numpy().tolist()
|
188 |
+
train_sampler = BatchBalanceClassSampler(
|
189 |
+
train_labels, num_classes=10, num_samples=4)
|
190 |
+
valid_data = MNIST(os.getcwd(), train=False)
|
191 |
+
|
192 |
+
loaders = {
|
193 |
+
"train": DataLoader(train_data, batch_sampler=train_sampler),
|
194 |
+
"valid": DataLoader(valid_data, batch_size=32),
|
195 |
+
}
|
196 |
+
|
197 |
+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
|
198 |
+
criterion = nn.CrossEntropyLoss()
|
199 |
+
optimizer = optim.Adam(model.parameters(), lr=0.02)
|
200 |
+
|
201 |
+
runner = dl.SupervisedRunner()
|
202 |
+
# model training
|
203 |
+
runner.train(
|
204 |
+
model=model,
|
205 |
+
criterion=criterion,
|
206 |
+
optimizer=optimizer,
|
207 |
+
loaders=loaders,
|
208 |
+
num_epochs=1,
|
209 |
+
logdir="./logs",
|
210 |
+
valid_loader="valid",
|
211 |
+
valid_metric="loss",
|
212 |
+
minimize_valid_metric=True,
|
213 |
+
verbose=True,
|
214 |
+
)
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
labels: Union[List[int], np.ndarray],
|
220 |
+
num_classes: int,
|
221 |
+
num_samples: int,
|
222 |
+
num_batches: int = None,
|
223 |
+
):
|
224 |
+
"""Sampler initialisation."""
|
225 |
+
super().__init__(labels)
|
226 |
+
classes = set(labels)
|
227 |
+
|
228 |
+
assert isinstance(num_classes, int) and isinstance(num_samples, int)
|
229 |
+
assert (1 < num_classes <= len(classes)) and (1 < num_samples)
|
230 |
+
assert all(
|
231 |
+
n > 1 for n in Counter(labels).values()
|
232 |
+
), "Each class shoud contain at least 2 instances to fit (1)"
|
233 |
+
|
234 |
+
labels = np.array(labels)
|
235 |
+
self._labels = list(set(labels.tolist()))
|
236 |
+
self._num_classes = num_classes
|
237 |
+
self._num_samples = num_samples
|
238 |
+
self._batch_size = self._num_classes * self._num_samples
|
239 |
+
self._num_batches = num_batches or len(labels) // self._batch_size
|
240 |
+
self.lbl2idx = {
|
241 |
+
label: np.arange(len(labels))[labels == label].tolist()
|
242 |
+
for label in set(labels)
|
243 |
+
}
|
244 |
+
|
245 |
+
@property
|
246 |
+
def batch_size(self) -> int:
|
247 |
+
"""
|
248 |
+
Returns:
|
249 |
+
this value should be used in DataLoader as batch size
|
250 |
+
"""
|
251 |
+
return self._batch_size
|
252 |
+
|
253 |
+
@property
|
254 |
+
def batches_in_epoch(self) -> int:
|
255 |
+
"""
|
256 |
+
Returns:
|
257 |
+
number of batches in an epoch
|
258 |
+
"""
|
259 |
+
return self._num_batches
|
260 |
+
|
261 |
+
def __len__(self) -> int:
|
262 |
+
"""
|
263 |
+
Returns:
|
264 |
+
number of samples in an epoch
|
265 |
+
"""
|
266 |
+
return self._num_batches # * self._batch_size
|
267 |
+
|
268 |
+
def __iter__(self) -> Iterator[int]:
|
269 |
+
"""
|
270 |
+
Returns:
|
271 |
+
indeces for sampling dataset elems during an epoch
|
272 |
+
"""
|
273 |
+
indices = []
|
274 |
+
for _ in range(self._num_batches):
|
275 |
+
batch_indices = []
|
276 |
+
classes_for_batch = random.sample(self._labels, self._num_classes)
|
277 |
+
while self._num_classes != len(set(classes_for_batch)):
|
278 |
+
classes_for_batch = random.sample(self._labels, self._num_classes)
|
279 |
+
for cls_id in classes_for_batch:
|
280 |
+
replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
|
281 |
+
batch_indices += np.random.choice(
|
282 |
+
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
|
283 |
+
).tolist()
|
284 |
+
indices.append(batch_indices)
|
285 |
+
return iter(indices)
|
286 |
+
|
287 |
+
|
288 |
+
class DynamicBalanceClassSampler(Sampler):
|
289 |
+
"""
|
290 |
+
This kind of sampler can be used for classification tasks with significant
|
291 |
+
class imbalance.
|
292 |
+
|
293 |
+
The idea of this sampler that we start with the original class distribution
|
294 |
+
and gradually move to uniform class distribution like with downsampling.
|
295 |
+
|
296 |
+
Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
|
297 |
+
i and :math: #C_min is a size of the rarest class, so :math: D_i define
|
298 |
+
class distribution. Also define :math: g(n_epoch) is a exponential
|
299 |
+
scheduler. On each epoch current :math: D_i calculated as
|
300 |
+
:math: current D_i = D_i ^ g(n_epoch),
|
301 |
+
after this data samples according this distribution.
|
302 |
+
|
303 |
+
Notes:
|
304 |
+
In the end of the training, epochs will contain only
|
305 |
+
min_size_class * n_classes examples. So, possible it will not
|
306 |
+
necessary to do validation on each epoch. For this reason use
|
307 |
+
ControlFlowCallback.
|
308 |
+
|
309 |
+
Examples:
|
310 |
+
|
311 |
+
>>> import torch
|
312 |
+
>>> import numpy as np
|
313 |
+
|
314 |
+
>>> from catalyst.data import DynamicBalanceClassSampler
|
315 |
+
>>> from torch.utils import data
|
316 |
+
|
317 |
+
>>> features = torch.Tensor(np.random.random((200, 100)))
|
318 |
+
>>> labels = np.random.randint(0, 4, size=(200,))
|
319 |
+
>>> sampler = DynamicBalanceClassSampler(labels)
|
320 |
+
>>> labels = torch.LongTensor(labels)
|
321 |
+
>>> dataset = data.TensorDataset(features, labels)
|
322 |
+
>>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
|
323 |
+
|
324 |
+
>>> for batch in loader:
|
325 |
+
>>> b_features, b_labels = batch
|
326 |
+
|
327 |
+
Sampler was inspired by https://arxiv.org/abs/1901.06783
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(
|
331 |
+
self,
|
332 |
+
labels: List[Union[int, str]],
|
333 |
+
exp_lambda: float = 0.9,
|
334 |
+
start_epoch: int = 0,
|
335 |
+
max_d: Optional[int] = None,
|
336 |
+
mode: Union[str, int] = "downsampling",
|
337 |
+
ignore_warning: bool = False,
|
338 |
+
):
|
339 |
+
"""
|
340 |
+
Args:
|
341 |
+
labels: list of labels for each elem in the dataset
|
342 |
+
exp_lambda: exponent figure for schedule
|
343 |
+
start_epoch: start epoch number, can be useful for multi-stage
|
344 |
+
experiments
|
345 |
+
max_d: if not None, limit on the difference between the most
|
346 |
+
frequent and the rarest classes, heuristic
|
347 |
+
mode: number of samples per class in the end of training. Must be
|
348 |
+
"downsampling" or number. Before change it, make sure that you
|
349 |
+
understand how does it work
|
350 |
+
ignore_warning: ignore warning about min class size
|
351 |
+
"""
|
352 |
+
assert isinstance(start_epoch, int)
|
353 |
+
assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
|
354 |
+
super().__init__(labels)
|
355 |
+
self.exp_lambda = exp_lambda
|
356 |
+
if max_d is None:
|
357 |
+
max_d = np.inf
|
358 |
+
self.max_d = max_d
|
359 |
+
self.epoch = start_epoch
|
360 |
+
labels = np.array(labels)
|
361 |
+
samples_per_class = Counter(labels)
|
362 |
+
self.min_class_size = min(samples_per_class.values())
|
363 |
+
|
364 |
+
if self.min_class_size < 100 and not ignore_warning:
|
365 |
+
LOGGER.warning(
|
366 |
+
f"the smallest class contains only"
|
367 |
+
f" {self.min_class_size} examples. At the end of"
|
368 |
+
f" training, epochs will contain only"
|
369 |
+
f" {self.min_class_size * len(samples_per_class)}"
|
370 |
+
f" examples"
|
371 |
+
)
|
372 |
+
|
373 |
+
self.original_d = {
|
374 |
+
key: value / self.min_class_size for key, value in samples_per_class.items()
|
375 |
+
}
|
376 |
+
self.label2idxes = {
|
377 |
+
label: np.arange(len(labels))[labels == label].tolist()
|
378 |
+
for label in set(labels)
|
379 |
+
}
|
380 |
+
|
381 |
+
if isinstance(mode, int):
|
382 |
+
self.min_class_size = mode
|
383 |
+
else:
|
384 |
+
assert mode == "downsampling"
|
385 |
+
|
386 |
+
self.labels = labels
|
387 |
+
self._update()
|
388 |
+
|
389 |
+
def _update(self) -> None:
|
390 |
+
"""Update d coefficients."""
|
391 |
+
current_d = {
|
392 |
+
key: min(value ** self._exp_scheduler(), self.max_d)
|
393 |
+
for key, value in self.original_d.items()
|
394 |
+
}
|
395 |
+
samples_per_classes = {
|
396 |
+
key: int(value * self.min_class_size) for key, value in current_d.items()
|
397 |
+
}
|
398 |
+
self.samples_per_classes = samples_per_classes
|
399 |
+
self.length = np.sum(list(samples_per_classes.values()))
|
400 |
+
self.epoch += 1
|
401 |
+
|
402 |
+
def _exp_scheduler(self) -> float:
|
403 |
+
return self.exp_lambda**self.epoch
|
404 |
+
|
405 |
+
def __iter__(self) -> Iterator[int]:
|
406 |
+
"""
|
407 |
+
Returns:
|
408 |
+
iterator of indices of stratified sample
|
409 |
+
"""
|
410 |
+
indices = []
|
411 |
+
for key in sorted(self.label2idxes):
|
412 |
+
samples_per_class = self.samples_per_classes[key]
|
413 |
+
replace_flag = samples_per_class > len(self.label2idxes[key])
|
414 |
+
indices += np.random.choice(
|
415 |
+
self.label2idxes[key], samples_per_class, replace=replace_flag
|
416 |
+
).tolist()
|
417 |
+
assert len(indices) == self.length
|
418 |
+
np.random.shuffle(indices)
|
419 |
+
self._update()
|
420 |
+
return iter(indices)
|
421 |
+
|
422 |
+
def __len__(self) -> int:
|
423 |
+
"""
|
424 |
+
Returns:
|
425 |
+
length of result sample
|
426 |
+
"""
|
427 |
+
return self.length
|
428 |
+
|
429 |
+
|
430 |
+
class MiniEpochSampler(Sampler):
|
431 |
+
"""
|
432 |
+
Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
data_len: Size of the dataset
|
436 |
+
mini_epoch_len: Num samples from the dataset used in one
|
437 |
+
mini epoch.
|
438 |
+
drop_last: If ``True``, sampler will drop the last batches
|
439 |
+
if its size would be less than ``batches_per_epoch``
|
440 |
+
shuffle: one of ``"always"``, ``"real_epoch"``, or `None``.
|
441 |
+
The sampler will shuffle indices
|
442 |
+
> "per_mini_epoch" - every mini epoch (every ``__iter__`` call)
|
443 |
+
> "per_epoch" -- every real epoch
|
444 |
+
> None -- don't shuffle
|
445 |
+
|
446 |
+
Example:
|
447 |
+
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100)
|
448 |
+
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True)
|
449 |
+
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100,
|
450 |
+
>>> shuffle="per_epoch")
|
451 |
+
"""
|
452 |
+
|
453 |
+
def __init__(
|
454 |
+
self,
|
455 |
+
data_len: int,
|
456 |
+
mini_epoch_len: int,
|
457 |
+
drop_last: bool = False,
|
458 |
+
shuffle: str = None,
|
459 |
+
):
|
460 |
+
"""Sampler initialisation."""
|
461 |
+
super().__init__(None)
|
462 |
+
|
463 |
+
self.data_len = int(data_len)
|
464 |
+
self.mini_epoch_len = int(mini_epoch_len)
|
465 |
+
|
466 |
+
self.steps = int(data_len / self.mini_epoch_len)
|
467 |
+
self.state_i = 0
|
468 |
+
|
469 |
+
has_reminder = data_len - self.steps * mini_epoch_len > 0
|
470 |
+
if self.steps == 0:
|
471 |
+
self.divider = 1
|
472 |
+
elif has_reminder and not drop_last:
|
473 |
+
self.divider = self.steps + 1
|
474 |
+
else:
|
475 |
+
self.divider = self.steps
|
476 |
+
|
477 |
+
self._indices = np.arange(self.data_len)
|
478 |
+
self.indices = self._indices
|
479 |
+
self.end_pointer = max(self.data_len, self.mini_epoch_len)
|
480 |
+
|
481 |
+
if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]):
|
482 |
+
raise ValueError(
|
483 |
+
"Shuffle must be one of ['per_mini_epoch', 'per_epoch']. "
|
484 |
+
+ f"Got {shuffle}"
|
485 |
+
)
|
486 |
+
self.shuffle_type = shuffle
|
487 |
+
|
488 |
+
def shuffle(self) -> None:
|
489 |
+
"""Shuffle sampler indices."""
|
490 |
+
if self.shuffle_type == "per_mini_epoch" or (
|
491 |
+
self.shuffle_type == "per_epoch" and self.state_i == 0
|
492 |
+
):
|
493 |
+
if self.data_len >= self.mini_epoch_len:
|
494 |
+
self.indices = self._indices
|
495 |
+
np.random.shuffle(self.indices)
|
496 |
+
else:
|
497 |
+
self.indices = np.random.choice(
|
498 |
+
self._indices, self.mini_epoch_len, replace=True
|
499 |
+
)
|
500 |
+
|
501 |
+
def __iter__(self) -> Iterator[int]:
|
502 |
+
"""Iterate over sampler.
|
503 |
+
|
504 |
+
Returns:
|
505 |
+
python iterator
|
506 |
+
"""
|
507 |
+
self.state_i = self.state_i % self.divider
|
508 |
+
self.shuffle()
|
509 |
+
|
510 |
+
start = self.state_i * self.mini_epoch_len
|
511 |
+
stop = (
|
512 |
+
self.end_pointer
|
513 |
+
if (self.state_i == self.steps)
|
514 |
+
else (self.state_i + 1) * self.mini_epoch_len
|
515 |
+
)
|
516 |
+
indices = self.indices[start:stop].tolist()
|
517 |
+
|
518 |
+
self.state_i += 1
|
519 |
+
return iter(indices)
|
520 |
+
|
521 |
+
def __len__(self) -> int:
|
522 |
+
"""
|
523 |
+
Returns:
|
524 |
+
int: length of the mini-epoch
|
525 |
+
"""
|
526 |
+
return self.mini_epoch_len
|
527 |
+
|
528 |
+
|
529 |
+
class DistributedSamplerWrapper(DistributedSampler):
|
530 |
+
"""
|
531 |
+
Wrapper over `Sampler` for distributed training.
|
532 |
+
Allows you to use any sampler in distributed mode.
|
533 |
+
|
534 |
+
It is especially useful in conjunction with
|
535 |
+
`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
536 |
+
process can pass a DistributedSamplerWrapper instance as a DataLoader
|
537 |
+
sampler, and load a subset of subsampled data of the original dataset
|
538 |
+
that is exclusive to it.
|
539 |
+
|
540 |
+
.. note::
|
541 |
+
Sampler is assumed to be of constant size.
|
542 |
+
"""
|
543 |
+
|
544 |
+
def __init__(
|
545 |
+
self,
|
546 |
+
sampler,
|
547 |
+
num_replicas: Optional[int] = None,
|
548 |
+
rank: Optional[int] = None,
|
549 |
+
shuffle: bool = True,
|
550 |
+
):
|
551 |
+
"""
|
552 |
+
|
553 |
+
Args:
|
554 |
+
sampler: Sampler used for subsampling
|
555 |
+
num_replicas (int, optional): Number of processes participating in
|
556 |
+
distributed training
|
557 |
+
rank (int, optional): Rank of the current process
|
558 |
+
within ``num_replicas``
|
559 |
+
shuffle (bool, optional): If true (default),
|
560 |
+
sampler will shuffle the indices
|
561 |
+
"""
|
562 |
+
super(DistributedSamplerWrapper, self).__init__(
|
563 |
+
DatasetFromSampler(sampler),
|
564 |
+
num_replicas=num_replicas,
|
565 |
+
rank=rank,
|
566 |
+
shuffle=shuffle,
|
567 |
+
)
|
568 |
+
self.sampler = sampler
|
569 |
+
|
570 |
+
def __iter__(self) -> Iterator[int]:
|
571 |
+
"""Iterate over sampler.
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
python iterator
|
575 |
+
"""
|
576 |
+
self.dataset = DatasetFromSampler(self.sampler)
|
577 |
+
indexes_of_indexes = super().__iter__()
|
578 |
+
subsampler_indexes = self.dataset
|
579 |
+
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
|
580 |
+
|
581 |
+
|
582 |
+
__all__ = [
|
583 |
+
"BalanceClassSampler",
|
584 |
+
"BatchBalanceClassSampler",
|
585 |
+
"DistributedSamplerWrapper",
|
586 |
+
"DynamicBalanceClassSampler",
|
587 |
+
"MiniEpochSampler",
|
588 |
+
]
|
utilities/tools.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Haohe Liu
|
2 |
+
# Email: [email protected]
|
3 |
+
# Date: 11 Feb 2023
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib
|
12 |
+
from scipy.io import wavfile
|
13 |
+
from matplotlib import pyplot as plt
|
14 |
+
|
15 |
+
|
16 |
+
matplotlib.use("Agg")
|
17 |
+
|
18 |
+
import hashlib
|
19 |
+
import os
|
20 |
+
|
21 |
+
import requests
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
URL_MAP = {
|
25 |
+
"vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
|
26 |
+
"vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
|
27 |
+
"melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
|
28 |
+
}
|
29 |
+
|
30 |
+
CKPT_MAP = {
|
31 |
+
"vggishish_lpaps": "vggishish16.pt",
|
32 |
+
"vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
|
33 |
+
"melception": "melception-21-05-10T09-28-40.pt",
|
34 |
+
}
|
35 |
+
|
36 |
+
MD5_MAP = {
|
37 |
+
"vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
|
38 |
+
"vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
|
39 |
+
"melception": "a71a41041e945b457c7d3d814bbcf72d",
|
40 |
+
}
|
41 |
+
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
|
45 |
+
def load_json(fname):
|
46 |
+
with open(fname, "r") as f:
|
47 |
+
data = json.load(f)
|
48 |
+
return data
|
49 |
+
|
50 |
+
|
51 |
+
def read_json(dataset_json_file):
|
52 |
+
with open(dataset_json_file, "r") as fp:
|
53 |
+
data_json = json.load(fp)
|
54 |
+
return data_json["data"]
|
55 |
+
|
56 |
+
|
57 |
+
def copy_test_subset_data(metadata, testset_copy_target_path):
|
58 |
+
# metadata = read_json(testset_metadata)
|
59 |
+
os.makedirs(testset_copy_target_path, exist_ok=True)
|
60 |
+
if len(os.listdir(testset_copy_target_path)) == len(metadata):
|
61 |
+
return
|
62 |
+
else:
|
63 |
+
# delete files in folder testset_copy_target_path
|
64 |
+
for file in os.listdir(testset_copy_target_path):
|
65 |
+
try:
|
66 |
+
os.remove(os.path.join(testset_copy_target_path, file))
|
67 |
+
except Exception as e:
|
68 |
+
print(e)
|
69 |
+
|
70 |
+
print("Copying test subset data to {}".format(testset_copy_target_path))
|
71 |
+
for each in tqdm(metadata):
|
72 |
+
cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
|
73 |
+
os.system(cmd)
|
74 |
+
|
75 |
+
|
76 |
+
def listdir_nohidden(path):
|
77 |
+
for f in os.listdir(path):
|
78 |
+
if not f.startswith("."):
|
79 |
+
yield f
|
80 |
+
|
81 |
+
|
82 |
+
def get_restore_step(path):
|
83 |
+
checkpoints = os.listdir(path)
|
84 |
+
if os.path.exists(os.path.join(path, "final.ckpt")):
|
85 |
+
return "final.ckpt", 0
|
86 |
+
elif not os.path.exists(os.path.join(path, "last.ckpt")):
|
87 |
+
steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
|
88 |
+
return checkpoints[np.argmax(steps)], np.max(steps)
|
89 |
+
else:
|
90 |
+
steps = []
|
91 |
+
for x in checkpoints:
|
92 |
+
if "last" in x:
|
93 |
+
if "-v" not in x:
|
94 |
+
fname = "last.ckpt"
|
95 |
+
else:
|
96 |
+
this_version = int(x.split(".ckpt")[0].split("-v")[1])
|
97 |
+
steps.append(this_version)
|
98 |
+
if len(steps) == 0 or this_version > np.max(steps):
|
99 |
+
fname = "last-v%s.ckpt" % this_version
|
100 |
+
return fname, 0
|
101 |
+
|
102 |
+
|
103 |
+
def download(url, local_path, chunk_size=1024):
|
104 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
105 |
+
with requests.get(url, stream=True) as r:
|
106 |
+
total_size = int(r.headers.get("content-length", 0))
|
107 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
108 |
+
with open(local_path, "wb") as f:
|
109 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
110 |
+
if data:
|
111 |
+
f.write(data)
|
112 |
+
pbar.update(chunk_size)
|
113 |
+
|
114 |
+
|
115 |
+
def md5_hash(path):
|
116 |
+
with open(path, "rb") as f:
|
117 |
+
content = f.read()
|
118 |
+
return hashlib.md5(content).hexdigest()
|
119 |
+
|
120 |
+
|
121 |
+
def get_ckpt_path(name, root, check=False):
|
122 |
+
assert name in URL_MAP
|
123 |
+
path = os.path.join(root, CKPT_MAP[name])
|
124 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
125 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
126 |
+
download(URL_MAP[name], path)
|
127 |
+
md5 = md5_hash(path)
|
128 |
+
assert md5 == MD5_MAP[name], md5
|
129 |
+
return path
|
130 |
+
|
131 |
+
|
132 |
+
class KeyNotFoundError(Exception):
|
133 |
+
def __init__(self, cause, keys=None, visited=None):
|
134 |
+
self.cause = cause
|
135 |
+
self.keys = keys
|
136 |
+
self.visited = visited
|
137 |
+
messages = list()
|
138 |
+
if keys is not None:
|
139 |
+
messages.append("Key not found: {}".format(keys))
|
140 |
+
if visited is not None:
|
141 |
+
messages.append("Visited: {}".format(visited))
|
142 |
+
messages.append("Cause:\n{}".format(cause))
|
143 |
+
message = "\n".join(messages)
|
144 |
+
super().__init__(message)
|
145 |
+
|
146 |
+
|
147 |
+
def retrieve(
|
148 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
149 |
+
):
|
150 |
+
"""Given a nested list or dict return the desired value at key expanding
|
151 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
152 |
+
is done in-place.
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
list_or_dict : list or dict
|
157 |
+
Possibly nested list or dictionary.
|
158 |
+
key : str
|
159 |
+
key/to/value, path like string describing all keys necessary to
|
160 |
+
consider to get to the desired value. List indices can also be
|
161 |
+
passed here.
|
162 |
+
splitval : str
|
163 |
+
String that defines the delimiter between keys of the
|
164 |
+
different depth levels in `key`.
|
165 |
+
default : obj
|
166 |
+
Value returned if :attr:`key` is not found.
|
167 |
+
expand : bool
|
168 |
+
Whether to expand callable nodes on the path or not.
|
169 |
+
|
170 |
+
Returns
|
171 |
+
-------
|
172 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
173 |
+
:attr:`key` is not found returns ``default``.
|
174 |
+
|
175 |
+
Raises
|
176 |
+
------
|
177 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
178 |
+
``None``.
|
179 |
+
"""
|
180 |
+
|
181 |
+
keys = key.split(splitval)
|
182 |
+
|
183 |
+
success = True
|
184 |
+
try:
|
185 |
+
visited = []
|
186 |
+
parent = None
|
187 |
+
last_key = None
|
188 |
+
for key in keys:
|
189 |
+
if callable(list_or_dict):
|
190 |
+
if not expand:
|
191 |
+
raise KeyNotFoundError(
|
192 |
+
ValueError(
|
193 |
+
"Trying to get past callable node with expand=False."
|
194 |
+
),
|
195 |
+
keys=keys,
|
196 |
+
visited=visited,
|
197 |
+
)
|
198 |
+
list_or_dict = list_or_dict()
|
199 |
+
parent[last_key] = list_or_dict
|
200 |
+
|
201 |
+
last_key = key
|
202 |
+
parent = list_or_dict
|
203 |
+
|
204 |
+
try:
|
205 |
+
if isinstance(list_or_dict, dict):
|
206 |
+
list_or_dict = list_or_dict[key]
|
207 |
+
else:
|
208 |
+
list_or_dict = list_or_dict[int(key)]
|
209 |
+
except (KeyError, IndexError, ValueError) as e:
|
210 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
211 |
+
|
212 |
+
visited += [key]
|
213 |
+
# final expansion of retrieved value
|
214 |
+
if expand and callable(list_or_dict):
|
215 |
+
list_or_dict = list_or_dict()
|
216 |
+
parent[last_key] = list_or_dict
|
217 |
+
except KeyNotFoundError as e:
|
218 |
+
if default is None:
|
219 |
+
raise e
|
220 |
+
else:
|
221 |
+
list_or_dict = default
|
222 |
+
success = False
|
223 |
+
|
224 |
+
if not pass_success:
|
225 |
+
return list_or_dict
|
226 |
+
else:
|
227 |
+
return list_or_dict, success
|
228 |
+
|
229 |
+
|
230 |
+
def to_device(data, device):
|
231 |
+
if len(data) == 12:
|
232 |
+
(
|
233 |
+
ids,
|
234 |
+
raw_texts,
|
235 |
+
speakers,
|
236 |
+
texts,
|
237 |
+
src_lens,
|
238 |
+
max_src_len,
|
239 |
+
mels,
|
240 |
+
mel_lens,
|
241 |
+
max_mel_len,
|
242 |
+
pitches,
|
243 |
+
energies,
|
244 |
+
durations,
|
245 |
+
) = data
|
246 |
+
|
247 |
+
speakers = torch.from_numpy(speakers).long().to(device)
|
248 |
+
texts = torch.from_numpy(texts).long().to(device)
|
249 |
+
src_lens = torch.from_numpy(src_lens).to(device)
|
250 |
+
mels = torch.from_numpy(mels).float().to(device)
|
251 |
+
mel_lens = torch.from_numpy(mel_lens).to(device)
|
252 |
+
pitches = torch.from_numpy(pitches).float().to(device)
|
253 |
+
energies = torch.from_numpy(energies).to(device)
|
254 |
+
durations = torch.from_numpy(durations).long().to(device)
|
255 |
+
|
256 |
+
return (
|
257 |
+
ids,
|
258 |
+
raw_texts,
|
259 |
+
speakers,
|
260 |
+
texts,
|
261 |
+
src_lens,
|
262 |
+
max_src_len,
|
263 |
+
mels,
|
264 |
+
mel_lens,
|
265 |
+
max_mel_len,
|
266 |
+
pitches,
|
267 |
+
energies,
|
268 |
+
durations,
|
269 |
+
)
|
270 |
+
|
271 |
+
if len(data) == 6:
|
272 |
+
(ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
|
273 |
+
|
274 |
+
speakers = torch.from_numpy(speakers).long().to(device)
|
275 |
+
texts = torch.from_numpy(texts).long().to(device)
|
276 |
+
src_lens = torch.from_numpy(src_lens).to(device)
|
277 |
+
|
278 |
+
return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
|
279 |
+
|
280 |
+
|
281 |
+
def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
|
282 |
+
# if losses is not None:
|
283 |
+
# logger.add_scalar("Loss/total_loss", losses[0], step)
|
284 |
+
# logger.add_scalar("Loss/mel_loss", losses[1], step)
|
285 |
+
# logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
|
286 |
+
# logger.add_scalar("Loss/pitch_loss", losses[3], step)
|
287 |
+
# logger.add_scalar("Loss/energy_loss", losses[4], step)
|
288 |
+
# logger.add_scalar("Loss/duration_loss", losses[5], step)
|
289 |
+
# if(len(losses) > 6):
|
290 |
+
# logger.add_scalar("Loss/disc_loss", losses[6], step)
|
291 |
+
# logger.add_scalar("Loss/fmap_loss", losses[7], step)
|
292 |
+
# logger.add_scalar("Loss/r_loss", losses[8], step)
|
293 |
+
# logger.add_scalar("Loss/g_loss", losses[9], step)
|
294 |
+
# logger.add_scalar("Loss/gen_loss", losses[10], step)
|
295 |
+
# logger.add_scalar("Loss/diff_loss", losses[11], step)
|
296 |
+
|
297 |
+
if fig is not None:
|
298 |
+
logger.add_figure(tag, fig)
|
299 |
+
|
300 |
+
if audio is not None:
|
301 |
+
audio = audio / (max(abs(audio)) * 1.1)
|
302 |
+
logger.add_audio(
|
303 |
+
tag,
|
304 |
+
audio,
|
305 |
+
sample_rate=sampling_rate,
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
310 |
+
batch_size = lengths.shape[0]
|
311 |
+
if max_len is None:
|
312 |
+
max_len = torch.max(lengths).item()
|
313 |
+
|
314 |
+
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
|
315 |
+
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
316 |
+
|
317 |
+
return mask
|
318 |
+
|
319 |
+
|
320 |
+
def expand(values, durations):
|
321 |
+
out = list()
|
322 |
+
for value, d in zip(values, durations):
|
323 |
+
out += [value] * max(0, int(d))
|
324 |
+
return np.array(out)
|
325 |
+
|
326 |
+
|
327 |
+
def synth_one_sample_val(
|
328 |
+
targets, predictions, vocoder, model_config, preprocess_config
|
329 |
+
):
|
330 |
+
index = np.random.choice(list(np.arange(targets[6].size(0))))
|
331 |
+
|
332 |
+
basename = targets[0][index]
|
333 |
+
src_len = predictions[8][index].item()
|
334 |
+
mel_len = predictions[9][index].item()
|
335 |
+
mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
|
336 |
+
|
337 |
+
mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
|
338 |
+
postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
|
339 |
+
duration = targets[11][index, :src_len].detach().cpu().numpy()
|
340 |
+
|
341 |
+
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
|
342 |
+
pitch = predictions[2][index, :src_len].detach().cpu().numpy()
|
343 |
+
pitch = expand(pitch, duration)
|
344 |
+
else:
|
345 |
+
pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
|
346 |
+
|
347 |
+
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
|
348 |
+
energy = predictions[3][index, :src_len].detach().cpu().numpy()
|
349 |
+
energy = expand(energy, duration)
|
350 |
+
else:
|
351 |
+
energy = predictions[3][index, :mel_len].detach().cpu().numpy()
|
352 |
+
|
353 |
+
with open(
|
354 |
+
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
|
355 |
+
) as f:
|
356 |
+
stats = json.load(f)
|
357 |
+
stats = stats["pitch"] + stats["energy"][:2]
|
358 |
+
|
359 |
+
# from datetime import datetime
|
360 |
+
# now = datetime.now()
|
361 |
+
# current_time = now.strftime("%D:%H:%M:%S")
|
362 |
+
# np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
|
363 |
+
# np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
|
364 |
+
# np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
|
365 |
+
|
366 |
+
fig = plot_mel(
|
367 |
+
[
|
368 |
+
(mel_prediction.cpu().numpy(), pitch, energy),
|
369 |
+
(postnet_mel_prediction.cpu().numpy(), pitch, energy),
|
370 |
+
(mel_target.cpu().numpy(), pitch, energy),
|
371 |
+
],
|
372 |
+
stats,
|
373 |
+
[
|
374 |
+
"Raw mel spectrogram prediction",
|
375 |
+
"Postnet mel prediction",
|
376 |
+
"Ground-Truth Spectrogram",
|
377 |
+
],
|
378 |
+
)
|
379 |
+
|
380 |
+
if vocoder is not None:
|
381 |
+
from .model import vocoder_infer
|
382 |
+
|
383 |
+
wav_reconstruction = vocoder_infer(
|
384 |
+
mel_target.unsqueeze(0),
|
385 |
+
vocoder,
|
386 |
+
model_config,
|
387 |
+
preprocess_config,
|
388 |
+
)[0]
|
389 |
+
wav_prediction = vocoder_infer(
|
390 |
+
postnet_mel_prediction.unsqueeze(0),
|
391 |
+
vocoder,
|
392 |
+
model_config,
|
393 |
+
preprocess_config,
|
394 |
+
)[0]
|
395 |
+
else:
|
396 |
+
wav_reconstruction = wav_prediction = None
|
397 |
+
|
398 |
+
return fig, wav_reconstruction, wav_prediction, basename
|
399 |
+
|
400 |
+
|
401 |
+
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
|
402 |
+
if vocoder is not None:
|
403 |
+
from .model import vocoder_infer
|
404 |
+
|
405 |
+
wav_reconstruction = vocoder_infer(
|
406 |
+
mel_input.permute(0, 2, 1),
|
407 |
+
vocoder,
|
408 |
+
)
|
409 |
+
wav_prediction = vocoder_infer(
|
410 |
+
mel_prediction.permute(0, 2, 1),
|
411 |
+
vocoder,
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
wav_reconstruction = wav_prediction = None
|
415 |
+
|
416 |
+
return wav_reconstruction, wav_prediction
|
417 |
+
|
418 |
+
|
419 |
+
def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
|
420 |
+
# (diff_output, diff_loss, latent_loss) = diffusion
|
421 |
+
|
422 |
+
basenames = targets[0]
|
423 |
+
|
424 |
+
for i in range(len(predictions[1])):
|
425 |
+
basename = basenames[i]
|
426 |
+
src_len = predictions[8][i].item()
|
427 |
+
mel_len = predictions[9][i].item()
|
428 |
+
mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
|
429 |
+
# diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
|
430 |
+
# duration = predictions[5][i, :src_len].detach().cpu().numpy()
|
431 |
+
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
|
432 |
+
pitch = predictions[2][i, :src_len].detach().cpu().numpy()
|
433 |
+
# pitch = expand(pitch, duration)
|
434 |
+
else:
|
435 |
+
pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
|
436 |
+
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
|
437 |
+
energy = predictions[3][i, :src_len].detach().cpu().numpy()
|
438 |
+
# energy = expand(energy, duration)
|
439 |
+
else:
|
440 |
+
energy = predictions[3][i, :mel_len].detach().cpu().numpy()
|
441 |
+
# import ipdb; ipdb.set_trace()
|
442 |
+
with open(
|
443 |
+
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
|
444 |
+
) as f:
|
445 |
+
stats = json.load(f)
|
446 |
+
stats = stats["pitch"] + stats["energy"][:2]
|
447 |
+
|
448 |
+
fig = plot_mel(
|
449 |
+
[
|
450 |
+
(mel_prediction.cpu().numpy(), pitch, energy),
|
451 |
+
],
|
452 |
+
stats,
|
453 |
+
["Synthetized Spectrogram by PostNet"],
|
454 |
+
)
|
455 |
+
# np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
|
456 |
+
plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
|
457 |
+
plt.close()
|
458 |
+
|
459 |
+
from .model import vocoder_infer
|
460 |
+
|
461 |
+
mel_predictions = predictions[1].transpose(1, 2)
|
462 |
+
lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
|
463 |
+
wav_predictions = vocoder_infer(
|
464 |
+
mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
|
465 |
+
)
|
466 |
+
|
467 |
+
sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
|
468 |
+
for wav, basename in zip(wav_predictions, basenames):
|
469 |
+
wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
|
470 |
+
|
471 |
+
|
472 |
+
def plot_mel(data, titles=None):
|
473 |
+
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
474 |
+
if titles is None:
|
475 |
+
titles = [None for i in range(len(data))]
|
476 |
+
|
477 |
+
for i in range(len(data)):
|
478 |
+
mel = data[i]
|
479 |
+
axes[i][0].imshow(mel, origin="lower", aspect="auto")
|
480 |
+
axes[i][0].set_aspect(2.5, adjustable="box")
|
481 |
+
axes[i][0].set_ylim(0, mel.shape[0])
|
482 |
+
axes[i][0].set_title(titles[i], fontsize="medium")
|
483 |
+
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
484 |
+
axes[i][0].set_anchor("W")
|
485 |
+
|
486 |
+
return fig
|
487 |
+
|
488 |
+
|
489 |
+
def pad_1D(inputs, PAD=0):
|
490 |
+
def pad_data(x, length, PAD):
|
491 |
+
x_padded = np.pad(
|
492 |
+
x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
|
493 |
+
)
|
494 |
+
return x_padded
|
495 |
+
|
496 |
+
max_len = max((len(x) for x in inputs))
|
497 |
+
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
|
498 |
+
|
499 |
+
return padded
|
500 |
+
|
501 |
+
|
502 |
+
def pad_2D(inputs, maxlen=None):
|
503 |
+
def pad(x, max_len):
|
504 |
+
PAD = 0
|
505 |
+
if np.shape(x)[0] > max_len:
|
506 |
+
raise ValueError("not max_len")
|
507 |
+
|
508 |
+
s = np.shape(x)[1]
|
509 |
+
x_padded = np.pad(
|
510 |
+
x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
|
511 |
+
)
|
512 |
+
return x_padded[:, :s]
|
513 |
+
|
514 |
+
if maxlen:
|
515 |
+
output = np.stack([pad(x, maxlen) for x in inputs])
|
516 |
+
else:
|
517 |
+
max_len = max(np.shape(x)[0] for x in inputs)
|
518 |
+
output = np.stack([pad(x, max_len) for x in inputs])
|
519 |
+
|
520 |
+
return output
|
521 |
+
|
522 |
+
|
523 |
+
def pad(input_ele, mel_max_length=None):
|
524 |
+
if mel_max_length:
|
525 |
+
max_len = mel_max_length
|
526 |
+
else:
|
527 |
+
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
|
528 |
+
|
529 |
+
out_list = list()
|
530 |
+
for i, batch in enumerate(input_ele):
|
531 |
+
if len(batch.shape) == 1:
|
532 |
+
one_batch_padded = F.pad(
|
533 |
+
batch, (0, max_len - batch.size(0)), "constant", 0.0
|
534 |
+
)
|
535 |
+
elif len(batch.shape) == 2:
|
536 |
+
one_batch_padded = F.pad(
|
537 |
+
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
|
538 |
+
)
|
539 |
+
out_list.append(one_batch_padded)
|
540 |
+
out_padded = torch.stack(out_list)
|
541 |
+
return out_padded
|