Spaces:
Running
on
Zero
Running
on
Zero
add test playground
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
# %%
|
3 |
import copy
|
4 |
from datetime import datetime
|
|
|
5 |
import pickle
|
6 |
from functools import partial
|
7 |
from io import BytesIO
|
@@ -137,6 +138,7 @@ def compute_ncut(
|
|
137 |
indirect_connection=True,
|
138 |
make_orthogonal=False,
|
139 |
progess_start=0.4,
|
|
|
140 |
):
|
141 |
progress = gr.Progress()
|
142 |
logging_str = ""
|
@@ -165,6 +167,10 @@ def compute_ncut(
|
|
165 |
# print(f"NCUT time: {time.time() - start:.2f}s")
|
166 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
167 |
|
|
|
|
|
|
|
|
|
168 |
start = time.time()
|
169 |
progress(progess_start+0.01, desc="spectral-tSNE")
|
170 |
_, rgb = eigenvector_to_rgb(
|
@@ -272,10 +278,12 @@ def dont_use_too_much_green(image_rgb):
|
|
272 |
return image_rgb
|
273 |
|
274 |
|
275 |
-
def to_pil_images(images, target_size=512, resize=True):
|
276 |
size = images[0].shape[1]
|
277 |
multiplier = target_size // size
|
278 |
res = int(size * multiplier)
|
|
|
|
|
279 |
pil_images = [
|
280 |
Image.fromarray((image * 255).cpu().numpy().astype(np.uint8))
|
281 |
for image in images
|
@@ -855,6 +863,8 @@ def ncut_run(
|
|
855 |
|
856 |
# ailgnedcut
|
857 |
if not directed:
|
|
|
|
|
858 |
rgb, _logging_str, eigvecs = compute_ncut(
|
859 |
features,
|
860 |
num_eig=num_eig,
|
@@ -872,7 +882,12 @@ def ncut_run(
|
|
872 |
indirect_connection=indirect_connection,
|
873 |
make_orthogonal=make_orthogonal,
|
874 |
metric=ncut_metric,
|
|
|
875 |
)
|
|
|
|
|
|
|
|
|
876 |
if directed:
|
877 |
head_index_text = kwargs.get("head_index_text", None)
|
878 |
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
@@ -978,26 +993,26 @@ def ncut_run(
|
|
978 |
|
979 |
def _ncut_run(*args, **kwargs):
|
980 |
n_ret = kwargs.get("n_ret", 1)
|
981 |
-
try:
|
982 |
-
|
983 |
-
|
984 |
|
985 |
-
|
986 |
|
987 |
-
|
988 |
-
|
989 |
|
990 |
-
|
991 |
-
|
992 |
-
except Exception as e:
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
|
1002 |
if USE_HUGGINGFACE_ZEROGPU:
|
1003 |
@spaces.GPU(duration=30)
|
@@ -1213,6 +1228,7 @@ def run_fn(
|
|
1213 |
alignedcut_eig_norm_plot=False,
|
1214 |
advanced=False,
|
1215 |
directed=False,
|
|
|
1216 |
):
|
1217 |
# print(node_type2, head_index_text, make_symmetric)
|
1218 |
progress=gr.Progress()
|
@@ -1353,6 +1369,7 @@ def run_fn(
|
|
1353 |
"node_type2": node_type2,
|
1354 |
"head_index_text": head_index_text,
|
1355 |
"make_symmetric": make_symmetric,
|
|
|
1356 |
}
|
1357 |
# print(kwargs)
|
1358 |
|
@@ -1664,7 +1681,7 @@ def load_and_append(existing_images, *args, **kwargs):
|
|
1664 |
gr.Info(f"Total images: {len(existing_images)}")
|
1665 |
return existing_images
|
1666 |
|
1667 |
-
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True):
|
1668 |
if markdown:
|
1669 |
gr.Markdown('### Input Images')
|
1670 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
@@ -1702,7 +1719,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1702 |
with gr.Row():
|
1703 |
button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
|
1704 |
gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
|
1705 |
-
button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name,
|
1706 |
return gallery, button
|
1707 |
example_items = [
|
1708 |
("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"),
|
@@ -2040,7 +2057,7 @@ def make_output_images_section(markdown=True, button=True):
|
|
2040 |
add_rotate_flip_buttons(output_gallery)
|
2041 |
return output_gallery
|
2042 |
|
2043 |
-
def make_parameters_section(is_lisa=False, model_ratio=True):
|
2044 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
2045 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
2046 |
model_names = list_models()
|
@@ -2105,18 +2122,18 @@ def make_parameters_section(is_lisa=False, model_ratio=True):
|
|
2105 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
2106 |
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
2107 |
|
2108 |
-
with gr.Accordion("Advanced Parameters: NCUT", open=False):
|
2109 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
2110 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
2111 |
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
2112 |
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
2113 |
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
2114 |
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2115 |
-
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2116 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2117 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
|
2118 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
2119 |
-
with gr.Accordion("Advanced Parameters: Visualization", open=False):
|
2120 |
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2121 |
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2122 |
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
@@ -2147,8 +2164,9 @@ demo = gr.Blocks(
|
|
2147 |
css=custom_css,
|
2148 |
)
|
2149 |
with demo:
|
2150 |
-
|
2151 |
-
|
|
|
2152 |
with gr.Tab('AlignedCut'):
|
2153 |
|
2154 |
with gr.Row():
|
@@ -2989,7 +3007,7 @@ with demo:
|
|
2989 |
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
2990 |
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
2991 |
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2992 |
-
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2993 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2994 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
|
2995 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
@@ -3422,6 +3440,274 @@ with demo:
|
|
3422 |
outputs=[mask_gallery, crop_gallery])
|
3423 |
|
3424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3425 |
with gr.Tab('📄About'):
|
3426 |
with gr.Column():
|
3427 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
@@ -3481,6 +3767,7 @@ with demo:
|
|
3481 |
hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
|
3482 |
hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
|
3483 |
hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
|
|
|
3484 |
|
3485 |
# with gr.Row():
|
3486 |
# with gr.Column():
|
@@ -3522,3 +3809,13 @@ demo.launch(share=True)
|
|
3522 |
# # %%
|
3523 |
|
3524 |
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
# %%
|
3 |
import copy
|
4 |
from datetime import datetime
|
5 |
+
import math
|
6 |
import pickle
|
7 |
from functools import partial
|
8 |
from io import BytesIO
|
|
|
138 |
indirect_connection=True,
|
139 |
make_orthogonal=False,
|
140 |
progess_start=0.4,
|
141 |
+
only_eigvecs=False,
|
142 |
):
|
143 |
progress = gr.Progress()
|
144 |
logging_str = ""
|
|
|
167 |
# print(f"NCUT time: {time.time() - start:.2f}s")
|
168 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
169 |
|
170 |
+
if only_eigvecs:
|
171 |
+
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
172 |
+
return None, logging_str, eigvecs
|
173 |
+
|
174 |
start = time.time()
|
175 |
progress(progess_start+0.01, desc="spectral-tSNE")
|
176 |
_, rgb = eigenvector_to_rgb(
|
|
|
278 |
return image_rgb
|
279 |
|
280 |
|
281 |
+
def to_pil_images(images, target_size=512, resize=True, force_size=False):
|
282 |
size = images[0].shape[1]
|
283 |
multiplier = target_size // size
|
284 |
res = int(size * multiplier)
|
285 |
+
if force_size:
|
286 |
+
res = target_size
|
287 |
pil_images = [
|
288 |
Image.fromarray((image * 255).cpu().numpy().astype(np.uint8))
|
289 |
for image in images
|
|
|
863 |
|
864 |
# ailgnedcut
|
865 |
if not directed:
|
866 |
+
only_eigvecs = kwargs.get("only_eigvecs", False)
|
867 |
+
|
868 |
rgb, _logging_str, eigvecs = compute_ncut(
|
869 |
features,
|
870 |
num_eig=num_eig,
|
|
|
882 |
indirect_connection=indirect_connection,
|
883 |
make_orthogonal=make_orthogonal,
|
884 |
metric=ncut_metric,
|
885 |
+
only_eigvecs=only_eigvecs,
|
886 |
)
|
887 |
+
|
888 |
+
if only_eigvecs:
|
889 |
+
return eigvecs, logging_str
|
890 |
+
|
891 |
if directed:
|
892 |
head_index_text = kwargs.get("head_index_text", None)
|
893 |
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
|
|
993 |
|
994 |
def _ncut_run(*args, **kwargs):
|
995 |
n_ret = kwargs.get("n_ret", 1)
|
996 |
+
# try:
|
997 |
+
# if torch.cuda.is_available():
|
998 |
+
# torch.cuda.empty_cache()
|
999 |
|
1000 |
+
# ret = ncut_run(*args, **kwargs)
|
1001 |
|
1002 |
+
# if torch.cuda.is_available():
|
1003 |
+
# torch.cuda.empty_cache()
|
1004 |
|
1005 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
1006 |
+
# return ret
|
1007 |
+
# except Exception as e:
|
1008 |
+
# gr.Error(str(e))
|
1009 |
+
# if torch.cuda.is_available():
|
1010 |
+
# torch.cuda.empty_cache()
|
1011 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
1012 |
|
1013 |
+
ret = ncut_run(*args, **kwargs)
|
1014 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
1015 |
+
return ret
|
1016 |
|
1017 |
if USE_HUGGINGFACE_ZEROGPU:
|
1018 |
@spaces.GPU(duration=30)
|
|
|
1228 |
alignedcut_eig_norm_plot=False,
|
1229 |
advanced=False,
|
1230 |
directed=False,
|
1231 |
+
only_eigvecs=False,
|
1232 |
):
|
1233 |
# print(node_type2, head_index_text, make_symmetric)
|
1234 |
progress=gr.Progress()
|
|
|
1369 |
"node_type2": node_type2,
|
1370 |
"head_index_text": head_index_text,
|
1371 |
"make_symmetric": make_symmetric,
|
1372 |
+
"only_eigvecs": only_eigvecs,
|
1373 |
}
|
1374 |
# print(kwargs)
|
1375 |
|
|
|
1681 |
gr.Info(f"Total images: {len(existing_images)}")
|
1682 |
return existing_images
|
1683 |
|
1684 |
+
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100):
|
1685 |
if markdown:
|
1686 |
gr.Markdown('### Input Images')
|
1687 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
|
|
1719 |
with gr.Row():
|
1720 |
button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
|
1721 |
gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
|
1722 |
+
button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, n_example_images, is_random=True, seed=42)), outputs=[input_gallery])
|
1723 |
return gallery, button
|
1724 |
example_items = [
|
1725 |
("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"),
|
|
|
2057 |
add_rotate_flip_buttons(output_gallery)
|
2058 |
return output_gallery
|
2059 |
|
2060 |
+
def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=True):
|
2061 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
2062 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
2063 |
model_names = list_models()
|
|
|
2122 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
2123 |
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
2124 |
|
2125 |
+
with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=parameter_dropdown):
|
2126 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
2127 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
2128 |
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
2129 |
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
2130 |
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
2131 |
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2132 |
+
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
2133 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2134 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
|
2135 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
2136 |
+
with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=parameter_dropdown):
|
2137 |
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2138 |
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2139 |
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
|
|
2164 |
css=custom_css,
|
2165 |
)
|
2166 |
with demo:
|
2167 |
+
|
2168 |
+
|
2169 |
+
|
2170 |
with gr.Tab('AlignedCut'):
|
2171 |
|
2172 |
with gr.Row():
|
|
|
3007 |
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
3008 |
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
3009 |
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
3010 |
+
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
3011 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
3012 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
|
3013 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
|
|
3440 |
outputs=[mask_gallery, crop_gallery])
|
3441 |
|
3442 |
|
3443 |
+
with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab:
|
3444 |
+
eigvecs = gr.State(torch.tensor([]))
|
3445 |
+
with gr.Row():
|
3446 |
+
with gr.Column(scale=5, min_width=200):
|
3447 |
+
gr.Markdown("### Step 1: Load Images and Run NCUT")
|
3448 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
|
3449 |
+
# submit_button.visible = False
|
3450 |
+
num_images_slider.value = 30
|
3451 |
+
[
|
3452 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
3453 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3454 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3455 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3456 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3457 |
+
] = make_parameters_section(parameter_dropdown=False)
|
3458 |
+
num_eig_slider.value = 1000
|
3459 |
+
num_eig_slider.visible = False
|
3460 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
3461 |
+
|
3462 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
3463 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
3464 |
+
|
3465 |
+
submit_button.click(
|
3466 |
+
partial(run_fn, n_ret=1, only_eigvecs=True),
|
3467 |
+
inputs=[
|
3468 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
3469 |
+
positive_prompt, negative_prompt,
|
3470 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
3471 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3472 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3473 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
3474 |
+
],
|
3475 |
+
outputs=[eigvecs, logging_text],
|
3476 |
+
)
|
3477 |
+
|
3478 |
+
with gr.Column(scale=5, min_width=200):
|
3479 |
+
gr.Markdown("### Step 2a: Pick an Image")
|
3480 |
+
from gradio_image_prompter import ImagePrompter
|
3481 |
+
with gr.Row():
|
3482 |
+
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
3483 |
+
load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
|
3484 |
+
gr.Markdown("### Step 2b: Draw a Point")
|
3485 |
+
gr.Markdown("""
|
3486 |
+
<h5>
|
3487 |
+
🖱️ Left Click: Foreground </br>
|
3488 |
+
</h5>
|
3489 |
+
""")
|
3490 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
3491 |
+
def update_prompt_image(original_images, index):
|
3492 |
+
images = original_images
|
3493 |
+
if images is None:
|
3494 |
+
return
|
3495 |
+
total_len = len(images)
|
3496 |
+
if total_len == 0:
|
3497 |
+
return
|
3498 |
+
if index >= total_len:
|
3499 |
+
index = total_len - 1
|
3500 |
+
|
3501 |
+
return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True)
|
3502 |
+
# return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True)
|
3503 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1])
|
3504 |
+
|
3505 |
+
child_idx = gr.State([])
|
3506 |
+
current_idx = gr.State(None)
|
3507 |
+
n_eig = gr.State(64)
|
3508 |
+
with gr.Column(scale=5, min_width=200):
|
3509 |
+
gr.Markdown("### Step 3: Check groupping")
|
3510 |
+
child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True)
|
3511 |
+
overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True)
|
3512 |
+
run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary')
|
3513 |
+
parent_plot = gr.Gallery(value=None, label="Parent", show_label=True, elem_id="parent_plot", interactive=False, rows=[1], columns=[2])
|
3514 |
+
parent_button = gr.Button("Use Parent", elem_id="run_parent")
|
3515 |
+
current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2])
|
3516 |
+
with gr.Column(scale=5, min_width=200):
|
3517 |
+
child_plots = []
|
3518 |
+
child_buttons = []
|
3519 |
+
for i in range(4):
|
3520 |
+
child_plots.append(gr.Gallery(value=None, label=f"Child {i}", show_label=True, elem_id=f"child_plot_{i}", interactive=False, rows=[1], columns=[2]))
|
3521 |
+
child_buttons.append(gr.Button(f"Use Child {i}", elem_id=f"run_child_{i}"))
|
3522 |
+
|
3523 |
+
def relative_xy(prompts):
|
3524 |
+
image = prompts['image']
|
3525 |
+
points = np.asarray(prompts['points'])
|
3526 |
+
if points.shape[0] == 0:
|
3527 |
+
return [], []
|
3528 |
+
is_point = points[:, 5] == 4.0
|
3529 |
+
points = points[is_point]
|
3530 |
+
is_positive = points[:, 2] == 1.0
|
3531 |
+
is_negative = points[:, 2] == 0.0
|
3532 |
+
xy = points[:, :2].tolist()
|
3533 |
+
if isinstance(image, str):
|
3534 |
+
image = Image.open(image)
|
3535 |
+
image = np.array(image)
|
3536 |
+
h, w = image.shape[:2]
|
3537 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
3538 |
+
# print(new_xy)
|
3539 |
+
return new_xy, is_positive
|
3540 |
+
|
3541 |
+
def xy_eigvec(prompts, image_idx, eigvecs):
|
3542 |
+
eigvec = eigvecs[image_idx]
|
3543 |
+
xy, is_positive = relative_xy(prompts)
|
3544 |
+
for i, (x, y) in enumerate(xy):
|
3545 |
+
if not is_positive[i]:
|
3546 |
+
continue
|
3547 |
+
x = int(x * eigvec.shape[1])
|
3548 |
+
y = int(y * eigvec.shape[0])
|
3549 |
+
return eigvec[y, x], (y, x)
|
3550 |
+
|
3551 |
+
from ncut_pytorch.ncut_pytorch import _transform_heatmap
|
3552 |
+
def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True):
|
3553 |
+
left = eigvecs[..., :n_eig]
|
3554 |
+
if flat_idx is not None:
|
3555 |
+
right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
|
3556 |
+
y, x = None, None
|
3557 |
+
else:
|
3558 |
+
right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
|
3559 |
+
right = right[:n_eig]
|
3560 |
+
left = F.normalize(left, p=2, dim=1)
|
3561 |
+
_right = F.normalize(right, p=2, dim=0)
|
3562 |
+
heatmap = left @ _right.unsqueeze(-1)
|
3563 |
+
heatmap = heatmap.squeeze(-1)
|
3564 |
+
heatmap = 1 - heatmap
|
3565 |
+
heatmap = _transform_heatmap(heatmap)
|
3566 |
+
if raw_heatmap:
|
3567 |
+
return heatmap
|
3568 |
+
# apply hot colormap and covert to PIL image 256x256
|
3569 |
+
heatmap = heatmap.cpu().numpy()
|
3570 |
+
hot_map = matplotlib.cm.get_cmap('hot')
|
3571 |
+
heatmap = hot_map(heatmap)
|
3572 |
+
pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
3573 |
+
if overlay_image:
|
3574 |
+
overlaied_images = []
|
3575 |
+
for i_image in range(len(images)):
|
3576 |
+
rgb_image = images[i_image].resize((256, 256))
|
3577 |
+
rgb_image = np.array(rgb_image)
|
3578 |
+
heatmap_image = np.array(pil_images[i_image])[..., :3]
|
3579 |
+
blend_image = 0.5 * rgb_image + 0.5 * heatmap_image
|
3580 |
+
blend_image = Image.fromarray(blend_image.astype(np.uint8))
|
3581 |
+
overlaied_images.append(blend_image)
|
3582 |
+
pil_images = overlaied_images
|
3583 |
+
return pil_images, (y, x)
|
3584 |
+
|
3585 |
+
def farthest_point_sampling(
|
3586 |
+
features,
|
3587 |
+
start_feature,
|
3588 |
+
num_sample=300,
|
3589 |
+
h=9,
|
3590 |
+
):
|
3591 |
+
import fpsample
|
3592 |
+
|
3593 |
+
h = min(h, int(np.log2(features.shape[0])))
|
3594 |
+
|
3595 |
+
inp = features.cpu().numpy()
|
3596 |
+
inp = np.concatenate([inp, start_feature[None, :]], axis=0)
|
3597 |
+
|
3598 |
+
kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
|
3599 |
+
inp, num_sample, h, start_idx=inp.shape[0] - 1
|
3600 |
+
).astype(np.int64)
|
3601 |
+
return kdline_fps_samples_idx
|
3602 |
+
|
3603 |
+
@torch.no_grad()
|
3604 |
+
def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True):
|
3605 |
+
gr.Info(f"current number of eigenvectors: {n_eig}")
|
3606 |
+
images = [image[0] for image in images]
|
3607 |
+
if isinstance(images[0], str):
|
3608 |
+
images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images]
|
3609 |
+
|
3610 |
+
current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image)
|
3611 |
+
parent_heatmap, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig/2), flat_idx, overlay_image=overlay_image)
|
3612 |
+
|
3613 |
+
# find childs
|
3614 |
+
# pca_eigvecs
|
3615 |
+
_eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
|
3616 |
+
u, s, v = torch.pca_lowrank(_eigvecs, q=8)
|
3617 |
+
_n = _eigvecs.shape[0]
|
3618 |
+
s /= math.sqrt(_n)
|
3619 |
+
_eigvecs = u @ torch.diag(s)
|
3620 |
+
|
3621 |
+
if flat_idx is None:
|
3622 |
+
_picked_eigvec = _eigvecs.reshape(*eigvecs.shape[:-1], 8)[image1_slider, y, x]
|
3623 |
+
else:
|
3624 |
+
_picked_eigvec = _eigvecs[flat_idx]
|
3625 |
+
l2_distance = torch.norm(_eigvecs - _picked_eigvec, dim=-1)
|
3626 |
+
average_distance = l2_distance.mean()
|
3627 |
+
distance_threshold = distance_slider * average_distance
|
3628 |
+
distance_mask = l2_distance < distance_threshold
|
3629 |
+
masked_eigvecs = _eigvecs[distance_mask]
|
3630 |
+
num_childs = min(4, masked_eigvecs.shape[0])
|
3631 |
+
assert num_childs > 0
|
3632 |
+
|
3633 |
+
child_idx = farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
|
3634 |
+
child_idx = np.sort(child_idx)[:-1]
|
3635 |
+
|
3636 |
+
# convert child_idx to flat_idx
|
3637 |
+
dummy_idx = torch.zeros(_eigvecs.shape[0], dtype=torch.bool)
|
3638 |
+
dummy_idx2 = torch.zeros(int(distance_mask.sum().item()), dtype=torch.bool)
|
3639 |
+
dummy_idx2[child_idx] = True
|
3640 |
+
dummy_idx[distance_mask] = dummy_idx2
|
3641 |
+
child_idx = torch.where(dummy_idx)[0]
|
3642 |
+
|
3643 |
+
|
3644 |
+
# current_child heatmap, for contrast
|
3645 |
+
current_child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), flat_idx, raw_heatmap=True, overlay_image=overlay_image)
|
3646 |
+
|
3647 |
+
# child_heatmaps, contrast mean of current clicked point
|
3648 |
+
child_heatmaps = []
|
3649 |
+
for idx in child_idx:
|
3650 |
+
child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, raw_heatmap=True, overlay_image=overlay_image)
|
3651 |
+
heatmap = child_heatmap - current_child_heatmap
|
3652 |
+
# convert [-1, 1] to [0, 1]
|
3653 |
+
heatmap = (heatmap + 1) / 2
|
3654 |
+
heatmap = heatmap.cpu().numpy()
|
3655 |
+
cm = matplotlib.cm.get_cmap('bwr')
|
3656 |
+
heatmap = cm(heatmap)
|
3657 |
+
# bwr with contrast
|
3658 |
+
pil_images1 = to_pil_images(torch.tensor(heatmap), resize=256)
|
3659 |
+
# no contrast
|
3660 |
+
pil_images2, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, overlay_image=overlay_image)
|
3661 |
+
|
3662 |
+
# combine contrast and no contrast
|
3663 |
+
pil_images = []
|
3664 |
+
for i in range(len(pil_images1)):
|
3665 |
+
pil_images.append(pil_images2[i])
|
3666 |
+
pil_images.append(pil_images1[i])
|
3667 |
+
|
3668 |
+
|
3669 |
+
child_heatmaps.append(pil_images)
|
3670 |
+
|
3671 |
+
return parent_heatmap, current_heatmap, *child_heatmaps, child_idx.tolist()
|
3672 |
+
|
3673 |
+
# def debug_fn(eigvecs):
|
3674 |
+
# shape = eigvecs.shape
|
3675 |
+
# gr.Info(f"eigvecs shape: {shape}")
|
3676 |
+
|
3677 |
+
# run_button.click(
|
3678 |
+
# debug_fn,
|
3679 |
+
# inputs=[eigvecs],
|
3680 |
+
# outputs=[],
|
3681 |
+
# )
|
3682 |
+
none_placeholder = gr.State(None)
|
3683 |
+
run_button.click(
|
3684 |
+
run_heatmap,
|
3685 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, none_placeholder, overlay_image_checkbox],
|
3686 |
+
outputs=[parent_plot, current_plot, *child_plots, child_idx],
|
3687 |
+
)
|
3688 |
+
|
3689 |
+
def run_paraent(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True):
|
3690 |
+
n_eig = int(n_eig/2)
|
3691 |
+
return n_eig, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image)
|
3692 |
+
|
3693 |
+
parent_button.click(
|
3694 |
+
run_paraent,
|
3695 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, current_idx, overlay_image_checkbox],
|
3696 |
+
outputs=[n_eig, parent_plot, current_plot, *child_plots, child_idx],
|
3697 |
+
)
|
3698 |
+
|
3699 |
+
def run_child(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, child_idx=[], overlay_image=True, i_child=0):
|
3700 |
+
n_eig = int(n_eig*2)
|
3701 |
+
flat_idx = child_idx[i_child]
|
3702 |
+
return n_eig, flat_idx, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image)
|
3703 |
+
|
3704 |
+
for i in range(4):
|
3705 |
+
child_buttons[i].click(
|
3706 |
+
partial(run_child, i_child=i),
|
3707 |
+
inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox],
|
3708 |
+
outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx],
|
3709 |
+
)
|
3710 |
+
|
3711 |
with gr.Tab('📄About'):
|
3712 |
with gr.Column():
|
3713 |
gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**")
|
|
|
3767 |
hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
|
3768 |
hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
|
3769 |
hidden_button.change(unlock_tabs, n_smiles, tab_directed_ncut)
|
3770 |
+
hidden_button.change(unlock_tabs, n_smiles, test_playground_tab)
|
3771 |
|
3772 |
# with gr.Row():
|
3773 |
# with gr.Column():
|
|
|
3809 |
# # %%
|
3810 |
|
3811 |
# %%
|
3812 |
+
|
3813 |
+
# %%
|
3814 |
+
|
3815 |
+
# %%
|
3816 |
+
|
3817 |
+
# %%
|
3818 |
+
|
3819 |
+
# %%
|
3820 |
+
|
3821 |
+
# %%
|