Flux9665 commited on
Commit
4b61798
β€’
1 Parent(s): 99b69ce

fix some label errors during plotting and close plots

Browse files
InferenceInterfaces/ToucanTTSInterface.py CHANGED
@@ -177,49 +177,61 @@ class ToucanTTSInterface(torch.nn.Module):
177
  pass
178
 
179
  if view or return_plot_as_filepath:
180
- fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
 
181
 
182
- ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
183
- ax.yaxis.set_visible(False)
184
- duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
185
- ax.xaxis.grid(True, which='minor')
186
- ax.set_xticks(label_positions, minor=False)
187
- if input_is_phones:
188
- phones = text.replace(" ", "|")
189
- else:
190
- phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
191
- ax.set_xticklabels(phones)
192
- word_boundaries = list()
193
- for label_index, phone in enumerate(phones):
194
- if phone == "|":
195
- word_boundaries.append(label_positions[label_index])
 
 
 
196
 
197
- try:
198
- prev_word_boundary = 0
199
- word_label_positions = list()
200
- for word_boundary in word_boundaries:
201
- word_label_positions.append((word_boundary + prev_word_boundary) / 2)
202
- prev_word_boundary = word_boundary
203
- word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
204
 
205
- secondary_ax = ax.secondary_xaxis('bottom')
206
- secondary_ax.tick_params(axis="x", direction="out", pad=24)
207
- secondary_ax.set_xticks(word_label_positions, minor=False)
208
- secondary_ax.set_xticklabels(text.split())
209
- secondary_ax.tick_params(axis='x', colors='orange')
210
- secondary_ax.xaxis.label.set_color('orange')
211
- except ValueError:
212
- ax.set_title(text)
213
- except IndexError:
214
- ax.set_title(text)
 
 
215
 
216
- ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
217
- ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
218
- plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
219
- ax.set_aspect("auto")
 
 
220
 
221
  if return_plot_as_filepath:
222
- plt.savefig("tmp.png")
 
 
 
 
223
  return wave, sr, "tmp.png"
224
 
225
  return wave, sr
 
177
  pass
178
 
179
  if view or return_plot_as_filepath:
180
+ try:
181
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
182
 
183
+ ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu')
184
+ ax.yaxis.set_visible(False)
185
+ duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
186
+ ax.xaxis.grid(True, which='minor')
187
+ ax.set_xticks(label_positions, minor=False)
188
+ if input_is_phones:
189
+ phones = text.replace(" ", "|")
190
+ else:
191
+ phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
192
+ try:
193
+ ax.set_xticklabels(phones)
194
+ except IndexError:
195
+ pass
196
+ word_boundaries = list()
197
+ for label_index, phone in enumerate(phones):
198
+ if phone == "|":
199
+ word_boundaries.append(label_positions[label_index])
200
 
201
+ try:
202
+ prev_word_boundary = 0
203
+ word_label_positions = list()
204
+ for word_boundary in word_boundaries:
205
+ word_label_positions.append((word_boundary + prev_word_boundary) / 2)
206
+ prev_word_boundary = word_boundary
207
+ word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2)
208
 
209
+ secondary_ax = ax.secondary_xaxis('bottom')
210
+ secondary_ax.tick_params(axis="x", direction="out", pad=24)
211
+ secondary_ax.set_xticks(word_label_positions, minor=False)
212
+ secondary_ax.set_xticklabels(text.split())
213
+ secondary_ax.tick_params(axis='x', colors='orange')
214
+ secondary_ax.xaxis.label.set_color('orange')
215
+ except ValueError:
216
+ ax.set_title(text)
217
+ except IndexError:
218
+ ax.set_title(text)
219
+ except RuntimeError:
220
+ ax.set_title(text)
221
 
222
+ ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5)
223
+ ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0)
224
+ plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0)
225
+ ax.set_aspect("auto")
226
+ except:
227
+ pass
228
 
229
  if return_plot_as_filepath:
230
+ try:
231
+ plt.savefig("tmp.png")
232
+ plt.close()
233
+ except:
234
+ pass
235
  return wave, sr, "tmp.png"
236
 
237
  return wave, sr