huzey commited on
Commit
e9899b2
1 Parent(s): a75bd09
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -4,6 +4,7 @@ import copy
4
  from functools import partial
5
  from io import BytesIO
6
  import os
 
7
 
8
  from einops import rearrange
9
  from matplotlib import pyplot as plt
@@ -374,14 +375,15 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
374
  i_cluster += 1
375
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
376
 
377
- buf = BytesIO()
378
- plt.savefig(buf, bbox_inches='tight', dpi=72)
 
379
 
380
- buf.seek(0) # Move to the start of the BytesIO buffer
381
- img = Image.open(buf)
382
  img = img.convert("RGB")
383
  img = copy.deepcopy(img)
384
- buf.close()
 
385
 
386
  fig_images.append(img)
387
  plt.close()
@@ -838,14 +840,17 @@ def plot_one_image_36_grid(original_image, tsne_rgb_images):
838
  if i_layer == 0:
839
  ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,)
840
  plt.tight_layout()
841
- buf = BytesIO()
842
- plt.savefig(buf, bbox_inches='tight', pad_inches=0, dpi=100)
843
 
844
- buf.seek(0) # Move to the start of the BytesIO buffer
845
- img = Image.open(buf)
 
 
 
846
  img = img.convert("RGB")
847
  img = copy.deepcopy(img)
848
- buf.close()
 
 
849
  plt.close()
850
  return img
851
 
 
4
  from functools import partial
5
  from io import BytesIO
6
  import os
7
+ import uuid
8
 
9
  from einops import rearrange
10
  from matplotlib import pyplot as plt
 
375
  i_cluster += 1
376
  plt.tight_layout(h_pad=0.5, w_pad=0.3)
377
 
378
+ filename = uuid.uuid4()
379
+ tmp_path = f"/tmp/{filename}.png"
380
+ plt.savefig(tmp_path, bbox_inches='tight', dpi=72)
381
 
382
+ img = Image.open(tmp_path)
 
383
  img = img.convert("RGB")
384
  img = copy.deepcopy(img)
385
+
386
+ os.remove(tmp_path)
387
 
388
  fig_images.append(img)
389
  plt.close()
 
840
  if i_layer == 0:
841
  ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,)
842
  plt.tight_layout()
 
 
843
 
844
+ filename = uuid.uuid4()
845
+ filename = f"/tmp/{filename}.png"
846
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=100)
847
+
848
+ img = Image.open(filename)
849
  img = img.convert("RGB")
850
  img = copy.deepcopy(img)
851
+
852
+ os.remove(filename)
853
+
854
  plt.close()
855
  return img
856