jadechoghari
commited on
Create tools.py
Browse files
tools.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Haohe Liu
|
2 |
+
# Email: [email protected]
|
3 |
+
# Date: 11 Feb 2023
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib
|
12 |
+
from scipy.io import wavfile
|
13 |
+
from matplotlib import pyplot as plt
|
14 |
+
|
15 |
+
matplotlib.use("Agg")
|
16 |
+
|
17 |
+
import hashlib
|
18 |
+
import os
|
19 |
+
|
20 |
+
import requests
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
URL_MAP = {
|
24 |
+
"vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
|
25 |
+
"vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
|
26 |
+
"melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
|
27 |
+
}
|
28 |
+
|
29 |
+
CKPT_MAP = {
|
30 |
+
"vggishish_lpaps": "vggishish16.pt",
|
31 |
+
"vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
|
32 |
+
"melception": "melception-21-05-10T09-28-40.pt",
|
33 |
+
}
|
34 |
+
|
35 |
+
MD5_MAP = {
|
36 |
+
"vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
|
37 |
+
"vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
|
38 |
+
"melception": "a71a41041e945b457c7d3d814bbcf72d",
|
39 |
+
}
|
40 |
+
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
|
43 |
+
|
44 |
+
def read_list(fname):
|
45 |
+
result = []
|
46 |
+
with open(fname, "r") as f:
|
47 |
+
for each in f.readlines():
|
48 |
+
each = each.strip("\n")
|
49 |
+
result.append(each)
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def build_dataset_json_from_list(list_path):
|
54 |
+
data = []
|
55 |
+
for each in read_list(list_path):
|
56 |
+
if "|" in each:
|
57 |
+
wav, caption = each.split("|")
|
58 |
+
else:
|
59 |
+
caption = each
|
60 |
+
wav = ""
|
61 |
+
data.append(
|
62 |
+
{
|
63 |
+
"wav": wav,
|
64 |
+
"caption": caption,
|
65 |
+
}
|
66 |
+
)
|
67 |
+
return {"data": data}
|
68 |
+
|
69 |
+
|
70 |
+
def load_json(fname):
|
71 |
+
with open(fname, "r") as f:
|
72 |
+
data = json.load(f)
|
73 |
+
return data
|
74 |
+
|
75 |
+
|
76 |
+
def read_json(dataset_json_file):
|
77 |
+
with open(dataset_json_file, "r") as fp:
|
78 |
+
data_json = json.load(fp)
|
79 |
+
return data_json["data"]
|
80 |
+
|
81 |
+
|
82 |
+
def copy_test_subset_data(metadata, testset_copy_target_path):
|
83 |
+
# metadata = read_json(testset_metadata)
|
84 |
+
os.makedirs(testset_copy_target_path, exist_ok=True)
|
85 |
+
if len(os.listdir(testset_copy_target_path)) == len(metadata):
|
86 |
+
return
|
87 |
+
else:
|
88 |
+
# delete files in folder testset_copy_target_path
|
89 |
+
for file in os.listdir(testset_copy_target_path):
|
90 |
+
try:
|
91 |
+
os.remove(os.path.join(testset_copy_target_path, file))
|
92 |
+
except Exception as e:
|
93 |
+
print(e)
|
94 |
+
|
95 |
+
print("Copying test subset data to {}".format(testset_copy_target_path))
|
96 |
+
for each in tqdm(metadata):
|
97 |
+
cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
|
98 |
+
os.system(cmd)
|
99 |
+
|
100 |
+
|
101 |
+
def listdir_nohidden(path):
|
102 |
+
for f in os.listdir(path):
|
103 |
+
if not f.startswith("."):
|
104 |
+
yield f
|
105 |
+
|
106 |
+
|
107 |
+
def get_restore_step(path):
|
108 |
+
checkpoints = os.listdir(path)
|
109 |
+
if os.path.exists(os.path.join(path, "final.ckpt")):
|
110 |
+
return "final.ckpt", 0
|
111 |
+
elif not os.path.exists(os.path.join(path, "last.ckpt")):
|
112 |
+
steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
|
113 |
+
return checkpoints[np.argmax(steps)], np.max(steps)
|
114 |
+
else:
|
115 |
+
steps = []
|
116 |
+
for x in checkpoints:
|
117 |
+
if "last" in x:
|
118 |
+
if "-v" not in x:
|
119 |
+
fname = "last.ckpt"
|
120 |
+
else:
|
121 |
+
this_version = int(x.split(".ckpt")[0].split("-v")[1])
|
122 |
+
steps.append(this_version)
|
123 |
+
if len(steps) == 0 or this_version > np.max(steps):
|
124 |
+
fname = "last-v%s.ckpt" % this_version
|
125 |
+
return fname, 0
|
126 |
+
|
127 |
+
|
128 |
+
def download(url, local_path, chunk_size=1024):
|
129 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
130 |
+
with requests.get(url, stream=True) as r:
|
131 |
+
total_size = int(r.headers.get("content-length", 0))
|
132 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
133 |
+
with open(local_path, "wb") as f:
|
134 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
135 |
+
if data:
|
136 |
+
f.write(data)
|
137 |
+
pbar.update(chunk_size)
|
138 |
+
|
139 |
+
|
140 |
+
def md5_hash(path):
|
141 |
+
with open(path, "rb") as f:
|
142 |
+
content = f.read()
|
143 |
+
return hashlib.md5(content).hexdigest()
|
144 |
+
|
145 |
+
|
146 |
+
def get_ckpt_path(name, root, check=False):
|
147 |
+
assert name in URL_MAP
|
148 |
+
path = os.path.join(root, CKPT_MAP[name])
|
149 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
150 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
151 |
+
download(URL_MAP[name], path)
|
152 |
+
md5 = md5_hash(path)
|
153 |
+
assert md5 == MD5_MAP[name], md5
|
154 |
+
return path
|
155 |
+
|
156 |
+
|
157 |
+
class KeyNotFoundError(Exception):
|
158 |
+
def __init__(self, cause, keys=None, visited=None):
|
159 |
+
self.cause = cause
|
160 |
+
self.keys = keys
|
161 |
+
self.visited = visited
|
162 |
+
messages = list()
|
163 |
+
if keys is not None:
|
164 |
+
messages.append("Key not found: {}".format(keys))
|
165 |
+
if visited is not None:
|
166 |
+
messages.append("Visited: {}".format(visited))
|
167 |
+
messages.append("Cause:\n{}".format(cause))
|
168 |
+
message = "\n".join(messages)
|
169 |
+
super().__init__(message)
|
170 |
+
|
171 |
+
|
172 |
+
def retrieve(
|
173 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
174 |
+
):
|
175 |
+
"""Given a nested list or dict return the desired value at key expanding
|
176 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
177 |
+
is done in-place.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
list_or_dict : list or dict
|
182 |
+
Possibly nested list or dictionary.
|
183 |
+
key : str
|
184 |
+
key/to/value, path like string describing all keys necessary to
|
185 |
+
consider to get to the desired value. List indices can also be
|
186 |
+
passed here.
|
187 |
+
splitval : str
|
188 |
+
String that defines the delimiter between keys of the
|
189 |
+
different depth levels in `key`.
|
190 |
+
default : obj
|
191 |
+
Value returned if :attr:`key` is not found.
|
192 |
+
expand : bool
|
193 |
+
Whether to expand callable nodes on the path or not.
|
194 |
+
|
195 |
+
Returns
|
196 |
+
-------
|
197 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
198 |
+
:attr:`key` is not found returns ``default``.
|
199 |
+
|
200 |
+
Raises
|
201 |
+
------
|
202 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
203 |
+
``None``.
|
204 |
+
"""
|
205 |
+
|
206 |
+
keys = key.split(splitval)
|
207 |
+
|
208 |
+
success = True
|
209 |
+
try:
|
210 |
+
visited = []
|
211 |
+
parent = None
|
212 |
+
last_key = None
|
213 |
+
for key in keys:
|
214 |
+
if callable(list_or_dict):
|
215 |
+
if not expand:
|
216 |
+
raise KeyNotFoundError(
|
217 |
+
ValueError(
|
218 |
+
"Trying to get past callable node with expand=False."
|
219 |
+
),
|
220 |
+
keys=keys,
|
221 |
+
visited=visited,
|
222 |
+
)
|
223 |
+
list_or_dict = list_or_dict()
|
224 |
+
parent[last_key] = list_or_dict
|
225 |
+
|
226 |
+
last_key = key
|
227 |
+
parent = list_or_dict
|
228 |
+
|
229 |
+
try:
|
230 |
+
if isinstance(list_or_dict, dict):
|
231 |
+
list_or_dict = list_or_dict[key]
|
232 |
+
else:
|
233 |
+
list_or_dict = list_or_dict[int(key)]
|
234 |
+
except (KeyError, IndexError, ValueError) as e:
|
235 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
236 |
+
|
237 |
+
visited += [key]
|
238 |
+
# final expansion of retrieved value
|
239 |
+
if expand and callable(list_or_dict):
|
240 |
+
list_or_dict = list_or_dict()
|
241 |
+
parent[last_key] = list_or_dict
|
242 |
+
except KeyNotFoundError as e:
|
243 |
+
if default is None:
|
244 |
+
raise e
|
245 |
+
else:
|
246 |
+
list_or_dict = default
|
247 |
+
success = False
|
248 |
+
|
249 |
+
if not pass_success:
|
250 |
+
return list_or_dict
|
251 |
+
else:
|
252 |
+
return list_or_dict, success
|
253 |
+
|
254 |
+
|
255 |
+
def to_device(data, device):
|
256 |
+
if len(data) == 12:
|
257 |
+
(
|
258 |
+
ids,
|
259 |
+
raw_texts,
|
260 |
+
speakers,
|
261 |
+
texts,
|
262 |
+
src_lens,
|
263 |
+
max_src_len,
|
264 |
+
mels,
|
265 |
+
mel_lens,
|
266 |
+
max_mel_len,
|
267 |
+
pitches,
|
268 |
+
energies,
|
269 |
+
durations,
|
270 |
+
) = data
|
271 |
+
|
272 |
+
speakers = torch.from_numpy(speakers).long().to(device)
|
273 |
+
texts = torch.from_numpy(texts).long().to(device)
|
274 |
+
src_lens = torch.from_numpy(src_lens).to(device)
|
275 |
+
mels = torch.from_numpy(mels).float().to(device)
|
276 |
+
mel_lens = torch.from_numpy(mel_lens).to(device)
|
277 |
+
pitches = torch.from_numpy(pitches).float().to(device)
|
278 |
+
energies = torch.from_numpy(energies).to(device)
|
279 |
+
durations = torch.from_numpy(durations).long().to(device)
|
280 |
+
|
281 |
+
return (
|
282 |
+
ids,
|
283 |
+
raw_texts,
|
284 |
+
speakers,
|
285 |
+
texts,
|
286 |
+
src_lens,
|
287 |
+
max_src_len,
|
288 |
+
mels,
|
289 |
+
mel_lens,
|
290 |
+
max_mel_len,
|
291 |
+
pitches,
|
292 |
+
energies,
|
293 |
+
durations,
|
294 |
+
)
|
295 |
+
|
296 |
+
if len(data) == 6:
|
297 |
+
(ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
|
298 |
+
|
299 |
+
speakers = torch.from_numpy(speakers).long().to(device)
|
300 |
+
texts = torch.from_numpy(texts).long().to(device)
|
301 |
+
src_lens = torch.from_numpy(src_lens).to(device)
|
302 |
+
|
303 |
+
return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
|
304 |
+
|
305 |
+
|
306 |
+
def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
|
307 |
+
# if losses is not None:
|
308 |
+
# logger.add_scalar("Loss/total_loss", losses[0], step)
|
309 |
+
# logger.add_scalar("Loss/mel_loss", losses[1], step)
|
310 |
+
# logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
|
311 |
+
# logger.add_scalar("Loss/pitch_loss", losses[3], step)
|
312 |
+
# logger.add_scalar("Loss/energy_loss", losses[4], step)
|
313 |
+
# logger.add_scalar("Loss/duration_loss", losses[5], step)
|
314 |
+
# if(len(losses) > 6):
|
315 |
+
# logger.add_scalar("Loss/disc_loss", losses[6], step)
|
316 |
+
# logger.add_scalar("Loss/fmap_loss", losses[7], step)
|
317 |
+
# logger.add_scalar("Loss/r_loss", losses[8], step)
|
318 |
+
# logger.add_scalar("Loss/g_loss", losses[9], step)
|
319 |
+
# logger.add_scalar("Loss/gen_loss", losses[10], step)
|
320 |
+
# logger.add_scalar("Loss/diff_loss", losses[11], step)
|
321 |
+
|
322 |
+
if fig is not None:
|
323 |
+
logger.add_figure(tag, fig)
|
324 |
+
|
325 |
+
if audio is not None:
|
326 |
+
audio = audio / (max(abs(audio)) * 1.1)
|
327 |
+
logger.add_audio(
|
328 |
+
tag,
|
329 |
+
audio,
|
330 |
+
sample_rate=sampling_rate,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
335 |
+
batch_size = lengths.shape[0]
|
336 |
+
if max_len is None:
|
337 |
+
max_len = torch.max(lengths).item()
|
338 |
+
|
339 |
+
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
|
340 |
+
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
341 |
+
|
342 |
+
return mask
|
343 |
+
|
344 |
+
|
345 |
+
def expand(values, durations):
|
346 |
+
out = list()
|
347 |
+
for value, d in zip(values, durations):
|
348 |
+
out += [value] * max(0, int(d))
|
349 |
+
return np.array(out)
|
350 |
+
|
351 |
+
|
352 |
+
def synth_one_sample_val(
|
353 |
+
targets, predictions, vocoder, model_config, preprocess_config
|
354 |
+
):
|
355 |
+
index = np.random.choice(list(np.arange(targets[6].size(0))))
|
356 |
+
|
357 |
+
basename = targets[0][index]
|
358 |
+
src_len = predictions[8][index].item()
|
359 |
+
mel_len = predictions[9][index].item()
|
360 |
+
mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
|
361 |
+
|
362 |
+
mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
|
363 |
+
postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
|
364 |
+
duration = targets[11][index, :src_len].detach().cpu().numpy()
|
365 |
+
|
366 |
+
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
|
367 |
+
pitch = predictions[2][index, :src_len].detach().cpu().numpy()
|
368 |
+
pitch = expand(pitch, duration)
|
369 |
+
else:
|
370 |
+
pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
|
371 |
+
|
372 |
+
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
|
373 |
+
energy = predictions[3][index, :src_len].detach().cpu().numpy()
|
374 |
+
energy = expand(energy, duration)
|
375 |
+
else:
|
376 |
+
energy = predictions[3][index, :mel_len].detach().cpu().numpy()
|
377 |
+
|
378 |
+
with open(
|
379 |
+
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
|
380 |
+
) as f:
|
381 |
+
stats = json.load(f)
|
382 |
+
stats = stats["pitch"] + stats["energy"][:2]
|
383 |
+
|
384 |
+
# from datetime import datetime
|
385 |
+
# now = datetime.now()
|
386 |
+
# current_time = now.strftime("%D:%H:%M:%S")
|
387 |
+
# np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
|
388 |
+
# np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
|
389 |
+
# np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
|
390 |
+
|
391 |
+
fig = plot_mel(
|
392 |
+
[
|
393 |
+
(mel_prediction.cpu().numpy(), pitch, energy),
|
394 |
+
(postnet_mel_prediction.cpu().numpy(), pitch, energy),
|
395 |
+
(mel_target.cpu().numpy(), pitch, energy),
|
396 |
+
],
|
397 |
+
stats,
|
398 |
+
[
|
399 |
+
"Raw mel spectrogram prediction",
|
400 |
+
"Postnet mel prediction",
|
401 |
+
"Ground-Truth Spectrogram",
|
402 |
+
],
|
403 |
+
)
|
404 |
+
|
405 |
+
if vocoder is not None:
|
406 |
+
from .model_util import vocoder_infer
|
407 |
+
|
408 |
+
wav_reconstruction = vocoder_infer(
|
409 |
+
mel_target.unsqueeze(0),
|
410 |
+
vocoder,
|
411 |
+
model_config,
|
412 |
+
preprocess_config,
|
413 |
+
)[0]
|
414 |
+
wav_prediction = vocoder_infer(
|
415 |
+
postnet_mel_prediction.unsqueeze(0),
|
416 |
+
vocoder,
|
417 |
+
model_config,
|
418 |
+
preprocess_config,
|
419 |
+
)[0]
|
420 |
+
else:
|
421 |
+
wav_reconstruction = wav_prediction = None
|
422 |
+
|
423 |
+
return fig, wav_reconstruction, wav_prediction, basename
|
424 |
+
|
425 |
+
|
426 |
+
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
|
427 |
+
if vocoder is not None:
|
428 |
+
from .model_util import vocoder_infer
|
429 |
+
|
430 |
+
wav_reconstruction = vocoder_infer(
|
431 |
+
mel_input.permute(0, 2, 1),
|
432 |
+
vocoder,
|
433 |
+
)
|
434 |
+
wav_prediction = vocoder_infer(
|
435 |
+
mel_prediction.permute(0, 2, 1),
|
436 |
+
vocoder,
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
wav_reconstruction = wav_prediction = None
|
440 |
+
|
441 |
+
return wav_reconstruction, wav_prediction
|
442 |
+
|
443 |
+
|
444 |
+
def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
|
445 |
+
# (diff_output, diff_loss, latent_loss) = diffusion
|
446 |
+
|
447 |
+
basenames = targets[0]
|
448 |
+
|
449 |
+
for i in range(len(predictions[1])):
|
450 |
+
basename = basenames[i]
|
451 |
+
src_len = predictions[8][i].item()
|
452 |
+
mel_len = predictions[9][i].item()
|
453 |
+
mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
|
454 |
+
# diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
|
455 |
+
# duration = predictions[5][i, :src_len].detach().cpu().numpy()
|
456 |
+
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
|
457 |
+
pitch = predictions[2][i, :src_len].detach().cpu().numpy()
|
458 |
+
# pitch = expand(pitch, duration)
|
459 |
+
else:
|
460 |
+
pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
|
461 |
+
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
|
462 |
+
energy = predictions[3][i, :src_len].detach().cpu().numpy()
|
463 |
+
# energy = expand(energy, duration)
|
464 |
+
else:
|
465 |
+
energy = predictions[3][i, :mel_len].detach().cpu().numpy()
|
466 |
+
# import ipdb; ipdb.set_trace()
|
467 |
+
with open(
|
468 |
+
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
|
469 |
+
) as f:
|
470 |
+
stats = json.load(f)
|
471 |
+
stats = stats["pitch"] + stats["energy"][:2]
|
472 |
+
|
473 |
+
fig = plot_mel(
|
474 |
+
[
|
475 |
+
(mel_prediction.cpu().numpy(), pitch, energy),
|
476 |
+
],
|
477 |
+
stats,
|
478 |
+
["Synthetized Spectrogram by PostNet"],
|
479 |
+
)
|
480 |
+
# np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
|
481 |
+
plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
|
482 |
+
plt.close()
|
483 |
+
|
484 |
+
from .model_util import vocoder_infer
|
485 |
+
|
486 |
+
mel_predictions = predictions[1].transpose(1, 2)
|
487 |
+
lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
|
488 |
+
wav_predictions = vocoder_infer(
|
489 |
+
mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
|
490 |
+
)
|
491 |
+
|
492 |
+
sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
|
493 |
+
for wav, basename in zip(wav_predictions, basenames):
|
494 |
+
wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
|
495 |
+
|
496 |
+
|
497 |
+
def plot_mel(data, titles=None):
|
498 |
+
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
499 |
+
if titles is None:
|
500 |
+
titles = [None for i in range(len(data))]
|
501 |
+
|
502 |
+
for i in range(len(data)):
|
503 |
+
mel = data[i]
|
504 |
+
axes[i][0].imshow(mel, origin="lower", aspect="auto")
|
505 |
+
axes[i][0].set_aspect(2.5, adjustable="box")
|
506 |
+
axes[i][0].set_ylim(0, mel.shape[0])
|
507 |
+
axes[i][0].set_title(titles[i], fontsize="medium")
|
508 |
+
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
509 |
+
axes[i][0].set_anchor("W")
|
510 |
+
|
511 |
+
return fig
|
512 |
+
|
513 |
+
|
514 |
+
def pad_1D(inputs, PAD=0):
|
515 |
+
def pad_data(x, length, PAD):
|
516 |
+
x_padded = np.pad(
|
517 |
+
x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
|
518 |
+
)
|
519 |
+
return x_padded
|
520 |
+
|
521 |
+
max_len = max((len(x) for x in inputs))
|
522 |
+
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
|
523 |
+
|
524 |
+
return padded
|
525 |
+
|
526 |
+
|
527 |
+
def pad_2D(inputs, maxlen=None):
|
528 |
+
def pad(x, max_len):
|
529 |
+
PAD = 0
|
530 |
+
if np.shape(x)[0] > max_len:
|
531 |
+
raise ValueError("not max_len")
|
532 |
+
|
533 |
+
s = np.shape(x)[1]
|
534 |
+
x_padded = np.pad(
|
535 |
+
x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
|
536 |
+
)
|
537 |
+
return x_padded[:, :s]
|
538 |
+
|
539 |
+
if maxlen:
|
540 |
+
output = np.stack([pad(x, maxlen) for x in inputs])
|
541 |
+
else:
|
542 |
+
max_len = max(np.shape(x)[0] for x in inputs)
|
543 |
+
output = np.stack([pad(x, max_len) for x in inputs])
|
544 |
+
|
545 |
+
return output
|
546 |
+
|
547 |
+
|
548 |
+
def pad(input_ele, mel_max_length=None):
|
549 |
+
if mel_max_length:
|
550 |
+
max_len = mel_max_length
|
551 |
+
else:
|
552 |
+
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
|
553 |
+
|
554 |
+
out_list = list()
|
555 |
+
for i, batch in enumerate(input_ele):
|
556 |
+
if len(batch.shape) == 1:
|
557 |
+
one_batch_padded = F.pad(
|
558 |
+
batch, (0, max_len - batch.size(0)), "constant", 0.0
|
559 |
+
)
|
560 |
+
elif len(batch.shape) == 2:
|
561 |
+
one_batch_padded = F.pad(
|
562 |
+
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
|
563 |
+
)
|
564 |
+
out_list.append(one_batch_padded)
|
565 |
+
out_padded = torch.stack(out_list)
|
566 |
+
return out_padded
|