huzey commited on
Commit
c9c808c
1 Parent(s): ed7561b

update backbone UI

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -501,7 +501,7 @@ def ncut_run(
501
  images = images.cuda()
502
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
503
  cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
504
- logging_str += f"Plot time: {time.time() - start:.2f}s\n"
505
 
506
 
507
  if video_output:
@@ -1083,6 +1083,14 @@ def make_parameters_section(is_lisa=False):
1083
  from ncut_pytorch.backbone import list_models, get_demo_model_names
1084
  model_names = list_models()
1085
  model_names = sorted(model_names)
 
 
 
 
 
 
 
 
1086
  if is_lisa:
1087
  model_dropdown = gr.Dropdown(["LISA(xinlai/LISA-7B-v1)"], label="Backbone", value="LISA(xinlai/LISA-7B-v1)", elem_id="model_name")
1088
  layer_slider = gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False)
@@ -1092,8 +1100,10 @@ def make_parameters_section(is_lisa=False):
1092
  node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
1093
  else:
1094
  # remove LISA from the list
1095
- model_names = [m for m in model_names if "LISA" not in m]
1096
- model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8_448)", elem_id="model_name")
 
 
1097
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
1098
  positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
1099
  positive_prompt.visible = False
@@ -1170,7 +1180,7 @@ with demo:
1170
 
1171
  with gr.Column(scale=5, min_width=200):
1172
  output_gallery = make_output_images_section()
1173
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=False, elem_id="clusters", columns=[2], rows=[1], object_fit="contain", height=450, show_share_button=True, preview=True, interactive=False)
1174
  [
1175
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1176
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
@@ -1632,3 +1642,5 @@ demo.launch(share=True)
1632
  # images = [(Image.open(image), None) for image in default_images]
1633
  # ret = run_fn(images, num_eig=30)
1634
  # # %%
 
 
 
501
  images = images.cuda()
502
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
503
  cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
504
+ logging_str += f"plot time: {time.time() - start:.2f}s\n"
505
 
506
 
507
  if video_output:
 
1083
  from ncut_pytorch.backbone import list_models, get_demo_model_names
1084
  model_names = list_models()
1085
  model_names = sorted(model_names)
1086
+ def get_filtered_model_names(name):
1087
+ return [m for m in model_names if name.lower() in m.lower()]
1088
+ def get_default_model_name(name):
1089
+ lst = get_filtered_model_names(name)
1090
+ if len(lst) > 1:
1091
+ return lst[1]
1092
+ return lst[0]
1093
+
1094
  if is_lisa:
1095
  model_dropdown = gr.Dropdown(["LISA(xinlai/LISA-7B-v1)"], label="Backbone", value="LISA(xinlai/LISA-7B-v1)", elem_id="model_name")
1096
  layer_slider = gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False)
 
1100
  node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
1101
  else:
1102
  # remove LISA from the list
1103
+ model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True)
1104
+ model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
1105
+ model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
1106
+ # model_radio.change(fn=lambda x: gr.update(value=get_filtered_model_names(x)[0]), inputs=model_radio, outputs=[model_dropdown])
1107
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
1108
  positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
1109
  positive_prompt.visible = False
 
1180
 
1181
  with gr.Column(scale=5, min_width=200):
1182
  output_gallery = make_output_images_section()
1183
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=False, elem_id="clusters", columns=[2], rows=[1], object_fit="contain", height=300, show_share_button=True, preview=True, interactive=False)
1184
  [
1185
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1186
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
 
1642
  # images = [(Image.open(image), None) for image in default_images]
1643
  # ret = run_fn(images, num_eig=30)
1644
  # # %%
1645
+
1646
+ # %%