huzey commited on
Commit
b806945
1 Parent(s): fe0a89a

added lots datasets

Browse files
Files changed (1) hide show
  1. app.py +115 -34
app.py CHANGED
@@ -37,31 +37,79 @@ from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
37
  from ncut_pytorch import NCUT
38
  from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
39
 
40
- DATASET_TUPS = [
41
- # (name, num_classes)
42
- ('UCSC-VLAA/Recap-COCO-30K', None),
43
- ('nateraw/pascal-voc-2012', None),
44
- ('johnowhitaker/imagenette2-320', 10),
45
- ('jainr3/diffusiondb-pixelart', None),
46
- ('nielsr/CelebA-faces', None),
47
- ('JapanDegitalMaterial/Places_in_Japan', None),
48
- ('Borismile/Anime-dataset', None),
49
- ('Multimodal-Fatima/CUB_train', 200),
50
- ('mrm8488/ImageNet1K-val', 1000),
51
- ("trashsock/hands-images", 8),
52
- ]
53
- DATASET_NAMES = [tup[0] for tup in DATASET_TUPS]
54
- DATASET_CLASSES = [tup[1] for tup in DATASET_TUPS]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  from datasets import load_dataset
57
 
58
  def download_all_datasets():
59
- for name in DATASET_NAMES:
60
- print(f"Downloading {name}")
61
- try:
62
- load_dataset(name, trust_remote_code=True)
63
- except Exception as e:
64
- print(f"Error downloading {name}: {e}")
 
 
65
 
66
  def compute_ncut(
67
  features,
@@ -1126,21 +1174,31 @@ def make_dataset_images_section(advanced=False, is_random=False):
1126
 
1127
  gr.Markdown('### Load Datasets')
1128
  load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
1129
- advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
1130
  with gr.Column() as basic_block:
1131
- example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
1132
  with gr.Column() as advanced_block:
1133
- dataset_names = DATASET_NAMES
1134
- dataset_classes = DATASET_CLASSES
 
 
 
 
 
1135
  with gr.Row():
1136
- dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
 
 
 
1137
  # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
1138
- num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images")
1139
  if not is_random:
1140
  filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
1141
  filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
1142
- is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
1143
- random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
 
 
1144
  if is_random:
1145
  filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
1146
  filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
@@ -1159,10 +1217,17 @@ def make_dataset_images_section(advanced=False, is_random=False):
1159
  advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
1160
  advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
1161
 
 
 
 
 
 
 
 
 
1162
 
1163
  def change_filter_options(dataset_name):
1164
- idx = dataset_names.index(dataset_name)
1165
- num_classes = dataset_classes[idx]
1166
  if num_classes is None:
1167
  return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
1168
  gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
@@ -1171,8 +1236,7 @@ def make_dataset_images_section(advanced=False, is_random=False):
1171
  dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
1172
 
1173
  def change_filter_by_class(is_filter, dataset_name):
1174
- idx = dataset_names.index(dataset_name)
1175
- num_classes = dataset_classes[idx]
1176
  return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
1177
  filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
1178
 
@@ -1227,9 +1291,26 @@ def make_dataset_images_section(advanced=False, is_random=False):
1227
  image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1228
  else:
1229
  image_idx = list(range(num_images))
1230
- images = [dataset[i]['image'] for i in image_idx]
 
1231
  gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1232
  del dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
  return images
1234
 
1235
  load_images_button.click(load_dataset_images,
 
37
  from ncut_pytorch import NCUT
38
  from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
39
 
40
+
41
+ DATASETS = {
42
+ 'Common': [
43
+ ('mrm8488/ImageNet1K-val', 1000),
44
+ ('UCSC-VLAA/Recap-COCO-30K', None),
45
+ ('nateraw/pascal-voc-2012', None),
46
+ ('johnowhitaker/imagenette2-320', 10),
47
+ ('Multimodal-Fatima/CUB_train', 200),
48
+ ('saragag/FlBirds', 7),
49
+ ('microsoft/cats_vs_dogs', 2),
50
+ ('Robotkid2696/food_classification', 20),
51
+ ],
52
+ 'Face': [
53
+ ('nielsr/CelebA-faces', None),
54
+ ('huggan/anime-faces', None),
55
+ ],
56
+ 'Pose': [
57
+ ('razdab/sign_pose_M', None),
58
+ ('sayakpaul/poses-controlnet-dataset', None),
59
+ ('EgoThink/EgoThink', None),
60
+ ('junjuice0/vtuber-tachi-e', None),
61
+ ('Fiacre/small-animal-poses-controlnet-dataset', None),
62
+ ],
63
+ 'Hand': [
64
+ ('trashsock/hands-images', 8),
65
+ ('dduka/guitar-chords-v3', None),
66
+ ],
67
+ 'Satellite': [
68
+ ('arakesh/deepglobe-2448x2448', None),
69
+ ('tanganke/eurosat', 10),
70
+ ('wangyi111/EuroSAT-SAR', None),
71
+ ('efoley/sar_tile_512', None),
72
+ ],
73
+ 'Medical': [
74
+ ('Mahadih534/Chest_CT-Scan_images-Dataset', 4),
75
+ ('Falah/Alzheimer_MRI', 4),
76
+ ('sartajbhuvaji/Brain-Tumor-Classification', 4),
77
+ ('TrainingDataPro/chest-x-rays', None),
78
+ ('hongrui/mimic_chest_xray_v_1', None),
79
+ ('Leonardo6/path-vqa', None),
80
+ ('Itsunori/path-vqa_jap', None),
81
+ ('ruby-jrl/isic-2024-2', None),
82
+ ('VRJBro/lung_cancer_dataset', 5),
83
+ ('keremberke/blood-cell-object-detection', None)
84
+ ],
85
+ 'Miscs': [
86
+ ('yashvoladoddi37/kanjienglish', None),
87
+ ('Borismile/Anime-dataset', None),
88
+ ('jainr3/diffusiondb-pixelart', None),
89
+ ('jlbaker361/dcgan-eval-creative_gan_256_256', None),
90
+ ('Francesco/csgo-videogame', None),
91
+ ('Francesco/apex-videogame', None),
92
+ ('Marqo/deepfashion-multimodal', None),
93
+ ('huggan/pokemon', None),
94
+ ('huggan/few-shot-universe', None),
95
+ ('huggan/flowers-102-categories', None),
96
+ ('huggan/inat_butterflies_top10k', None),
97
+ ]
98
+ }
99
+ CENTER_CROP_DATASETS = ["razdab/sign_pose_M"]
100
+
101
 
102
  from datasets import load_dataset
103
 
104
  def download_all_datasets():
105
+ for cat in DATASETS.keys():
106
+ for tup in DATASETS[cat]:
107
+ name = tup[0]
108
+ print(f"Downloading {name}")
109
+ try:
110
+ load_dataset(name, trust_remote_code=True)
111
+ except Exception as e:
112
+ print(f"Error downloading {name}: {e}")
113
 
114
  def compute_ncut(
115
  features,
 
1174
 
1175
  gr.Markdown('### Load Datasets')
1176
  load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
1177
+ advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
1178
  with gr.Column() as basic_block:
1179
+ example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
1180
  with gr.Column() as advanced_block:
1181
+ # dataset_names = DATASET_NAMES
1182
+ # dataset_classes = DATASET_CLASSES
1183
+ dataset_categories = list(DATASETS.keys())
1184
+ defualt_cat = dataset_categories[0]
1185
+ def get_choices(cat):
1186
+ return [tup[0] for tup in DATASETS[cat]]
1187
+ defualt_choices = get_choices(defualt_cat)
1188
  with gr.Row():
1189
+ dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600)
1190
+ # dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
1191
+ dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400)
1192
+ dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown)
1193
  # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
