huzey commited on
Commit
2dd390f
1 Parent(s): fa3878c

add tree+image click

Browse files
Files changed (1) hide show
  1. app.py +38 -21
app.py CHANGED
@@ -13,6 +13,8 @@ import uuid
13
  import zipfile
14
  import multiprocessing as mp
15
 
 
 
16
  from einops import rearrange
17
  from matplotlib import pyplot as plt
18
  import matplotlib
@@ -2472,26 +2474,32 @@ with demo:
2472
  # Plot the t-SNE points with image heatmaps
2473
  big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
2474
 
2475
- return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, big_pil_image
2476
 
2477
- big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
 
 
 
 
2478
 
2479
  run_hierarchical_button.click(
2480
  run_fps_tsne_hierarchical,
2481
  inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2482
- outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, big_tsne_plot],
2483
  )
2484
- gr.Markdown('---')
2485
- gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
2486
- gr.Markdown('---')
2487
  with gr.Row():
2488
- from gradio_image_prompter import ImagePrompter
2489
  with gr.Column(scale=5, min_width=200) as tsne_select:
 
 
 
2490
  tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
2491
  # copy plot to tsne_prompt_image on change
2492
  # tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
2493
  # inputs=[tsne_plot], outputs=[tsne_prompt_image])
2494
  with gr.Column(scale=5, min_width=200) as image_select:
 
 
 
2495
  image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
2496
  image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
2497
  def update_image_prompt(image_slider, output_gallery):
@@ -2505,11 +2513,16 @@ with demo:
2505
  image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2506
  output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2507
  output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
 
 
 
 
 
2508
  with gr.Column(scale=5, min_width=200):
2509
  gr.Markdown('<h3 style="text-align: center;">Help</h3>')
2510
  with gr.Accordion("Instructions", open=True):
2511
  gr.Markdown("""
2512
- 1. Click one dot on the left-side image.
2513
  - Only the last clicked dot will be used
2514
  - Eraser is at top-right corner
2515
  - Use the right-side Radio to switch tree/image
@@ -2524,32 +2537,34 @@ with demo:
2524
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2525
  """)
2526
  with gr.Column(scale=5, min_width=200):
2527
- prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
2528
  granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
2529
  num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
2530
  def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
2531
  # Plot the t-SNE points
2532
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
2533
- if tsne_prompt_image is None:
2534
- return gr.update(value={'image': pil_image}, interactive=True)
2535
- return gr.update(value={'image': pil_image, 'points': tsne_prompt_image['points']}, interactive=True)
2536
  granularity_slider.change(updaste_tsne_plot_change_granularity,
2537
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
2538
- outputs=[tsne_prompt_image])
2539
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2540
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2541
- outputs=[tsne_prompt_image])
2542
  prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
 
 
 
2543
  run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
2544
  inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2545
  # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
2546
 
2547
  delete_all_output_button = gr.Button("❌ Delete All Output", elem_id="delete_all_output", variant='secondary')
2548
-
 
 
2549
  image_select.visible = False
2550
- tsne_select.visible = True
2551
- prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree"), inputs=prompt_radio, outputs=[tsne_select])
2552
  prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
 
2553
 
2554
  MAX_ROWS = 20
2555
  current_output_row = gr.State(0)
@@ -2634,19 +2649,21 @@ with demo:
2634
  closest_idx = np.argmax(sim)
2635
  return closest_idx, (_x_ratio, _y_ratio)
2636
 
2637
- def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
2638
  try:
2639
  if prompt_radio == "Tree":
2640
  return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
2641
  if prompt_radio == "Image":
2642
  return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
 
 
2643
  except:
2644
  raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
2645
 
