huzey commited on
Commit
5a2bda7
1 Parent(s): d92e52a

fix msg, faster eig norm

Browse files
Files changed (2) hide show
  1. app.py +32 -9
  2. fps_cluster.py +2 -1
app.py CHANGED
@@ -2232,6 +2232,12 @@ with demo:
2232
  def __run_fn(*args, **kwargs):
2233
  eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
2234
  rgb_gallery = to_pil_images(rgb)
 
 
 
 
 
 
2235
  return eigvecs, rgb, rgb_gallery, logging_str
2236
 
2237
  submit_button.click(
@@ -2408,8 +2414,8 @@ with demo:
2408
  gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
2409
  with gr.Accordion("Outputs", open=True):
2410
  gr.Markdown("""
2411
- 1. spectral-tSNE tree: ◆ means connected components to the selected point.
2412
- 2. Cluster Heatmap: max cosine similarity to any points in the connected components.
2413
  """)
2414
  with gr.Column(scale=5, min_width=200):
2415
  prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
@@ -2427,6 +2433,7 @@ with demo:
2427
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2428
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2429
  outputs=[tsne_prompt_image])
 
2430
  run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
2431
  inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2432
  # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
@@ -2500,8 +2507,8 @@ with demo:
2500
  x = int(x * w)
2501
  y = int(y * h)
2502
  eigvec = _eigvec[y, x]
2503
- dist = np.linalg.norm(fps_eigvecs - eigvec, axis=1)
2504
- closest_idx = np.argmin(dist)
2505
  return closest_idx
2506
 
2507
  def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
@@ -2529,13 +2536,29 @@ with demo:
2529
  output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
2530
 
2531
  # draw heatmap for the connected components
 
2532
  connected_eigvecs = fps_eigvecs[connected_idxs]
2533
- left = torch.tensor(eigvecs).float() # B H W 3
2534
- right = torch.tensor(connected_eigvecs).float()
2535
- left = F.normalize(left, p=2, dim=-1)
2536
- right = F.normalize(right, p=2, dim=-1)
 
2537
  similarity = left @ right.T
2538
- similarity = similarity.max(dim=-1).values # B H W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2539
  hot_map = matplotlib.colormaps['hot']
2540
  heatmap = hot_map(similarity)[..., :3] # B H W 3
2541
  heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
 
2232
  def __run_fn(*args, **kwargs):
2233
  eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
2234
  rgb_gallery = to_pil_images(rgb)
2235
+ # normalize the eigvecs
2236
+ eigvecs = torch.tensor(eigvecs)
2237
+ if torch.cuda.is_available():
2238
+ eigvecs = eigvecs.cuda()
2239
+ eigvecs = F.normalize(eigvecs, p=2, dim=-1)
2240
+ eigvecs = eigvecs.cpu().numpy()
2241
  return eigvecs, rgb, rgb_gallery, logging_str
2242
 
2243
  submit_button.click(
 
2414
  gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
2415
  with gr.Accordion("Outputs", open=True):
2416
  gr.Markdown("""
2417
+ 1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked dot.
2418
+ 2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2419
  """)
2420
  with gr.Column(scale=5, min_width=200):
2421
  prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
 
2433
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2434
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2435
  outputs=[tsne_prompt_image])
2436
+ prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2437
  run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
2438
  inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2439
  # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
 
2507
  x = int(x * w)
2508
  y = int(y * h)
2509
  eigvec = _eigvec[y, x]
2510
+ sim = fps_eigvecs @ eigvec
2511
+ closest_idx = np.argmax(sim)
2512
  return closest_idx
2513
 
2514
  def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
 
2536
  output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
2537
 
2538
  # draw heatmap for the connected components
2539
+ ## cosine distance
2540
  connected_eigvecs = fps_eigvecs[connected_idxs]
2541
+ left = eigvecs.astype(np.float32) # B H W C
2542
+ right = connected_eigvecs.astype(np.float32) # N C
2543
+ # left = F.normalize(left, p=2, dim=-1)
2544
+ # right = F.normalize(right, p=2, dim=-1)
2545
+ # eigvec is already normalized when saved to gr.State
2546
  similarity = left @ right.T
2547
+ similarity = similarity.max(axis=-1) # B H W N
2548
+ ## euclidean distance
2549
+ # b, h, w = tsne3d_rgb.shape[:3]
2550
+ # tsne3d_rgb = tsne3d_rgb.reshape(b*h*w, 3)
2551
+ # connected_rgb = tsne3d_rgb[fps_indices][connected_idxs]
2552
+ # left = torch.tensor(tsne3d_rgb).float() # (B H W) 3
2553
+ # right = torch.tensor(connected_rgb).float() # N 3
2554
+ # # dist B H W N
2555
+ # dist = left[:, None] - right[None]
2556
+ # dist = torch.sqrt((dist ** 2).sum(dim=-1))
2557
+ # dist = dist.min(dim=-1).values # B H W
2558
+ # dist = dist.reshape(b, h, w)
2559
+ # gr.Info(f"dist: min={dist.min().item()}, max={dist.max().item()}, mean={dist.mean().item()}", 3)
2560
+ # similarity = 1 - dist
2561
+
2562
  hot_map = matplotlib.colormaps['hot']
2563
  heatmap = hot_map(similarity)[..., :3] # B H W 3
2564
  heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
fps_cluster.py CHANGED
@@ -5,7 +5,8 @@ import torch
5
 
6
  def build_tree(all_dots):
7
  num_sample = all_dots.shape[0]
8
- center = all_dots.mean(axis=0)
 
9
  distances_to_center = np.linalg.norm(all_dots - center, axis=1)
10
  start_idx = np.argmin(distances_to_center)
11
  indices = [start_idx]
 
5
 
6
  def build_tree(all_dots):
7
  num_sample = all_dots.shape[0]
8
+ # center = all_dots.mean(axis=0)
9
+ center = np.median(all_dots, axis=0)
10
  distances_to_center = np.linalg.norm(all_dots - center, axis=1)
11
  start_idx = np.argmin(distances_to_center)
12
  indices = [start_idx]