huzey commited on
Commit
560d63b
1 Parent(s): b83f3a8

add prompt diffusion

Browse files
Files changed (2) hide show
  1. app.py +47 -20
  2. requirements.txt +1 -1
app.py CHANGED
@@ -351,19 +351,19 @@ def ncut_run(
351
  return to_pil_images(rgb), logging_str
352
 
353
  def _ncut_run(*args, **kwargs):
354
- # try:
355
- # ret = ncut_run(*args, **kwargs)
356
- # if torch.cuda.is_available():
357
- # torch.cuda.empty_cache()
358
- # return ret
359
- # except Exception as e:
360
- # gr.Error(str(e))
361
- # if torch.cuda.is_available():
362
- # torch.cuda.empty_cache()
363
- # return [], "Error: " + str(e)
364
-
365
- ret = ncut_run(*args, **kwargs)
366
- return ret
367
 
368
  if USE_HUGGINGFACE_ZEROGPU:
369
  @spaces.GPU(duration=20)
@@ -445,12 +445,17 @@ def load_alignedthreemodel():
445
  # model = torch.load(save_path)
446
  return model
447
 
 
 
 
448
  def run_fn(
449
  images,
450
  model_name="SAM(sam_vit_b)",
451
  layer=-1,
452
  num_eig=100,
453
  node_type="block",
 
 
454
  affinity_focal_gamma=0.3,
455
  num_sample_ncut=10000,
456
  knn_ncut=10,
@@ -503,6 +508,10 @@ def run_fn(
503
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
504
  model.timestep = layer
505
  layer = 1
 
 
 
 
506
 
507
  kwargs = {
508
  "model_name": model_name,
@@ -694,6 +703,10 @@ def make_parameters_section():
694
  model_names = sorted(model_names)
695
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
696
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
 
 
 
 
697
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
698
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
699
 
@@ -716,6 +729,14 @@ def make_parameters_section():
716
  gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True))
717
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
718
 
 
 
 
 
 
 
 
 
719
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
720
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
721
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
@@ -732,7 +753,7 @@ def make_parameters_section():
732
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
733
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
734
  perplexity_slider, n_neighbors_slider, min_dist_slider,
735
- sampling_method_dropdown]
736
 
737
  demo = gr.Blocks(
738
  theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
@@ -754,7 +775,7 @@ with demo:
754
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
755
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
756
  perplexity_slider, n_neighbors_slider, min_dist_slider,
757
- sampling_method_dropdown
758
  ] = make_parameters_section()
759
  # logging text box
760
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
@@ -766,6 +787,7 @@ with demo:
766
  run_fn,
767
  inputs=[
768
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
 
769
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
770
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
771
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
@@ -809,7 +831,7 @@ with demo:
809
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
810
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
811
  perplexity_slider, n_neighbors_slider, min_dist_slider,
812
- sampling_method_dropdown
813
  ] = make_parameters_section()
814
  old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut")
815
  invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, knn_ncut_slider,
@@ -824,6 +846,7 @@ with demo:
824
  run_fn,
825
  inputs=[
826
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
 
827
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
828
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
829
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -870,7 +893,7 @@ with demo:
870
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
871
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
872
  perplexity_slider, n_neighbors_slider, min_dist_slider,
873
- sampling_method_dropdown
874
  ] = make_parameters_section()
875
  num_eig_slider.visible = False
876
  affinity_focal_gamma_slider.visible = False
@@ -893,6 +916,7 @@ with demo:
893
  run_fn,
894
  inputs=[
895
  input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
 
896
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
897
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
898
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -916,7 +940,7 @@ with demo:
916
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
917
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
918
  perplexity_slider, n_neighbors_slider, min_dist_slider,
919
- sampling_method_dropdown
920
  ] = make_parameters_section()
921
  num_sample_tsne_slider.value = 1000
922
  perplexity_slider.value = 500
@@ -931,6 +955,7 @@ with demo:
931
  run_fn,
932
  inputs=[
933
  video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
 
934
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
935
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
936
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -971,7 +996,7 @@ with demo:
971
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
972
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
973
  perplexity_slider, n_neighbors_slider, min_dist_slider,
974
- sampling_method_dropdown
975
  ] = make_parameters_section()
976
  model_dropdown.value = "AlignedThreeModelAttnNodes"
977
  model_dropdown.visible = False
@@ -995,6 +1020,7 @@ with demo:
995
  run_fn,
996
  inputs=[
997
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
 
998
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
999
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1000
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
@@ -1013,7 +1039,7 @@ with demo:
1013
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1014
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1015
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1016
- sampling_method_dropdown
1017
  ] = make_parameters_section()
1018
  # logging text box
1019
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
@@ -1021,6 +1047,7 @@ with demo:
1021
  run_fn,
1022
  inputs=[
1023
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
 
1024
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1025
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1026
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
351
  return to_pil_images(rgb), logging_str
352
 
353
  def _ncut_run(*args, **kwargs):
354
+ try:
355
+ ret = ncut_run(*args, **kwargs)
356
+ if torch.cuda.is_available():
357
+ torch.cuda.empty_cache()
358
+ return ret
359
+ except Exception as e:
360
+ gr.Error(str(e))
361
+ if torch.cuda.is_available():
362
+ torch.cuda.empty_cache()
363
+ return [], "Error: " + str(e)
364
+
365
+ # ret = ncut_run(*args, **kwargs)
366
+ # return ret
367
 
368
  if USE_HUGGINGFACE_ZEROGPU:
369
  @spaces.GPU(duration=20)
 
445
  # model = torch.load(save_path)
446
  return model
447
 
448
+ promptable_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
449
+
450
+
451
  def run_fn(
452
  images,
453
  model_name="SAM(sam_vit_b)",
454
  layer=-1,
455
  num_eig=100,
456
  node_type="block",
457
+ positive_prompt="",
458
+ negative_prompt="",
459
  affinity_focal_gamma=0.3,
460
  num_sample_ncut=10000,
461
  knn_ncut=10,
 
508
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
509
  model.timestep = layer
510
  layer = 1
511
+
512
+ if model_name in promptable_models:
513
+ model.positive_prompt = positive_prompt
514
+ model.negative_prompt = negative_prompt
515
 
516
  kwargs = {
517
  "model_name": model_name,
 
703
  model_names = sorted(model_names)
704
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
705
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
706
+ positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
707
+ positive_prompt.visible = False
708
+ negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
709
+ negative_prompt.visible = False
710
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
711
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
712
 
 
729
  gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True))
730
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
731
 
732
+ def change_prompt_text(model_name):
733
+ if model_name in promptable_models:
734
+ return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
735
+ gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
736
+ return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
737
+ gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
738
+ model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
739
+
740
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
741
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
742
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
 
753
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
754
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
755
  perplexity_slider, n_neighbors_slider, min_dist_slider,
756
+ sampling_method_dropdown, positive_prompt, negative_prompt]
757
 
758
  demo = gr.Blocks(
759
  theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
 
775
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
776
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
777
  perplexity_slider, n_neighbors_slider, min_dist_slider,
778
+ sampling_method_dropdown, positive_prompt, negative_prompt
779
  ] = make_parameters_section()
780
  # logging text box
781
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
 
787
  run_fn,
788
  inputs=[
789
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
790
+ positive_prompt, negative_prompt,
791
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
792
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
793
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
831
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
832
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
833
  perplexity_slider, n_neighbors_slider, min_dist_slider,
834
+ sampling_method_dropdown, positive_prompt, negative_prompt
835
  ] = make_parameters_section()
836
  old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut")
837
  invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, knn_ncut_slider,
 
846
  run_fn,
847
  inputs=[
848
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
849
+ positive_prompt, negative_prompt,
850
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
851
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
852
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
893
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
894
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
895
  perplexity_slider, n_neighbors_slider, min_dist_slider,
896
+ sampling_method_dropdown, positive_prompt, negative_prompt
897
  ] = make_parameters_section()
898
  num_eig_slider.visible = False
899
  affinity_focal_gamma_slider.visible = False
 
916
  run_fn,
917
  inputs=[
918
  input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
919
+ positive_prompt, negative_prompt,
920
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
921
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
922
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
940
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
941
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
942
  perplexity_slider, n_neighbors_slider, min_dist_slider,
943
+ sampling_method_dropdown, positive_prompt, negative_prompt
944
  ] = make_parameters_section()
945
  num_sample_tsne_slider.value = 1000
946
  perplexity_slider.value = 500
 
955
  run_fn,
956
  inputs=[
957
  video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
958
+ positive_prompt, negative_prompt,
959
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
960
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
961
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
996
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
997
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
998
  perplexity_slider, n_neighbors_slider, min_dist_slider,
999
+ sampling_method_dropdown, positive_prompt, negative_prompt
1000
  ] = make_parameters_section()
1001
  model_dropdown.value = "AlignedThreeModelAttnNodes"
1002
  model_dropdown.visible = False
 
1020
  run_fn,
1021
  inputs=[
1022
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1023
+ positive_prompt, negative_prompt,
1024
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1025
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1026
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
1039
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1040
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1041
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1042
+ sampling_method_dropdown, positive_prompt, negative_prompt
1043
  ] = make_parameters_section()
1044
  # logging text box
1045
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
 
1047
  run_fn,
1048
  inputs=[
1049
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1050
+ positive_prompt, negative_prompt,
1051
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1052
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1053
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
requirements.txt CHANGED
@@ -14,4 +14,4 @@ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@
14
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
15
  timm
16
  open-clip-torch==2.20.0
17
- ncut-pytorch>=1.3.8
 
14
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
15
  timm
16
  open-clip-torch==2.20.0
17
+ ncut-pytorch>=1.3.10