Spaces:
Running
on
T4
Running
on
T4
fix some label errors during plotting and close plots
Browse files
InferenceInterfaces/ToucanTTSInterface.py
CHANGED
@@ -177,49 +177,61 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
177 |
pass
|
178 |
|
179 |
if view or return_plot_as_filepath:
|
180 |
-
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
|
221 |
if return_plot_as_filepath:
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
return wave, sr, "tmp.png"
|
224 |
|
225 |
return wave, sr
|
|
|
177 |
pass
|
178 |
|
179 |
if view or return_plot_as_filepath:
|
180 |
+
try:
|
181 |
+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
|
182 |
|
183 |
+
ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
|
184 |
+
ax.yaxis.set_visible(False)
|
185 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
186 |
+
ax.xaxis.grid(True, which='minor')
|
187 |
+
ax.set_xticks(label_positions, minor=False)
|
188 |
+
if input_is_phones:
|
189 |
+
phones = text.replace(" ", "|")
|
190 |
+
else:
|
191 |
+
phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
|
192 |
+
try:
|
193 |
+
ax.set_xticklabels(phones)
|
194 |
+
except IndexError:
|
195 |
+
pass
|
196 |
+
word_boundaries = list()
|
197 |
+
for label_index, phone in enumerate(phones):
|
198 |
+
if phone == "|":
|
199 |
+
word_boundaries.append(label_positions[label_index])
|
200 |
|
201 |
+
try:
|
202 |
+
prev_word_boundary = 0
|
203 |
+
word_label_positions = list()
|
204 |
+
for word_boundary in word_boundaries:
|
205 |
+
word_label_positions.append((word_boundary + prev_word_boundary) / 2)
|
206 |
+
prev_word_boundary = word_boundary
|
207 |
+
word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
|
208 |
|
209 |
+
secondary_ax = ax.secondary_xaxis('bottom')
|
210 |
+
secondary_ax.tick_params(axis="x", direction="out", pad=24)
|
211 |
+
secondary_ax.set_xticks(word_label_positions, minor=False)
|
212 |
+
secondary_ax.set_xticklabels(text.split())
|
213 |
+
secondary_ax.tick_params(axis='x', colors='orange')
|
214 |
+
secondary_ax.xaxis.label.set_color('orange')
|
215 |
+
except ValueError:
|
216 |
+
ax.set_title(text)
|
217 |
+
except IndexError:
|
218 |
+
ax.set_title(text)
|
219 |
+
except RuntimeError:
|
220 |
+
ax.set_title(text)
|
221 |
|
222 |
+
ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
|
223 |
+
ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
|
224 |
+
plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
|
225 |
+
ax.set_aspect("auto")
|
226 |
+
except:
|
227 |
+
pass
|
228 |
|
229 |
if return_plot_as_filepath:
|
230 |
+
try:
|
231 |
+
plt.savefig("tmp.png")
|
232 |
+
plt.close()
|
233 |
+
except:
|
234 |
+
pass
|
235 |
return wave, sr, "tmp.png"
|
236 |
|
237 |
return wave, sr
|