2646
- def run_inspection(tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
2647
  if len(tsne2d_embed) == 0:
2648
  raise gr.Error("Please run FPS+Cluster first.")
2649
- closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
2650
  closest_rgb = fps_tsne_rgb[closest_idx]
2651
  closest_rgb = (closest_rgb * 255).astype(np.uint8)
2652
 
@@ -2728,7 +2745,7 @@ with demo:
2728
 
2729
  run_inspection_button.click(
2730
  run_inspection,
2731
- inputs=[tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy],
2732
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
2733
  )
2734
 
 
13
  import zipfile
14
  import multiprocessing as mp
15
 
16
+ from gradio_image_prompter import ImagePrompter
17
+
18
  from einops import rearrange
19
  from matplotlib import pyplot as plt
20
  import matplotlib
 
2474
  # Plot the t-SNE points with image heatmaps
2475
  big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
2476
 
2477
+ return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, gr.update(value={'image': big_pil_image, 'points': []}, interactive=True)
2478
 
2479
+ gr.Markdown('---')
2480
+ gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
2481
+ gr.Markdown('---')
2482
+ # big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
2483
+ tsne_image_plot = ImagePrompter(show_label=True, elem_id="tsne_image_plot", interactive=False, label="spectral-tSNE tree [+ Cluster Heatmap]")
2484
 
2485
  run_hierarchical_button.click(
2486
  run_fps_tsne_hierarchical,
2487
  inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2488
+ outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
2489
  )
 
 
 
2490
  with gr.Row():
 
2491
  with gr.Column(scale=5, min_width=200) as tsne_select:
2492
+ gr.Markdown('---')
2493
+ gr.Markdown('<h3 style="text-align: center;">Please click on the image blow ↓</h3>')
2494
+ gr.Markdown('---')
2495
  tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
2496
  # copy plot to tsne_prompt_image on change
2497
  # tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
2498
  # inputs=[tsne_plot], outputs=[tsne_prompt_image])
2499
  with gr.Column(scale=5, min_width=200) as image_select:
2500
+ gr.Markdown('---')
2501
+ gr.Markdown('<h3 style="text-align: center;">Please click on the image blow ↓</h3>')
2502
+ gr.Markdown('---')
2503
  image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
2504
  image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
2505
  def update_image_prompt(image_slider, output_gallery):
 
2513
  image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2514
  output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2515
  output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
2516
+ with gr.Column(scale=5, min_width=200) as tsne_image_select:
2517
+ gr.Markdown('---')
2518
+ gr.Markdown('<h3 style="text-align: center;">Please click on the image above ↑</h3>')
2519
+ gr.Markdown('---')
2520
+ tsne_non_prompt_image = gr.Image(label="spectral-tSNE tree", elem_id="tsne_non_prompt_image", interactive=False, format='png')
2521
  with gr.Column(scale=5, min_width=200):
2522
  gr.Markdown('<h3 style="text-align: center;">Help</h3>')
2523
  with gr.Accordion("Instructions", open=True):
2524
  gr.Markdown("""
2525
+ 1. Click one dot on the image.
2526
  - Only the last clicked dot will be used
2527
  - Eraser is at top-right corner
2528
  - Use the right-side Radio to switch tree/image
 
2537
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2538
  """)
2539
  with gr.Column(scale=5, min_width=200):
2540
+ prompt_radio = gr.Radio(["Tree [+Image]", "Image"], label="Where to click on?", value="Tree [+Image]", elem_id="prompt_radio", show_label=True)
2541
  granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
2542
  num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
2543
  def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
2544
  # Plot the t-SNE points
2545
  pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
2546
+ return gr.update(value=pil_image, label=f"spectral-tSNE tree [k={granularity}]")
 
 
2547
  granularity_slider.change(updaste_tsne_plot_change_granularity,
2548
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
2549
+ outputs=[tsne_non_prompt_image])
2550
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2551
  inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2552
+ outputs=[tsne_non_prompt_image])
2553
  prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2554
+ # prompt_radio.change(updaste_tsne_plot_change_granularity,
2555
+ # inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
2556
+ # outputs=[tsne_non_prompt_image])
2557
  run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
2558
  inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2559
  # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
2560
 
2561
  delete_all_output_button = gr.Button("❌ Delete All Output", elem_id="delete_all_output", variant='secondary')
2562
+
2563
+ tsne_image_select.visible = True
2564
+ tsne_select.visible = False
2565
  image_select.visible = False
 
 
2566
  prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
2567
+ prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree [+Image]"), inputs=prompt_radio, outputs=[tsne_image_select])
2568
 
2569
  MAX_ROWS = 20
2570
  current_output_row = gr.State(0)
 
2649
  closest_idx = np.argmax(sim)
2650
  return closest_idx, (_x_ratio, _y_ratio)
2651
 
2652
+ def find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
2653
  try:
2654
  if prompt_radio == "Tree":
2655
  return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
2656
  if prompt_radio == "Image":
2657
  return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
2658
+ if prompt_radio == "Tree [+Image]":
2659
+ return find_closest_fps_point_for_tsne_tree_plot(tsne_image_prompt, tsne2d_embed)
2660
  except:
2661
  raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
2662
 
2663
+ def run_inspection(tsne_image_prompt, tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
2664
  if len(tsne2d_embed) == 0:
2665
  raise gr.Error("Please run FPS+Cluster first.")
2666
+ closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
2667
  closest_rgb = fps_tsne_rgb[closest_idx]
2668
  closest_rgb = (closest_rgb * 255).astype(np.uint8)
2669
 
 
2745
 
2746
  run_inspection_button.click(
2747
  run_inspection,
2748
+ inputs=[tsne_image_plot, tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy],
2749
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
2750
  )
2751