taesiri commited on
Commit
ac69117
·
1 Parent(s): d97d4d2
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -173,7 +173,9 @@ def visualize_attention(
173
 
174
  # Convert plot to image
175
  fig.canvas.draw()
176
- vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.buffer_rgba())
 
 
177
  visualizations.append(vis_image)
178
  plt.close(fig)
179
 
@@ -200,11 +202,14 @@ def visualize_attention(
200
 
201
  # Convert plot to image
202
  fig.canvas.draw()
203
- rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.buffer_rgba())
 
 
204
  plt.close(fig)
205
 
206
  return visualizations, rollout_image
207
 
 
208
  # Create Gradio interface
209
  iface = gr.Interface(
210
  fn=visualize_attention,
 
173
 
174
  # Convert plot to image
175
  fig.canvas.draw()
176
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
177
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
178
+ vis_image = Image.fromarray(data)
179
  visualizations.append(vis_image)
180
  plt.close(fig)
181
 
 
202
 
203
  # Convert plot to image
204
  fig.canvas.draw()
205
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
206
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
207
+ rollout_image = Image.fromarray(data)
208
  plt.close(fig)
209
 
210
  return visualizations, rollout_image
211
 
212
+
213
  # Create Gradio interface
214
  iface = gr.Interface(
215
  fn=visualize_attention,