huzey commited on
Commit
4fa11ec
1 Parent(s): 6706a30
Files changed (2) hide show
  1. app.py +147 -15
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Author: Huzheng Yang
2
  # %%
 
3
  import os
4
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
5
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
@@ -232,6 +233,10 @@ def ncut_run(
232
  recursion_l2_gamma=0.5,
233
  recursion_l3_gamma=0.5,
234
  video_output=False,
 
 
 
 
235
  ):
236
  logging_str = ""
237
  if "AlignedThreeModelAttnNodes" == model_name:
@@ -256,6 +261,24 @@ def ncut_run(
256
  if "AlignedThreeModelAttnNodes" == model_name:
257
  # dirty patch for the alignedcut paper
258
  features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  else:
260
  features = extract_features(
261
  images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
@@ -340,6 +363,14 @@ def ncut_run(
340
  galleries.append(to_pil_images(_rgb, target_size=56))
341
  return *galleries, logging_str
342
 
 
 
 
 
 
 
 
 
343
  rgb = dont_use_too_much_green(rgb)
344
 
345
 
@@ -451,7 +482,8 @@ def load_alignedthreemodel():
451
  # model = torch.load(save_path)
452
  return model
453
 
454
- promptable_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
 
455
 
456
 
457
  def run_fn(
@@ -462,6 +494,10 @@ def run_fn(
462
  node_type="block",
463
  positive_prompt="",
464
  negative_prompt="",
 
 
 
 
465
  affinity_focal_gamma=0.3,
466
  num_sample_ncut=10000,
467
  knn_ncut=10,
@@ -515,10 +551,10 @@ def run_fn(
515
  model.timestep = layer
516
  layer = 1
517
 
518
- if model_name in promptable_models:
519
  model.positive_prompt = positive_prompt
520
  model.negative_prompt = negative_prompt
521
-
522
  kwargs = {
523
  "model_name": model_name,
524
  "layer": layer,
@@ -543,11 +579,18 @@ def run_fn(
543
  "recursion_l2_gamma": recursion_l2_gamma,
544
  "recursion_l3_gamma": recursion_l3_gamma,
545
  "video_output": video_output,
 
 
 
 
546
  }
547
  # print(kwargs)
548
 
549
  if old_school_ncut:
550
- super_duper_long_run(model, images, **kwargs)
 
 
 
551
 
552
  num_images = len(images)
553
  if num_images >= 100:
@@ -702,18 +745,26 @@ def make_output_images_section():
702
  output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
703
  return output_gallery
704
 
705
- def make_parameters_section():
706
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
707
  from ncut_pytorch.backbone import list_models, get_demo_model_names
708
  model_names = list_models()
709
  model_names = sorted(model_names)
710
- model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
711
- layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
712
- positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
713
- positive_prompt.visible = False
714
- negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
715
- negative_prompt.visible = False
716
- node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
 
 
 
 
 
 
 
 
717
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
718
 
719
  def change_layer_slider(model_name):
@@ -724,6 +775,12 @@ def make_parameters_section():
724
  return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
725
  gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
726
 
 
 
 
 
 
 
727
  layer_dict = LAYER_DICT
728
  if model_name in layer_dict:
729
  value = layer_dict[model_name]
@@ -736,7 +793,7 @@ def make_parameters_section():
736
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
737
 
738
  def change_prompt_text(model_name):
739
- if model_name in promptable_models:
740
  return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
741
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
742
  return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
@@ -788,12 +845,15 @@ with demo:
788
 
789
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
790
 
 
 
791
 
792
  submit_button.click(
793
  run_fn,
794
  inputs=[
795
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
796
  positive_prompt, negative_prompt,
 
797
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
798
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
799
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
@@ -848,11 +908,15 @@ with demo:
848
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
849
 
850
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
 
 
 
851
  submit_button.click(
852
  run_fn,
853
  inputs=[
854
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
855
  positive_prompt, negative_prompt,
 
856
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
857
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
858
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -918,11 +982,15 @@ with demo:
918
  number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
919
  number_placeholder.visible = False
920
  clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
 
 
 
921
  submit_button.click(
922
  run_fn,
923
  inputs=[
924
- input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
925
  positive_prompt, negative_prompt,
 
926
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
927
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
928
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -957,11 +1025,15 @@ with demo:
957
  clear_video_button.click(lambda x: (None, None), outputs=[video_input_gallery, video_output_gallery])
958
  place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false")
959
  place_holder_false.visible = False
 
 
 
960
  submit_button.click(
961
  run_fn,
962
  inputs=[
963
- video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
964
  positive_prompt, negative_prompt,
 
965
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
966
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
967
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
@@ -979,6 +1051,57 @@ with demo:
979
  from draft_gradio_app_text import make_demo
980
  make_demo()
981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  with gr.Tab('Model Aligned'):
983
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
984
  gr.Markdown('---')
@@ -1022,11 +1145,16 @@ with demo:
1022
 
1023
 
1024
  clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
 
 
 
 
1025
  submit_button.click(
1026
  run_fn,
1027
  inputs=[
1028
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1029
  positive_prompt, negative_prompt,
 
1030
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1031
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1032
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
@@ -1049,11 +1177,15 @@ with demo:
1049
  ] = make_parameters_section()
1050
  # logging text box
1051
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
 
 
 
1052
  submit_button.click(
1053
  run_fn,
1054
  inputs=[
1055
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1056
  positive_prompt, negative_prompt,
 
1057
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1058
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1059
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
1
  # Author: Huzheng Yang
2
  # %%
3
+ import copy
4
  import os
5
  USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
6
  DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
 
233
  recursion_l2_gamma=0.5,
234
  recursion_l3_gamma=0.5,
235
  video_output=False,
236
+ is_lisa=False,
237
+ lisa_prompt1="",
238
+ lisa_prompt2="",
239
+ lisa_prompt3="",
240
  ):
241
  logging_str = ""
242
  if "AlignedThreeModelAttnNodes" == model_name:
 
261
  if "AlignedThreeModelAttnNodes" == model_name:
262
  # dirty patch for the alignedcut paper
263
  features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
264
+ elif is_lisa == True:
265
+ # dirty patch for the LISA model
266
+ features = []
267
+ with torch.no_grad():
268
+ model = model.cuda()
269
+ images = images.cuda()
270
+ lisa_prompts = [lisa_prompt1, lisa_prompt2, lisa_prompt3]
271
+ for prompt in lisa_prompts:
272
+ import bleach
273
+ prompt = bleach.clean(prompt)
274
+ prompt = prompt.strip()
275
+ # print(prompt)
276
+ # # copy the sting to a new string
277
+ # copy_s = copy.copy(prompt)
278
+ feature = model(images, input_str=prompt)[node_type][0]
279
+ feature = F.normalize(feature, dim=-1)
280
+ features.append(feature.cpu().float())
281
+ features = torch.stack(features)
282
  else:
283
  features = extract_features(
284
  images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
 
363
  galleries.append(to_pil_images(_rgb, target_size=56))
364
  return *galleries, logging_str
365
 
366
+ if is_lisa == True:
367
+ # dirty patch for the LISA model
368
+ galleries = []
369
+ for i_prompt in range(len(lisa_prompts)):
370
+ _rgb = rgb[i_prompt]
371
+ galleries.append(to_pil_images(_rgb, target_size=256))
372
+ return *galleries, logging_str
373
+
374
  rgb = dont_use_too_much_green(rgb)
375
 
376
 
 
482
  # model = torch.load(save_path)
483
  return model
484
 
485
+ promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
486
+ promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
487
 
488
 
489
  def run_fn(
 
494
  node_type="block",
495
  positive_prompt="",
496
  negative_prompt="",
497
+ is_lisa=False,
498
+ lisa_prompt1="",
499
+ lisa_prompt2="",
500
+ lisa_prompt3="",
501
  affinity_focal_gamma=0.3,
502
  num_sample_ncut=10000,
503
  knn_ncut=10,
 
551
  model.timestep = layer
552
  layer = 1
553
 
554
+ if model_name in promptable_diffusion_models:
555
  model.positive_prompt = positive_prompt
556
  model.negative_prompt = negative_prompt
557
+
558
  kwargs = {
559
  "model_name": model_name,
560
  "layer": layer,
 
579
  "recursion_l2_gamma": recursion_l2_gamma,
580
  "recursion_l3_gamma": recursion_l3_gamma,
581
  "video_output": video_output,
582
+ "lisa_prompt1": lisa_prompt1,
583
+ "lisa_prompt2": lisa_prompt2,
584
+ "lisa_prompt3": lisa_prompt3,
585
+ "is_lisa": is_lisa,
586
  }
587
  # print(kwargs)
588
 
589
  if old_school_ncut:
590
+ return super_duper_long_run(model, images, **kwargs)
591
+
592
+ if is_lisa:
593
+ return super_duper_long_run(model, images, **kwargs)
594
 
595
  num_images = len(images)
596
  if num_images >= 100:
 
745
  output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
746
  return output_gallery
747
 
748
+ def make_parameters_section(is_lisa=False):
749
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
750
  from ncut_pytorch.backbone import list_models, get_demo_model_names
751
  model_names = list_models()
752
  model_names = sorted(model_names)
753
+ if is_lisa:
754
+ model_dropdown = gr.Dropdown(["LISA(xinlai/LISA-7B-v1)"], label="Backbone", value="LISA(xinlai/LISA-7B-v1)", elem_id="model_name")
755
+ layer_slider = gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False)
756
+ layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
757
+ positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False)
758
+ negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
759
+ node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
760
+ else:
761
+ model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
762
+ layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
763
+ positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
764
+ positive_prompt.visible = False
765
+ negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
766
+ negative_prompt.visible = False
767
+ node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
768
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
769
 
770
  def change_layer_slider(model_name):
 
775
  return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
776
  gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
777
 
778
+ if model_name == "LISSL(xinlai/LISSL-7B-v1)":
779
+ layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
780
+ default_layer = "dec_1_block"
781
+ return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False),
782
+ gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
783
+
784
  layer_dict = LAYER_DICT
785
  if model_name in layer_dict:
786
  value = layer_dict[model_name]
 
793
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
794
 
795
  def change_prompt_text(model_name):
796
+ if model_name in promptable_diffusion_models:
797
  return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
798
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
799
  return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
 
845
 
846
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
847
 
848
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
849
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
850
 
851
  submit_button.click(
852
  run_fn,
853
  inputs=[
854
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
855
  positive_prompt, negative_prompt,
856
+ false_placeholder, no_prompt, no_prompt, no_prompt,
857
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
858
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
859
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
908
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
909
 
910
  clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
911
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
912
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
913
+
914
  submit_button.click(
915
  run_fn,
916
  inputs=[
917
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
918
  positive_prompt, negative_prompt,
919
+ false_placeholder, no_prompt, no_prompt, no_prompt,
920
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
921
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
922
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
982
  number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
983
  number_placeholder.visible = False
984
  clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
985
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
986
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
987
+
988
  submit_button.click(
989
  run_fn,
990
  inputs=[
991
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
992
  positive_prompt, negative_prompt,
993
+ false_placeholder, no_prompt, no_prompt, no_prompt,
994
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
995
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
996
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
1025
  clear_video_button.click(lambda x: (None, None), outputs=[video_input_gallery, video_output_gallery])
1026
  place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false")
1027
  place_holder_false.visible = False
1028
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1029
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1030
+
1031
  submit_button.click(
1032
  run_fn,
1033
  inputs=[
1034
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1035
  positive_prompt, negative_prompt,
1036
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1037
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1038
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1039
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
 
1051
  from draft_gradio_app_text import make_demo
1052
  make_demo()
1053
 
1054
+ with gr.Tab('Vision-Language'):
1055
+ gr.Markdown('[LISA]((https://arxiv.org/pdf/2308.00692)) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.')
1056
+ gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt')
1057
+ gr.Markdown('This page aims to see how the text prompt affects the image features')
1058
+ gr.Markdown('---')
1059
+ gr.Markdown('<p style="text-align: center;">Color is <b>aligned</b> across 3 prompts. NCUT is computed on the concatenated features from 3 prompts.</p>')
1060
+ with gr.Row():
1061
+ with gr.Column(scale=5, min_width=200):
1062
+ gr.Markdown('### Output (Prompt #1)')
1063
+ l1_gallery = gr.Gallery(value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="auto")
1064
+ prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3)
1065
+ with gr.Column(scale=5, min_width=200):
1066
+ gr.Markdown('### Output (Prompt #2)')
1067
+ l2_gallery = gr.Gallery(value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="auto")
1068
+ prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3)
1069
+ with gr.Column(scale=5, min_width=200):
1070
+ gr.Markdown('### Output (Prompt #3)')
1071
+ l3_gallery = gr.Gallery(value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="auto")
1072
+ prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3)
1073
+
1074
+ with gr.Row():
1075
+ with gr.Column(scale=5, min_width=200):
1076
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
1077
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=False)
1078
+ clear_images_button.click(lambda x: ([], [], [], []), outputs=[input_gallery, l1_gallery, l2_gallery, l3_gallery])
1079
+
1080
+ with gr.Column(scale=5, min_width=200):
1081
+ [
1082
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1083
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1084
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1085
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
1086
+ sampling_method_dropdown, positive_prompt, negative_prompt
1087
+ ] = make_parameters_section(is_lisa=True)
1088
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1089
+
1090
+ galleries = [l1_gallery, l2_gallery, l3_gallery]
1091
+ true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder", visible=False)
1092
+ submit_button.click(
1093
+ run_fn,
1094
+ inputs=[
1095
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1096
+ positive_prompt, negative_prompt,
1097
+ true_placeholder, prompt1, prompt2, prompt3,
1098
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1099
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1100
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
1101
+ ],
1102
+ outputs=galleries + [logging_text],
1103
+ )
1104
+
1105
  with gr.Tab('Model Aligned'):
1106
  gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
1107
  gr.Markdown('---')
 
1145
 
1146
 
1147
  clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
1148
+
1149
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1150
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1151
+
1152
  submit_button.click(
1153
  run_fn,
1154
  inputs=[
1155
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1156
  positive_prompt, negative_prompt,
1157
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1158
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1159
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1160
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
 
1177
  ] = make_parameters_section()
1178
  # logging text box
1179
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
1180
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1181
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
1182
+
1183
  submit_button.click(
1184
  run_fn,
1185
  inputs=[
1186
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
1187
  positive_prompt, negative_prompt,
1188
+ false_placeholder, no_prompt, no_prompt, no_prompt,
1189
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
1190
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
1191
  perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
requirements.txt CHANGED
@@ -2,7 +2,7 @@ torch
2
  torchvision
3
  opencv-python
4
  decord
5
- transformers
6
  datasets
7
  diffusers
8
  accelerate
@@ -12,6 +12,7 @@ pillow==9.4.0
12
  SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
13
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
14
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
 
15
  timm
16
  open-clip-torch==2.20.0
17
- ncut-pytorch>=1.3.10
 
2
  torchvision
3
  opencv-python
4
  decord
5
+ transformers==4.31.0
6
  datasets
7
  diffusers
8
  accelerate
 
12
  SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
13
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
14
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
15
+ lisa @ git+https://github.com/huzeyann/LISA.git
16
  timm
17
  open-clip-torch==2.20.0
18
+ ncut-pytorch>=1.3.13