huzey commited on
Commit
5afcac2
1 Parent(s): 9913a19

add test playground

Browse files
Files changed (1) hide show
  1. app.py +323 -26
app.py CHANGED
@@ -2,6 +2,7 @@
2
  # %%
3
  import copy
4
  from datetime import datetime
 
5
  import pickle
6
  from functools import partial
7
  from io import BytesIO
@@ -137,6 +138,7 @@ def compute_ncut(
137
  indirect_connection=True,
138
  make_orthogonal=False,
139
  progess_start=0.4,
 
140
  ):
141
  progress = gr.Progress()
142
  logging_str = ""
@@ -165,6 +167,10 @@ def compute_ncut(
165
  # print(f"NCUT time: {time.time() - start:.2f}s")
166
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
167
 
 
 
 
 
168
  start = time.time()
169
  progress(progess_start+0.01, desc="spectral-tSNE")
170
  _, rgb = eigenvector_to_rgb(
@@ -272,10 +278,12 @@ def dont_use_too_much_green(image_rgb):
272
  return image_rgb
273
 
274
 
275
- def to_pil_images(images, target_size=512, resize=True):
276
  size = images[0].shape[1]
277
  multiplier = target_size // size
278
  res = int(size * multiplier)
 
 
279
  pil_images = [
280
  Image.fromarray((image * 255).cpu().numpy().astype(np.uint8))
281
  for image in images
@@ -855,6 +863,8 @@ def ncut_run(
855
 
856
  # ailgnedcut
857
  if not directed:
 
 
858
  rgb, _logging_str, eigvecs = compute_ncut(
859
  features,
860
  num_eig=num_eig,
@@ -872,7 +882,12 @@ def ncut_run(
872
  indirect_connection=indirect_connection,
873
  make_orthogonal=make_orthogonal,
874
  metric=ncut_metric,
 
875
  )
 
 
 
 
876
  if directed:
877
  head_index_text = kwargs.get("head_index_text", None)
878
  n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
@@ -978,26 +993,26 @@ def ncut_run(
978
 
979
  def _ncut_run(*args, **kwargs):
980
  n_ret = kwargs.get("n_ret", 1)
981
- try:
982
- if torch.cuda.is_available():
983
- torch.cuda.empty_cache()
984
 
985
- ret = ncut_run(*args, **kwargs)
986
 
987
- if torch.cuda.is_available():
988
- torch.cuda.empty_cache()
989
 
990
- ret = list(ret)[:n_ret] + [ret[-1]]
991
- return ret
992
- except Exception as e:
993
- gr.Error(str(e))
994
- if torch.cuda.is_available():
995
- torch.cuda.empty_cache()
996
- return *(None for _ in range(n_ret)), "Error: " + str(e)
997
 
998
- # ret = ncut_run(*args, **kwargs)
999
- # ret = list(ret)[:n_ret] + [ret[-1]]
1000
- # return ret
1001
 
1002
  if USE_HUGGINGFACE_ZEROGPU:
1003
  @spaces.GPU(duration=30)
@@ -1213,6 +1228,7 @@ def run_fn(
1213
  alignedcut_eig_norm_plot=False,
1214
  advanced=False,
1215
  directed=False,
 
1216
  ):
1217
  # print(node_type2, head_index_text, make_symmetric)
1218
  progress=gr.Progress()
@@ -1353,6 +1369,7 @@ def run_fn(
1353
  "node_type2": node_type2,
1354
  "head_index_text": head_index_text,
1355
  "make_symmetric": make_symmetric,
 
1356
  }
1357
  # print(kwargs)
1358
 
@@ -1664,7 +1681,7 @@ def load_and_append(existing_images, *args, **kwargs):
1664
  gr.Info(f"Total images: {len(existing_images)}")
1665
  return existing_images
1666
 
1667
- def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True):
1668
  if markdown:
1669
  gr.Markdown('### Input Images')
1670
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
@@ -1702,7 +1719,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1702
  with gr.Row():
1703
  button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
1704
  gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
1705
- button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
1706
  return gallery, button
1707
  example_items = [
1708
  ("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"),
@@ -2040,7 +2057,7 @@ def make_output_images_section(markdown=True, button=True):
2040
  add_rotate_flip_buttons(output_gallery)
2041
  return output_gallery
2042
 
2043
- def make_parameters_section(is_lisa=False, model_ratio=True):
2044
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
2045
  from ncut_pytorch.backbone import list_models, get_demo_model_names
2046
  model_names = list_models()
@@ -2105,18 +2122,18 @@ def make_parameters_section(is_lisa=False, model_ratio=True):
2105
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
2106
  model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
2107
 
2108
- with gr.Accordion("Advanced Parameters: NCUT", open=False):
2109
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
2110
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
2111
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
2112
  # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
2113
  sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
2114
  # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2115
- ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2116
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2117
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
2118
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
2119
- with gr.Accordion("Advanced Parameters: Visualization", open=False):
2120
  # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2121
  embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2122
  # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
@@ -2147,8 +2164,9 @@ demo = gr.Blocks(
2147
  css=custom_css,
2148
  )
2149
  with demo:
2150
-
2151
-
 
2152
  with gr.Tab('AlignedCut'):
2153
 
2154
  with gr.Row():
@@ -2989,7 +3007,7 @@ with demo:
2989
  # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
2990
  sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
2991
  # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2992
- ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2993
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2994
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
2995
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
@@ -3422,6 +3440,274 @@ with demo:
3422
  outputs=[mask_gallery, crop_gallery])
3423
 
3424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3425
  with gr.Tab('📄About'):
3426
  with gr.Column():
3427
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
@@ -3481,6 +3767,7 @@ with demo:
3481
  hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
3482
  hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
3483
  hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
 
3484
 
3485
  # with gr.Row():
3486
  # with gr.Column():
@@ -3522,3 +3809,13 @@ demo.launch(share=True)
3522
  # # %%
3523
 
3524
  # %%
 
 
 
 
 
 
 
 
 
 
 
2
  # %%
3
  import copy
4
  from datetime import datetime
5
+ import math
6
  import pickle
7
  from functools import partial
8
  from io import BytesIO
 
138
  indirect_connection=True,
139
  make_orthogonal=False,
140
  progess_start=0.4,
141
+ only_eigvecs=False,
142
  ):
143
  progress = gr.Progress()
144
  logging_str = ""
 
167
  # print(f"NCUT time: {time.time() - start:.2f}s")
168
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
169
 
170
+ if only_eigvecs:
171
+ eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
172
+ return None, logging_str, eigvecs
173
+
174
  start = time.time()
175
  progress(progess_start+0.01, desc="spectral-tSNE")
176
  _, rgb = eigenvector_to_rgb(
 
278
  return image_rgb
279
 
280
 
281
+ def to_pil_images(images, target_size=512, resize=True, force_size=False):
282
  size = images[0].shape[1]
283
  multiplier = target_size // size
284
  res = int(size * multiplier)
285
+ if force_size:
286
+ res = target_size
287
  pil_images = [
288
  Image.fromarray((image * 255).cpu().numpy().astype(np.uint8))
289
  for image in images
 
863
 
864
  # ailgnedcut
865
  if not directed:
866
+ only_eigvecs = kwargs.get("only_eigvecs", False)
867
+
868
  rgb, _logging_str, eigvecs = compute_ncut(
869
  features,
870
  num_eig=num_eig,
 
882
  indirect_connection=indirect_connection,
883
  make_orthogonal=make_orthogonal,
884
  metric=ncut_metric,
885
+ only_eigvecs=only_eigvecs,
886
  )
887
+
888
+ if only_eigvecs:
889
+ return eigvecs, logging_str
890
+
891
  if directed:
892
  head_index_text = kwargs.get("head_index_text", None)
893
  n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
 
993
 
994
  def _ncut_run(*args, **kwargs):
995
  n_ret = kwargs.get("n_ret", 1)
996
+ # try:
997
+ # if torch.cuda.is_available():
998
+ # torch.cuda.empty_cache()
999
 
1000
+ # ret = ncut_run(*args, **kwargs)
1001
 
1002
+ # if torch.cuda.is_available():
1003
+ # torch.cuda.empty_cache()
1004
 
1005
+ # ret = list(ret)[:n_ret] + [ret[-1]]
1006
+ # return ret
1007
+ # except Exception as e:
1008
+ # gr.Error(str(e))
1009
+ # if torch.cuda.is_available():
1010
+ # torch.cuda.empty_cache()
1011
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
1012
 
1013
+ ret = ncut_run(*args, **kwargs)
1014
+ ret = list(ret)[:n_ret] + [ret[-1]]
1015
+ return ret
1016
 
1017
  if USE_HUGGINGFACE_ZEROGPU:
1018
  @spaces.GPU(duration=30)
 
1228
  alignedcut_eig_norm_plot=False,
1229
  advanced=False,
1230
  directed=False,
1231
+ only_eigvecs=False,
1232
  ):
1233
  # print(node_type2, head_index_text, make_symmetric)
1234
  progress=gr.Progress()
 
1369
  "node_type2": node_type2,
1370
  "head_index_text": head_index_text,
1371
  "make_symmetric": make_symmetric,
1372
+ "only_eigvecs": only_eigvecs,
1373
  }
1374
  # print(kwargs)
1375
 
 
1681
  gr.Info(f"Total images: {len(existing_images)}")
1682
  return existing_images
1683
 
1684
+ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
1685
  if markdown:
1686
  gr.Markdown('### Input Images')
1687
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
 
1719
  with gr.Row():
1720
  button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
1721
  gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
1722
+ button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, n_example_images, is_random=True, seed=42)), outputs=[input_gallery])
1723
  return gallery, button
1724
  example_items = [
1725
  ("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"),
 
2057
  add_rotate_flip_buttons(output_gallery)
2058
  return output_gallery
2059
 
2060
+ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=True):
2061
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
2062
  from ncut_pytorch.backbone import list_models, get_demo_model_names
2063
  model_names = list_models()
 
2122
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
2123
  model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
2124
 
2125
+ with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=parameter_dropdown):
2126
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
2127
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
2128
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
2129
  # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
