huzey commited on
Commit
38c0aa5
1 Parent(s): 2dd390f

add tree option

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -2297,6 +2297,7 @@ with demo:
2297
  num_sample_fps_slider = gr.Slider(1, 5000, step=1, label="FPS: num_sample", value=1000, elem_id="num_sample_fps")
2298
  tsne_perplexity_slider = gr.Slider(1, 1000, step=1, label="t-SNE: perplexity", value=500, elem_id="perplexity_tsne")
2299
  fps_hc_seed_slider = gr.Slider(0, 1000, step=1, label="Seed", value=0, elem_id="fps_hc_seed")
 
2300
  tsne_plot = gr.Image(label="spectral-tSNE tree", elem_id="tsne_plot", interactive=False, format='png')
2301
 
2302
  tsne_2d_points = gr.State(np.array([]))
@@ -2309,10 +2310,20 @@ with demo:
2309
  # Plot the t-SNE points
2310
  fig, ax = plt.subplots(1, 1, figsize=(6, 6))
2311
  ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
 
 
 
 
2312
  # draw the edges
2313
  for i_edge in range(k, len(edges)):
2314
  edge = edges[i_edge]
2315
- ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=0.7)
 
 
 
 
 
 
2316
  # highlight the selected node
2317
  if hightlight_idx is not None:
2318
  if highlight_connections:
@@ -2428,7 +2439,7 @@ with demo:
2428
  return pil_image
2429
 
2430
 
2431
- def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, max_display_dots=300):
2432
  if len(eigvecs) == 0:
2433
  gr.Warning("Please run NCUT first.")
2434
  return
@@ -2464,9 +2475,10 @@ with demo:
2464
  tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1
2465
  tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
2466
 
2467
- edges = build_tree(tsne_embed)
2468
- # edges = build_tree(fps_eigvecs, dist='cosine')
2469
- # edges = build_tree(fps_tsne3d_rgb)
 
2470
 
2471
  # Plot the t-SNE points
2472
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
@@ -2484,7 +2496,7 @@ with demo:
2484
 
2485
  run_hierarchical_button.click(
2486
  run_fps_tsne_hierarchical,
2487
- inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2488
  outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
2489
  )
2490
  with gr.Row():
 
2297
  num_sample_fps_slider = gr.Slider(1, 5000, step=1, label="FPS: num_sample", value=1000, elem_id="num_sample_fps")
2298
  tsne_perplexity_slider = gr.Slider(1, 1000, step=1, label="t-SNE: perplexity", value=500, elem_id="perplexity_tsne")
2299
  fps_hc_seed_slider = gr.Slider(0, 1000, step=1, label="Seed", value=0, elem_id="fps_hc_seed")
2300
+ tree_method_radio = gr.Radio(["eigvecs", "tsne"], label="Tree Method (input type)", value="eigvecs", elem_id="tree_method")
2301
  tsne_plot = gr.Image(label="spectral-tSNE tree", elem_id="tsne_plot", interactive=False, format='png')
2302
 
2303
  tsne_2d_points = gr.State(np.array([]))
 
2310
  # Plot the t-SNE points
2311
  fig, ax = plt.subplots(1, 1, figsize=(6, 6))
2312
  ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
2313
+ # compute the length of the edges
2314
+ lengthes = np.linalg.norm(tsne_embed[edges[:, 0]] - tsne_embed[edges[:, 1]], axis=1)
2315
+ max_length = lengthes[k:].max()
2316
+ diag_length = np.linalg.norm(tsne_embed.max(axis=0) - tsne_embed.min(axis=0))
2317
  # draw the edges
2318
  for i_edge in range(k, len(edges)):
2319
  edge = edges[i_edge]
2320
+ # _do = np.clip(lengthes[i_edge] / (diag_length*0.3), 0, 1)
2321
+ if lengthes[i_edge] > diag_length*0.1:
2322
+ _do = 1.0
2323
+ else:
2324
+ _do = 0.0
2325
+ alpha = 0.7 * (1 - _do) + 0.0
2326
+ ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=alpha)
2327
  # highlight the selected node
2328
  if hightlight_idx is not None:
2329
  if highlight_connections:
 
2439
  return pil_image
2440
 
2441
 
2442
+ def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, tree_method='eigvecs', max_display_dots=300):
2443
  if len(eigvecs) == 0:
2444
  gr.Warning("Please run NCUT first.")
2445
  return
 
2475
  tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1
2476
  tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
2477
 
2478
+ if tree_method == 'eigvecs':
2479
+ edges = build_tree(fps_eigvecs, dist='cosine')
2480
+ if tree_method == 'tsne':
2481
+ edges = build_tree(tsne_embed, dist='euclidean')
2482
 
2483
  # Plot the t-SNE points
2484
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
 
2496
 
2497
  run_hierarchical_button.click(
2498
  run_fps_tsne_hierarchical,
2499
+ inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider, tree_method_radio],
2500
  outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
2501
  )
2502
  with gr.Row():