Spaces:
Running
on
Zero
Running
on
Zero
add prompt diffusion
Browse files- app.py +47 -20
- requirements.txt +1 -1
app.py
CHANGED
@@ -351,19 +351,19 @@ def ncut_run(
|
|
351 |
return to_pil_images(rgb), logging_str
|
352 |
|
353 |
def _ncut_run(*args, **kwargs):
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
ret = ncut_run(*args, **kwargs)
|
366 |
-
return ret
|
367 |
|
368 |
if USE_HUGGINGFACE_ZEROGPU:
|
369 |
@spaces.GPU(duration=20)
|
@@ -445,12 +445,17 @@ def load_alignedthreemodel():
|
|
445 |
# model = torch.load(save_path)
|
446 |
return model
|
447 |
|
|
|
|
|
|
|
448 |
def run_fn(
|
449 |
images,
|
450 |
model_name="SAM(sam_vit_b)",
|
451 |
layer=-1,
|
452 |
num_eig=100,
|
453 |
node_type="block",
|
|
|
|
|
454 |
affinity_focal_gamma=0.3,
|
455 |
num_sample_ncut=10000,
|
456 |
knn_ncut=10,
|
@@ -503,6 +508,10 @@ def run_fn(
|
|
503 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
504 |
model.timestep = layer
|
505 |
layer = 1
|
|
|
|
|
|
|
|
|
506 |
|
507 |
kwargs = {
|
508 |
"model_name": model_name,
|
@@ -694,6 +703,10 @@ def make_parameters_section():
|
|
694 |
model_names = sorted(model_names)
|
695 |
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
|
696 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
|
|
|
|
|
|
|
|
697 |
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
|
698 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
|
699 |
|
@@ -716,6 +729,14 @@ def make_parameters_section():
|
|
716 |
gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True))
|
717 |
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
|
718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
with gr.Accordion("➡️ Click to expand: more parameters", open=False):
|
720 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
721 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
@@ -732,7 +753,7 @@ def make_parameters_section():
|
|
732 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
733 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
734 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
735 |
-
sampling_method_dropdown]
|
736 |
|
737 |
demo = gr.Blocks(
|
738 |
theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
|
@@ -754,7 +775,7 @@ with demo:
|
|
754 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
755 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
756 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
757 |
-
sampling_method_dropdown
|
758 |
] = make_parameters_section()
|
759 |
# logging text box
|
760 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
@@ -766,6 +787,7 @@ with demo:
|
|
766 |
run_fn,
|
767 |
inputs=[
|
768 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
|
|
769 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
770 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
771 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
@@ -809,7 +831,7 @@ with demo:
|
|
809 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
810 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
811 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
812 |
-
sampling_method_dropdown
|
813 |
] = make_parameters_section()
|
814 |
old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut")
|
815 |
invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, knn_ncut_slider,
|
@@ -824,6 +846,7 @@ with demo:
|
|
824 |
run_fn,
|
825 |
inputs=[
|
826 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
|
|
827 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
828 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
829 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -870,7 +893,7 @@ with demo:
|
|
870 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
871 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
872 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
873 |
-
sampling_method_dropdown
|
874 |
] = make_parameters_section()
|
875 |
num_eig_slider.visible = False
|
876 |
affinity_focal_gamma_slider.visible = False
|
@@ -893,6 +916,7 @@ with demo:
|
|
893 |
run_fn,
|
894 |
inputs=[
|
895 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
|
|
896 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
897 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
898 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -916,7 +940,7 @@ with demo:
|
|
916 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
917 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
918 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
919 |
-
sampling_method_dropdown
|
920 |
] = make_parameters_section()
|
921 |
num_sample_tsne_slider.value = 1000
|
922 |
perplexity_slider.value = 500
|
@@ -931,6 +955,7 @@ with demo:
|
|
931 |
run_fn,
|
932 |
inputs=[
|
933 |
video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
|
|
934 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
935 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
936 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
@@ -971,7 +996,7 @@ with demo:
|
|
971 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
972 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
973 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
974 |
-
sampling_method_dropdown
|
975 |
] = make_parameters_section()
|
976 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
977 |
model_dropdown.visible = False
|
@@ -995,6 +1020,7 @@ with demo:
|
|
995 |
run_fn,
|
996 |
inputs=[
|
997 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
|
|
998 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
999 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1000 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
@@ -1013,7 +1039,7 @@ with demo:
|
|
1013 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1014 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1015 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
1016 |
-
sampling_method_dropdown
|
1017 |
] = make_parameters_section()
|
1018 |
# logging text box
|
1019 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
@@ -1021,6 +1047,7 @@ with demo:
|
|
1021 |
run_fn,
|
1022 |
inputs=[
|
1023 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
|
|
1024 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1025 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1026 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
351 |
return to_pil_images(rgb), logging_str
|
352 |
|
353 |
def _ncut_run(*args, **kwargs):
|
354 |
+
try:
|
355 |
+
ret = ncut_run(*args, **kwargs)
|
356 |
+
if torch.cuda.is_available():
|
357 |
+
torch.cuda.empty_cache()
|
358 |
+
return ret
|
359 |
+
except Exception as e:
|
360 |
+
gr.Error(str(e))
|
361 |
+
if torch.cuda.is_available():
|
362 |
+
torch.cuda.empty_cache()
|
363 |
+
return [], "Error: " + str(e)
|
364 |
+
|
365 |
+
# ret = ncut_run(*args, **kwargs)
|
366 |
+
# return ret
|
367 |
|
368 |
if USE_HUGGINGFACE_ZEROGPU:
|
369 |
@spaces.GPU(duration=20)
|
|
|
445 |
# model = torch.load(save_path)
|
446 |
return model
|
447 |
|
448 |
+
promptable_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
|
449 |
+
|
450 |
+
|
451 |
def run_fn(
|
452 |
images,
|
453 |
model_name="SAM(sam_vit_b)",
|
454 |
layer=-1,
|
455 |
num_eig=100,
|
456 |
node_type="block",
|
457 |
+
positive_prompt="",
|
458 |
+
negative_prompt="",
|
459 |
affinity_focal_gamma=0.3,
|
460 |
num_sample_ncut=10000,
|
461 |
knn_ncut=10,
|
|
|
508 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
509 |
model.timestep = layer
|
510 |
layer = 1
|
511 |
+
|
512 |
+
if model_name in promptable_models:
|
513 |
+
model.positive_prompt = positive_prompt
|
514 |
+
model.negative_prompt = negative_prompt
|
515 |
|
516 |
kwargs = {
|
517 |
"model_name": model_name,
|
|
|
703 |
model_names = sorted(model_names)
|
704 |
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
|
705 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
706 |
+
positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
|
707 |
+
positive_prompt.visible = False
|
708 |
+
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
709 |
+
negative_prompt.visible = False
|
710 |
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
|
711 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
|
712 |
|
|
|
729 |
gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True))
|
730 |
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
|
731 |
|
732 |
+
def change_prompt_text(model_name):
|
733 |
+
if model_name in promptable_models:
|
734 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
|
735 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
|
736 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
|
737 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
738 |
+
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
739 |
+
|
740 |
with gr.Accordion("➡️ Click to expand: more parameters", open=False):
|
741 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
742 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
|
|
753 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
754 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
755 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
756 |
+
sampling_method_dropdown, positive_prompt, negative_prompt]
|
757 |
|
758 |
demo = gr.Blocks(
|
759 |
theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
|
|
|
775 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
776 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
777 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
778 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
779 |
] = make_parameters_section()
|
780 |
# logging text box
|
781 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
|
|
787 |
run_fn,
|
788 |
inputs=[
|
789 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
790 |
+
positive_prompt, negative_prompt,
|
791 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
792 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
793 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
831 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
832 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
833 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
834 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
835 |
] = make_parameters_section()
|
836 |
old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut")
|
837 |
invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, knn_ncut_slider,
|
|
|
846 |
run_fn,
|
847 |
inputs=[
|
848 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
849 |
+
positive_prompt, negative_prompt,
|
850 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
851 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
852 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
893 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
894 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
895 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
896 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
897 |
] = make_parameters_section()
|
898 |
num_eig_slider.visible = False
|
899 |
affinity_focal_gamma_slider.visible = False
|
|
|
916 |
run_fn,
|
917 |
inputs=[
|
918 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
919 |
+
positive_prompt, negative_prompt,
|
920 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
921 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
922 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
940 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
941 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
942 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
943 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
944 |
] = make_parameters_section()
|
945 |
num_sample_tsne_slider.value = 1000
|
946 |
perplexity_slider.value = 500
|
|
|
955 |
run_fn,
|
956 |
inputs=[
|
957 |
video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
958 |
+
positive_prompt, negative_prompt,
|
959 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
960 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
961 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
|
|
996 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
997 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
998 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
999 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
1000 |
] = make_parameters_section()
|
1001 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
1002 |
model_dropdown.visible = False
|
|
|
1020 |
run_fn,
|
1021 |
inputs=[
|
1022 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1023 |
+
positive_prompt, negative_prompt,
|
1024 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1025 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1026 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
1039 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1040 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1041 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
1042 |
+
sampling_method_dropdown, positive_prompt, negative_prompt
|
1043 |
] = make_parameters_section()
|
1044 |
# logging text box
|
1045 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
|
|
1047 |
run_fn,
|
1048 |
inputs=[
|
1049 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
1050 |
+
positive_prompt, negative_prompt,
|
1051 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
1052 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1053 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
requirements.txt
CHANGED
@@ -14,4 +14,4 @@ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@
|
|
14 |
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
|
15 |
timm
|
16 |
open-clip-torch==2.20.0
|
17 |
-
ncut-pytorch>=1.3.
|
|
|
14 |
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
|
15 |
timm
|
16 |
open-clip-torch==2.20.0
|
17 |
+
ncut-pytorch>=1.3.10
|