Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -173,7 +173,9 @@ def visualize_attention(
|
|
173 |
|
174 |
# Convert plot to image
|
175 |
fig.canvas.draw()
|
176 |
-
|
|
|
|
|
177 |
visualizations.append(vis_image)
|
178 |
plt.close(fig)
|
179 |
|
@@ -200,11 +202,14 @@ def visualize_attention(
|
|
200 |
|
201 |
# Convert plot to image
|
202 |
fig.canvas.draw()
|
203 |
-
|
|
|
|
|
204 |
plt.close(fig)
|
205 |
|
206 |
return visualizations, rollout_image
|
207 |
|
|
|
208 |
# Create Gradio interface
|
209 |
iface = gr.Interface(
|
210 |
fn=visualize_attention,
|
|
|
173 |
|
174 |
# Convert plot to image
|
175 |
fig.canvas.draw()
|
176 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
177 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
178 |
+
vis_image = Image.fromarray(data)
|
179 |
visualizations.append(vis_image)
|
180 |
plt.close(fig)
|
181 |
|
|
|
202 |
|
203 |
# Convert plot to image
|
204 |
fig.canvas.draw()
|
205 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
206 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
207 |
+
rollout_image = Image.fromarray(data)
|
208 |
plt.close(fig)
|
209 |
|
210 |
return visualizations, rollout_image
|
211 |
|
212 |
+
|
213 |
# Create Gradio interface
|
214 |
iface = gr.Interface(
|
215 |
fn=visualize_attention,
|