1194
+ num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200)
1195
  if not is_random:
1196
  filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
1197
  filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
1198
+ # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
1199
+ # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
1200
+ is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
1201
+ random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True)
1202
  if is_random:
1203
  filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
1204
  filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
 
1217
  advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
1218
  advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
1219
 
1220
+ def find_num_classes(dataset_name):
1221
+ num_classes = None
1222
+ for cat, datasets in DATASETS.items():
1223
+ datasets = [tup[0] for tup in datasets]
1224
+ if dataset_name in datasets:
1225
+ num_classes = DATASETS[cat][datasets.index(dataset_name)][1]
1226
+ break
1227
+ return num_classes
1228
 
1229
  def change_filter_options(dataset_name):
1230
+ num_classes = find_num_classes(dataset_name)
 
1231
  if num_classes is None:
1232
  return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
1233
  gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
 
1236
  dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
1237
 
1238
  def change_filter_by_class(is_filter, dataset_name):
1239
+ num_classes = find_num_classes(dataset_name)
 
1240
  return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
1241
  filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
1242
 
 
1291
  image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1292
  else:
1293
  image_idx = list(range(num_images))
1294
+ key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1295
+ images = [dataset[i][key] for i in image_idx]
1296
  gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1297
  del dataset
1298
+
1299
+ if dataset_name in CENTER_CROP_DATASETS:
1300
+ def center_crop_image(image):
1301
+ # image: PIL image
1302
+ w, h = img.size
1303
+ min_hw = min(h, w)
1304
+ # center crop
1305
+ left = (w - min_hw) // 2
1306
+ top = (h - min_hw) // 2
1307
+ right = left + min_hw
1308
+ bottom = top + min_hw
1309
+ print(left, top, right, bottom)
1310
+ img = img.crop((left, top, right, bottom))
1311
+ return image
1312
+ images = [center_crop_image(image) for image in images]
1313
+
1314
  return images
1315
 
1316
  load_images_button.click(load_dataset_images,