huzey commited on
Commit
8898a47
1 Parent(s): f5de82f

update default model

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -45,7 +45,10 @@ from datasets import load_dataset
45
  def download_all_datasets():
46
  for name in DATASET_NAMES:
47
  print(f"Downloading {name}")
48
- load_dataset(name, trust_remote_code=True)
 
 
 
49
 
50
  def compute_ncut(
51
  features,
@@ -528,6 +531,10 @@ def make_dataset_images_section(open=False):
528
  return None
529
  if num_images > len(dataset):
530
  num_images = len(dataset)
 
 
 
 
531
  if is_filter:
532
  classes = list(map(int, filter_by_class_text.split(",")))
533
  labels = np.array(dataset['label'])
@@ -574,7 +581,7 @@ def make_parameters_section():
574
  gr.Markdown('### Parameters')
575
  from backbone import get_all_model_names
576
  model_names = get_all_model_names()
577
- model_dropdown = gr.Dropdown(model_names, label="Backbone", value="SAM2(sam2_hiera_t)", elem_id="model_name")
578
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=12, elem_id="layer")
579
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
580
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
@@ -720,7 +727,6 @@ with gr.Blocks() as demo:
720
  hide_button.visible = False
721
  dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section()
722
  num_images_slider.value = 100
723
- dataset_dropdown.value = 'nielsr/CelebA-faces'
724
 
725
  with gr.Column(scale=5, min_width=200):
726
  with gr.Accordion("➡️ Recursion config", open=True):
@@ -737,10 +743,6 @@ with gr.Blocks() as demo:
737
  sampling_method_dropdown
738
  ] = make_parameters_section()
739
  num_eig_slider.visible = False
740
- model_dropdown.value = 'DiNO(dinov2_vitb14_reg)'
741
- layer_slider.value = 6
742
- node_type_dropdown.value = 'attn: attention output'
743
- affinity_focal_gamma_slider.value = 0.25
744
  # logging text box
745
  with gr.Row():
746
  with gr.Column(scale=5, min_width=200):
 
45
  def download_all_datasets():
46
  for name in DATASET_NAMES:
47
  print(f"Downloading {name}")
48
+ try:
49
+ load_dataset(name, trust_remote_code=True)
50
+ except Exception as e:
51
+ print(f"Error downloading {name}: {e}")
52
 
53
  def compute_ncut(
54
  features,
 
531
  return None
532
  if num_images > len(dataset):
533
  num_images = len(dataset)
534
+
535
+ if 'label' not in dataset and is_filter:
536
+ gr.Error(f"Dataset {dataset_name} has no class label.")
537
+ return None
538
  if is_filter:
539
  classes = list(map(int, filter_by_class_text.split(",")))
540
  labels = np.array(dataset['label'])
 
581
  gr.Markdown('### Parameters')
582
  from backbone import get_all_model_names
583
  model_names = get_all_model_names()
584
+ model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
585
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=12, elem_id="layer")
586
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
587
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
 
727
  hide_button.visible = False
728
  dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section()
729
  num_images_slider.value = 100
 
730
 
731
  with gr.Column(scale=5, min_width=200):
732
  with gr.Accordion("➡️ Recursion config", open=True):
 
743
  sampling_method_dropdown
744
  ] = make_parameters_section()
745
  num_eig_slider.visible = False
 
 
 
 
746
  # logging text box
747
  with gr.Row():
748
  with gr.Column(scale=5, min_width=200):