huzey commited on
Commit
d90f66b
1 Parent(s): abc6adf

update cluster sort

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -447,6 +447,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
447
  mask = (heatmap > top_p).float()
448
  # take top 3 masks only
449
  mask_sort_values = mask.mean((1, 2))
 
 
450
  mask_sort_idx = torch.argsort(mask_sort_values, descending=True)
451
  mask = mask[mask_sort_idx[:3]]
452
  sort_values.append(mask.mean().item())
 
447
  mask = (heatmap > top_p).float()
448
  # take top 3 masks only
449
  mask_sort_values = mask.mean((1, 2))
450
+ _sort_value2 = (heatmap > 0.1).float().mean((1, 2)) * 0.1
451
+ mask_sort_values += _sort_value2
452
  mask_sort_idx = torch.argsort(mask_sort_values, descending=True)
453
  mask = mask[mask_sort_idx[:3]]
454
  sort_values.append(mask.mean().item())