2130
  sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
2131
  # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2132
+ ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
2133
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2134
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
2135
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
2136
+ with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=parameter_dropdown):
2137
  # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2138
  embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2139
  # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
 
2164
  css=custom_css,
2165
  )
2166
  with demo:
2167
+
2168
+
2169
+
2170
  with gr.Tab('AlignedCut'):
2171
 
2172
  with gr.Row():
 
3007
  # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
3008
  sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
3009
  # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
3010
+ ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
3011
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
3012
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
3013
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
 
3440
  outputs=[mask_gallery, crop_gallery])
3441
 
3442
 
3443
+ with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
3444
+ eigvecs = gr.State(torch.tensor([]))
3445
+ with gr.Row():
3446
+ with gr.Column(scale=5, min_width=200):
3447
+ gr.Markdown("### Step 1: Load Images and Run NCUT")
3448
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
3449
+ # submit_button.visible = False
3450
+ num_images_slider.value = 30
3451
+ [
3452
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
3453
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3454
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3455
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
3456
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3457
+ ] = make_parameters_section(parameter_dropdown=False)
3458
+ num_eig_slider.value = 1000
3459
+ num_eig_slider.visible = False
3460
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
3461
+
3462
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
3463
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
3464
+
3465
+ submit_button.click(
3466
+ partial(run_fn, n_ret=1, only_eigvecs=True),
3467
+ inputs=[
3468
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
3469
+ positive_prompt, negative_prompt,
3470
+ false_placeholder, no_prompt, no_prompt, no_prompt,
3471
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3472
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3473
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
3474
+ ],
3475
+ outputs=[eigvecs, logging_text],
3476
+ )
3477
+
3478
+ with gr.Column(scale=5, min_width=200):
3479
+ gr.Markdown("### Step 2a: Pick an Image")
3480
+ from gradio_image_prompter import ImagePrompter
3481
+ with gr.Row():
3482
+ image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
3483
+ load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
3484
+ gr.Markdown("### Step 2b: Draw a Point")
3485
+ gr.Markdown("""
3486
+ <h5>
3487
+ 🖱️ Left Click: Foreground </br>
3488
+ </h5>
3489
+ """)
3490
+ prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
3491
+ def update_prompt_image(original_images, index):
3492
+ images = original_images
3493
+ if images is None:
3494
+ return
3495
+ total_len = len(images)
3496
+ if total_len == 0:
3497
+ return
3498
+ if index >= total_len:
3499
+ index = total_len - 1
3500
+
3501
+ return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
3502
+ # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
3503
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
3504
+
3505
+ child_idx = gr.State([])
3506
+ current_idx = gr.State(None)
3507
+ n_eig = gr.State(64)
3508
+ with gr.Column(scale=5, min_width=200):
3509
+ gr.Markdown("### Step 3: Check groupping")
3510
+ child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
3511
+ overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
3512
+ run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
3513
+ parent_plot = gr.Gallery(value=None, label="Parent", show_label=True, elem_id="parent_plot", interactive=False, rows=[1], columns=[2])
3514
+ parent_button = gr.Button("Use Parent", elem_id="run_parent")
3515
+ current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
3516
+ with gr.Column(scale=5, min_width=200):
3517
+ child_plots = []
3518
+ child_buttons = []
3519
+ for i in range(4):
3520
+ child_plots.append(gr.Gallery(value=None, label=f"Child {i}", show_label=True, elem_id=f"child_plot_{i}", interactive=False, rows=[1], columns=[2]))
3521
+ child_buttons.append(gr.Button(f"Use Child {i}", elem_id=f"run_child_{i}"))
3522
+
3523
+ def relative_xy(prompts):
3524
+ image = prompts['image']
3525
+ points = np.asarray(prompts['points'])
3526
+ if points.shape[0] == 0:
3527
+ return [], []
3528
+ is_point = points[:, 5] == 4.0
3529
+ points = points[is_point]
3530
+ is_positive = points[:, 2] == 1.0
3531
+ is_negative = points[:, 2] == 0.0
3532
+ xy = points[:, :2].tolist()
3533
+ if isinstance(image, str):
3534
+ image = Image.open(image)
3535
+ image = np.array(image)
3536
+ h, w = image.shape[:2]
3537
+ new_xy = [(x/w, y/h) for x, y in xy]
3538
+ # print(new_xy)
3539
+ return new_xy, is_positive
3540
+
3541
+ def xy_eigvec(prompts, image_idx, eigvecs):
3542
+ eigvec = eigvecs[image_idx]
3543
+ xy, is_positive = relative_xy(prompts)
3544
+ for i, (x, y) in enumerate(xy):
3545
+ if not is_positive[i]:
3546
+ continue
3547
+ x = int(x * eigvec.shape[1])
3548
+ y = int(y * eigvec.shape[0])
3549
+ return eigvec[y, x], (y, x)
3550
+
3551
+ from ncut_pytorch.ncut_pytorch import _transform_heatmap
3552
+ def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
3553
+ left = eigvecs[..., :n_eig]
3554
+ if flat_idx is not None:
3555
+ right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
3556
+ y, x = None, None
3557
+ else:
3558
+ right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
3559
+ right = right[:n_eig]
3560
+ left = F.normalize(left, p=2, dim=1)
3561
+ _right = F.normalize(right, p=2, dim=0)
3562
+ heatmap = left @ _right.unsqueeze(-1)
3563
+ heatmap = heatmap.squeeze(-1)
3564
+ heatmap = 1 - heatmap
3565
+ heatmap = _transform_heatmap(heatmap)
3566
+ if raw_heatmap:
3567
+ return heatmap
3568
+ # apply hot colormap and covert to PIL image 256x256
3569
+ heatmap = heatmap.cpu().numpy()
3570
+ hot_map = matplotlib.cm.get_cmap('hot')
3571
+ heatmap = hot_map(heatmap)
3572
+ pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
3573
+ if overlay_image:
3574
+ overlaied_images = []
3575
+ for i_image in range(len(images)):
3576
+ rgb_image = images[i_image].resize((256, 256))
3577
+ rgb_image = np.array(rgb_image)
3578
+ heatmap_image = np.array(pil_images[i_image])[..., :3]
3579
+ blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
3580
+ blend_image = Image.fromarray(blend_image.astype(np.uint8))
3581
+ overlaied_images.append(blend_image)
3582
+ pil_images = overlaied_images
3583
+ return pil_images, (y, x)
3584
+
3585
+ def farthest_point_sampling(
3586
+ features,
3587
+ start_feature,
3588
+ num_sample=300,
3589
+ h=9,
3590
+ ):
3591
+ import fpsample
3592
+
3593
+ h = min(h, int(np.log2(features.shape[0])))
3594
+
3595
+ inp = features.cpu().numpy()
3596
+ inp = np.concatenate([inp, start_feature[None, :]], axis=0)
3597
+
3598
+ kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
3599
+ inp, num_sample, h, start_idx=inp.shape[0] - 1
3600
+ ).astype(np.int64)
3601
+ return kdline_fps_samples_idx
3602
+
3603
+ @torch.no_grad()
3604
+ def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
3605
+ gr.Info(f"current number of eigenvectors: {n_eig}")
3606
+ images = [image[0] for image in images]
3607
+ if isinstance(images[0], str):
3608
+ images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
3609
+
3610
+ current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
3611
+ parent_heatmap, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig/2), flat_idx, overlay_image=overlay_image)
3612
+
3613
+ # find childs
3614
+ # pca_eigvecs
3615
+ _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
3616
+ u, s, v = torch.pca_lowrank(_eigvecs, q=8)
3617
+ _n = _eigvecs.shape[0]
3618
+ s /= math.sqrt(_n)
3619
+ _eigvecs = u @ torch.diag(s)
3620
+
3621
+ if flat_idx is None:
3622
+ _picked_eigvec = _eigvecs.reshape(*eigvecs.shape[:-1], 8)[image1_slider, y, x]
3623
+ else:
3624
+ _picked_eigvec = _eigvecs[flat_idx]
3625
+ l2_distance = torch.norm(_eigvecs - _picked_eigvec, dim=-1)
3626
+ average_distance = l2_distance.mean()
3627
+ distance_threshold = distance_slider * average_distance
3628
+ distance_mask = l2_distance < distance_threshold
3629
+ masked_eigvecs = _eigvecs[distance_mask]
3630
+ num_childs = min(4, masked_eigvecs.shape[0])
3631
+ assert num_childs > 0
3632
+
3633
+ child_idx = farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
3634
+ child_idx = np.sort(child_idx)[:-1]
3635
+
3636
+ # convert child_idx to flat_idx
3637
+ dummy_idx = torch.zeros(_eigvecs.shape[0], dtype=torch.bool)
3638
+ dummy_idx2 = torch.zeros(int(distance_mask.sum().item()), dtype=torch.bool)
3639
+ dummy_idx2[child_idx] = True
3640
+ dummy_idx[distance_mask] = dummy_idx2
3641
+ child_idx = torch.where(dummy_idx)[0]
3642
+
3643
+
3644
+ # current_child heatmap, for contrast
3645
+ current_child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), flat_idx, raw_heatmap=True, overlay_image=overlay_image)
3646
+
3647
+ # child_heatmaps, contrast mean of current clicked point
3648
+ child_heatmaps = []
3649
+ for idx in child_idx:
3650
+ child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, raw_heatmap=True, overlay_image=overlay_image)
3651
+ heatmap = child_heatmap - current_child_heatmap
3652
+ # convert [-1, 1] to [0, 1]
3653
+ heatmap = (heatmap + 1) / 2
3654
+ heatmap = heatmap.cpu().numpy()
3655
+ cm = matplotlib.cm.get_cmap('bwr')
3656
+ heatmap = cm(heatmap)
3657
+ # bwr with contrast
3658
+ pil_images1 = to_pil_images(torch.tensor(heatmap), resize=256)
3659
+ # no contrast
3660
+ pil_images2, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, overlay_image=overlay_image)
3661
+
3662
+ # combine contrast and no contrast
3663
+ pil_images = []
3664
+ for i in range(len(pil_images1)):
3665
+ pil_images.append(pil_images2[i])
3666
+ pil_images.append(pil_images1[i])
3667
+
3668
+
3669
+ child_heatmaps.append(pil_images)
3670
+
3671
+ return parent_heatmap, current_heatmap, *child_heatmaps, child_idx.tolist()
3672
+
3673
+ # def debug_fn(eigvecs):
3674
+ # shape = eigvecs.shape
3675
+ # gr.Info(f"eigvecs shape: {shape}")
3676
+
3677
+ # run_button.click(
3678
+ # debug_fn,
3679
+ # inputs=[eigvecs],
3680
+ # outputs=[],
3681
+ # )
3682
+ none_placeholder = gr.State(None)
3683
+ run_button.click(
3684
+ run_heatmap,
3685
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, none_placeholder, overlay_image_checkbox],
3686
+ outputs=[parent_plot, current_plot, *child_plots, child_idx],
3687
+ )
3688
+
3689
+ def run_paraent(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
3690
+ n_eig = int(n_eig/2)
3691
+ return n_eig, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image)
3692
+
3693
+ parent_button.click(
3694
+ run_paraent,
3695
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, current_idx, overlay_image_checkbox],
3696
+ outputs=[n_eig, parent_plot, current_plot, *child_plots, child_idx],
3697
+ )
3698
+
3699
+ def run_child(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, child_idx=[], overlay_image=True, i_child=0):
3700
+ n_eig = int(n_eig*2)
3701
+ flat_idx = child_idx[i_child]
3702
+ return n_eig, flat_idx, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image)
3703
+
3704
+ for i in range(4):
3705
+ child_buttons[i].click(
3706
+ partial(run_child, i_child=i),
3707
+ inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
3708
+ outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
3709
+ )
3710
+
3711
  with gr.Tab('📄About'):
3712
  with gr.Column():
3713
  gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
 
3767
  hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
3768
  hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
3769
  hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
3770
+ hidden_button.change(unlock_tabs, n_smiles, test_playground_tab)
3771
 
3772
  # with gr.Row():
3773
  # with gr.Column():
 
3809
  # # %%
3810
 
3811
  # %%
3812
+
3813
+ # %%
3814
+
3815
+ # %%
3816
+
3817
+ # %%
3818
+
3819
+ # %%
3820
+
3821
+ # %%