Spaces:
Running
on
Zero
Running
on
Zero
add LISA
Browse files- app.py +147 -15
- requirements.txt +3 -2
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
|
|
3 |
import os
|
4 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
5 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
@@ -232,6 +233,10 @@ def ncut_run(
|
|
232 |
recursion_l2_gamma=0.5,
|
233 |
recursion_l3_gamma=0.5,
|
234 |
video_output=False,
|
|
|
|
|
|
|
|
|
235 |
):
|
236 |
logging_str = ""
|
237 |
if "AlignedThreeModelAttnNodes" == model_name:
|
@@ -256,6 +261,24 @@ def ncut_run(
|
|
256 |
if "AlignedThreeModelAttnNodes" == model_name:
|
257 |
# dirty patch for the alignedcut paper
|
258 |
features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
else:
|
260 |
features = extract_features(
|
261 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
@@ -340,6 +363,14 @@ def ncut_run(
|
|
340 |
galleries.append(to_pil_images(_rgb, target_size=56))
|
341 |
return *galleries, logging_str
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
rgb = dont_use_too_much_green(rgb)
|
344 |
|
345 |
|
@@ -451,7 +482,8 @@ def load_alignedthreemodel():
|
|
451 |
# model = torch.load(save_path)
|
452 |
return model
|
453 |
|
454 |
-
|
|
|
455 |
|
456 |
|
457 |
def run_fn(
|
@@ -462,6 +494,10 @@ def run_fn(
|
|
462 |
node_type="block",
|
463 |
positive_prompt="",
|
464 |
negative_prompt="",
|
|
|
|
|
|
|
|
|
465 |
affinity_focal_gamma=0.3,
|
466 |
num_sample_ncut=10000,
|
467 |
knn_ncut=10,
|
@@ -515,10 +551,10 @@ def run_fn(
|
|
515 |
model.timestep = layer
|
516 |
layer = 1
|
517 |
|
518 |
-
if model_name in
|
519 |
model.positive_prompt = positive_prompt
|
520 |
model.negative_prompt = negative_prompt
|
521 |
-
|
522 |
kwargs = {
|
523 |
"model_name": model_name,
|
524 |
"layer": layer,
|
@@ -543,11 +579,18 @@ def run_fn(
|
|
543 |
"recursion_l2_gamma": recursion_l2_gamma,
|
544 |
"recursion_l3_gamma": recursion_l3_gamma,
|
545 |
"video_output": video_output,
|
|
|
|
|
|
|
|
|
546 |
}
|
547 |
# print(kwargs)
|
548 |
|
549 |
if old_school_ncut:
|
550 |
-
super_duper_long_run(model, images, **kwargs)
|
|
|
|
|
|
|
551 |
|
552 |
num_images = len(images)
|
553 |
if num_images >= 100:
|
@@ -702,18 +745,26 @@ def make_output_images_section():
|
|
702 |
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
703 |
return output_gallery
|
704 |
|
705 |
-
def make_parameters_section():
|
706 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
707 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
708 |
model_names = list_models()
|
709 |
model_names = sorted(model_names)
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
717 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
|
718 |
|
719 |
def change_layer_slider(model_name):
|
@@ -724,6 +775,12 @@ def make_parameters_section():
|
|
724 |
return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
|
725 |
gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
layer_dict = LAYER_DICT
|
728 |
if model_name in layer_dict:
|
729 |
value = layer_dict[model_name]
|
@@ -736,7 +793,7 @@ def make_parameters_section():
|
|
736 |
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
|
737 |
|
738 |
def change_prompt_text(model_name):
|
739 |
-
if model_name in
|
740 |
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
|
741 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
|
742 |
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
|
@@ -788,12 +845,15 @@ with demo:
|
|
788 |
|
789 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
790 |
|
|
|
|
|
791 |
|
792 |
submit_button.click(
|
793 |
run_fn,
|
794 |
inputs=[
|
795 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
796 |
positive_prompt, negative_prompt,
|
|
|
797 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
798 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
799 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
@@ -848,11 +908,15 @@ with demo:
|
|
848 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
849 |
|
850 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
|
|
|
|
|
|
851 |
submit_button.click(
|
852 |
run_fn,
|
853 |
inputs=[
|
854 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
855 |
positive_prompt, negative_prompt,
|
|
|
856 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
857 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
858 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -918,11 +982,15 @@ with demo:
|
|
918 |
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
919 |
number_placeholder.visible = False
|
920 |
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
|
|
|
|
|
|
921 |
submit_button.click(
|
922 |
run_fn,
|
923 |
inputs=[
|
924 |
-
input_gallery, model_dropdown, layer_slider,
|
925 |
positive_prompt, negative_prompt,
|
|
|
926 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
927 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
928 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -957,11 +1025,15 @@ with demo:
|
|
957 |
clear_video_button.click(lambda x: (None, None), outputs=[video_input_gallery, video_output_gallery])
|
958 |
place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false")
|
959 |
place_holder_false.visible = False
|
|
|
|
|
|
|
960 |
submit_button.click(
|
961 |
run_fn,
|
962 |
inputs=[
|
963 |
-
|
964 |
positive_prompt, negative_prompt,
|
|
|
965 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
966 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
967 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -979,6 +1051,57 @@ with demo:
|
|
979 |
from draft_gradio_app_text import make_demo
|
980 |
make_demo()
|
981 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
982 |
with gr.Tab('Model Aligned'):
|
983 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
984 |
gr.Markdown('---')
|
@@ -1022,11 +1145,16 @@ with demo:
|
|
1022 |
|
1023 |
|
1024 |
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
|
|
|
|
|
|
|
|
1025 |
submit_button.click(
|
1026 |
run_fn,
|
1027 |
inputs=[
|
1028 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1029 |
positive_prompt, negative_prompt,
|
|
|
1030 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1031 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1032 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
@@ -1049,11 +1177,15 @@ with demo:
|
|
1049 |
] = make_parameters_section()
|
1050 |
# logging text box
|
1051 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
|
|
|
|
|
|
1052 |
submit_button.click(
|
1053 |
run_fn,
|
1054 |
inputs=[
|
1055 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1056 |
positive_prompt, negative_prompt,
|
|
|
1057 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1058 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1059 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
3 |
+
import copy
|
4 |
import os
|
5 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
6 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
|
|
233 |
recursion_l2_gamma=0.5,
|
234 |
recursion_l3_gamma=0.5,
|
235 |
video_output=False,
|
236 |
+
is_lisa=False,
|
237 |
+
lisa_prompt1="",
|
238 |
+
lisa_prompt2="",
|
239 |
+
lisa_prompt3="",
|
240 |
):
|
241 |
logging_str = ""
|
242 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
|
261 |
if "AlignedThreeModelAttnNodes" == model_name:
|
262 |
# dirty patch for the alignedcut paper
|
263 |
features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
|
264 |
+
elif is_lisa == True:
|
265 |
+
# dirty patch for the LISA model
|
266 |
+
features = []
|
267 |
+
with torch.no_grad():
|
268 |
+
model = model.cuda()
|
269 |
+
images = images.cuda()
|
270 |
+
lisa_prompts = [lisa_prompt1, lisa_prompt2, lisa_prompt3]
|
271 |
+
for prompt in lisa_prompts:
|
272 |
+
import bleach
|
273 |
+
prompt = bleach.clean(prompt)
|
274 |
+
prompt = prompt.strip()
|
275 |
+
# print(prompt)
|
276 |
+
# # copy the sting to a new string
|
277 |
+
# copy_s = copy.copy(prompt)
|
278 |
+
feature = model(images, input_str=prompt)[node_type][0]
|
279 |
+
feature = F.normalize(feature, dim=-1)
|
280 |
+
features.append(feature.cpu().float())
|
281 |
+
features = torch.stack(features)
|
282 |
else:
|
283 |
features = extract_features(
|
284 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
|
|
363 |
galleries.append(to_pil_images(_rgb, target_size=56))
|
364 |
return *galleries, logging_str
|
365 |
|
366 |
+
if is_lisa == True:
|
367 |
+
# dirty patch for the LISA model
|
368 |
+
galleries = []
|
369 |
+
for i_prompt in range(len(lisa_prompts)):
|
370 |
+
_rgb = rgb[i_prompt]
|
371 |
+
galleries.append(to_pil_images(_rgb, target_size=256))
|
372 |
+
return *galleries, logging_str
|
373 |
+
|
374 |
rgb = dont_use_too_much_green(rgb)
|
375 |
|
376 |
|
|
|
482 |
# model = torch.load(save_path)
|
483 |
return model
|
484 |
|
485 |
+
promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
|
486 |
+
promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
|
487 |
|
488 |
|
489 |
def run_fn(
|
|
|
494 |
node_type="block",
|
495 |
positive_prompt="",
|
496 |
negative_prompt="",
|
497 |
+
is_lisa=False,
|
498 |
+
lisa_prompt1="",
|
499 |
+
lisa_prompt2="",
|
500 |
+
lisa_prompt3="",
|
501 |
affinity_focal_gamma=0.3,
|
502 |
num_sample_ncut=10000,
|
503 |
knn_ncut=10,
|
|
|
551 |
model.timestep = layer
|
552 |
layer = 1
|
553 |
|
554 |
+
if model_name in promptable_diffusion_models:
|
555 |
model.positive_prompt = positive_prompt
|
556 |
model.negative_prompt = negative_prompt
|
557 |
+
|
558 |
kwargs = {
|
559 |
"model_name": model_name,
|
560 |
"layer": layer,
|
|
|
579 |
"recursion_l2_gamma": recursion_l2_gamma,
|
580 |
"recursion_l3_gamma": recursion_l3_gamma,
|
581 |
"video_output": video_output,
|
582 |
+
"lisa_prompt1": lisa_prompt1,
|
583 |
+
"lisa_prompt2": lisa_prompt2,
|
584 |
+
"lisa_prompt3": lisa_prompt3,
|
585 |
+
"is_lisa": is_lisa,
|
586 |
}
|
587 |
# print(kwargs)
|
588 |
|
589 |
if old_school_ncut:
|
590 |
+
return super_duper_long_run(model, images, **kwargs)
|
591 |
+
|
592 |
+
if is_lisa:
|
593 |
+
return super_duper_long_run(model, images, **kwargs)
|
594 |
|
595 |
num_images = len(images)
|
596 |
if num_images >= 100:
|
|
|
745 |
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
746 |
return output_gallery
|
747 |
|
748 |
+
def make_parameters_section(is_lisa=False):
|
749 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
750 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
751 |
model_names = list_models()
|
752 |
model_names = sorted(model_names)
|
753 |
+
if is_lisa:
|
754 |
+
model_dropdown = gr.Dropdown(["LISA(xinlai/LISA-7B-v1)"], label="Backbone", value="LISA(xinlai/LISA-7B-v1)", elem_id="model_name")
|
755 |
+
layer_slider = gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False)
|
756 |
+
layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
|
757 |
+
positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False)
|
758 |
+
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
|
759 |
+
node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
|
760 |
+
else:
|
761 |
+
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
|
762 |
+
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
763 |
+
positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
|
764 |
+
positive_prompt.visible = False
|
765 |
+
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
766 |
+
negative_prompt.visible = False
|
767 |
+
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
|
768 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
|
769 |
|
770 |
def change_layer_slider(model_name):
|
|
|
775 |
return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
|
776 |
gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
|
777 |
|
778 |
+
if model_name == "LISSL(xinlai/LISSL-7B-v1)":
|
779 |
+
layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
|
780 |
+
default_layer = "dec_1_block"
|
781 |
+
return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False),
|
782 |
+
gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
|
783 |
+
|
784 |
layer_dict = LAYER_DICT
|
785 |
if model_name in layer_dict:
|
786 |
value = layer_dict[model_name]
|
|
|
793 |
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
|
794 |
|
795 |
def change_prompt_text(model_name):
|
796 |
+
if model_name in promptable_diffusion_models:
|
797 |
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
|
798 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
|
799 |
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
|
|
|
845 |
|
846 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
847 |
|
848 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
849 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
850 |
|
851 |
submit_button.click(
|
852 |
run_fn,
|
853 |
inputs=[
|
854 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
855 |
positive_prompt, negative_prompt,
|
856 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
857 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
858 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
859 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
908 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
909 |
|
910 |
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
911 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
912 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
913 |
+
|
914 |
submit_button.click(
|
915 |
run_fn,
|
916 |
inputs=[
|
917 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
918 |
positive_prompt, negative_prompt,
|
919 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
920 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
921 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
922 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
982 |
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
983 |
number_placeholder.visible = False
|
984 |
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
985 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
986 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
987 |
+
|
988 |
submit_button.click(
|
989 |
run_fn,
|
990 |
inputs=[
|
991 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
992 |
positive_prompt, negative_prompt,
|
993 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
994 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
995 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
996 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
1025 |
clear_video_button.click(lambda x: (None, None), outputs=[video_input_gallery, video_output_gallery])
|
1026 |
place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false")
|
1027 |
place_holder_false.visible = False
|
1028 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
1029 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1030 |
+
|
1031 |
submit_button.click(
|
1032 |
run_fn,
|
1033 |
inputs=[
|
1034 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1035 |
positive_prompt, negative_prompt,
|
1036 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
1037 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1038 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1039 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
1051 |
from draft_gradio_app_text import make_demo
|
1052 |
make_demo()
|
1053 |
|
1054 |
+
with gr.Tab('Vision-Language'):
|
1055 |
+
gr.Markdown('[LISA]((https://arxiv.org/pdf/2308.00692)) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
|
1056 |
+
gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
|
1057 |
+
gr.Markdown('This page aims to see how the text prompt affects the image features')
|
1058 |
+
gr.Markdown('---')
|
1059 |
+
gr.Markdown('<p style="text-align: center;">Color is <b>aligned</b> across 3 prompts. NCUT is computed on the concatenated features from 3 prompts.</p>')
|
1060 |
+
with gr.Row():
|
1061 |
+
with gr.Column(scale=5, min_width=200):
|
1062 |
+
gr.Markdown('### Output (Prompt #1)')
|
1063 |
+
l1_gallery = gr.Gallery(value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="auto")
|
1064 |
+
prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
|
1065 |
+
with gr.Column(scale=5, min_width=200):
|
1066 |
+
gr.Markdown('### Output (Prompt #2)')
|
1067 |
+
l2_gallery = gr.Gallery(value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="auto")
|
1068 |
+
prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
|
1069 |
+
with gr.Column(scale=5, min_width=200):
|
1070 |
+
gr.Markdown('### Output (Prompt #3)')
|
1071 |
+
l3_gallery = gr.Gallery(value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="auto")
|
1072 |
+
prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
|
1073 |
+
|
1074 |
+
with gr.Row():
|
1075 |
+
with gr.Column(scale=5, min_width=200):
|
1076 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1077 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=False)
|
1078 |
+
clear_images_button.click(lambda x: ([], [], [], []), outputs=[input_gallery, l1_gallery, l2_gallery, l3_gallery])
|
1079 |
+
|
1080 |
+
with gr.Column(scale=5, min_width=200):
|
1081 |
+
[
|
1082 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1083 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1084 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1085 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
1086 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
1087 |
+
] = make_parameters_section(is_lisa=True)
|
1088 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1089 |
+
|
1090 |
+
galleries = [l1_gallery, l2_gallery, l3_gallery]
|
1091 |
+
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder", visible=False)
|
1092 |
+
submit_button.click(
|
1093 |
+
run_fn,
|
1094 |
+
inputs=[
|
1095 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1096 |
+
positive_prompt, negative_prompt,
|
1097 |
+
true_placeholder, prompt1, prompt2, prompt3,
|
1098 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1099 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1100 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
1101 |
+
],
|
1102 |
+
outputs=galleries + [logging_text],
|
1103 |
+
)
|
1104 |
+
|
1105 |
with gr.Tab('Model Aligned'):
|
1106 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
1107 |
gr.Markdown('---')
|
|
|
1145 |
|
1146 |
|
1147 |
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
1148 |
+
|
1149 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
1150 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1151 |
+
|
1152 |
submit_button.click(
|
1153 |
run_fn,
|
1154 |
inputs=[
|
1155 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1156 |
positive_prompt, negative_prompt,
|
1157 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
1158 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1159 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1160 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
1177 |
] = make_parameters_section()
|
1178 |
# logging text box
|
1179 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1180 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
1181 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1182 |
+
|
1183 |
submit_button.click(
|
1184 |
run_fn,
|
1185 |
inputs=[
|
1186 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1187 |
positive_prompt, negative_prompt,
|
1188 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
1189 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1190 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1191 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ torch
|
|
2 |
torchvision
|
3 |
opencv-python
|
4 |
decord
|
5 |
-
transformers
|
6 |
datasets
|
7 |
diffusers
|
8 |
accelerate
|
@@ -12,6 +12,7 @@ pillow==9.4.0
|
|
12 |
SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
|
13 |
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
|
14 |
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
|
|
|
15 |
timm
|
16 |
open-clip-torch==2.20.0
|
17 |
-
ncut-pytorch>=1.3.
|
|
|
2 |
torchvision
|
3 |
opencv-python
|
4 |
decord
|
5 |
+
transformers==4.31.0
|
6 |
datasets
|
7 |
diffusers
|
8 |
accelerate
|
|
|
12 |
SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
|
13 |
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
|
14 |
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
|
15 |
+
lisa @ git+https://github.com/huzeyann/LISA.git
|
16 |
timm
|
17 |
open-clip-torch==2.20.0
|
18 |
+
ncut-pytorch>=1.3.13
|