Spaces:
Running
on
Zero
Running
on
Zero
added lots datasets
Browse files
app.py
CHANGED
@@ -37,31 +37,79 @@ from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
|
|
37 |
from ncut_pytorch import NCUT
|
38 |
from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
from datasets import load_dataset
|
57 |
|
58 |
def download_all_datasets():
|
59 |
-
for
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
def compute_ncut(
|
67 |
features,
|
@@ -1126,21 +1174,31 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
1126 |
|
1127 |
gr.Markdown('### Load Datasets')
|
1128 |
load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
|
1129 |
-
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
|
1130 |
with gr.Column() as basic_block:
|
1131 |
-
example_gallery = gr.Gallery(value=example_items, label="Example
|
1132 |
with gr.Column() as advanced_block:
|
1133 |
-
dataset_names = DATASET_NAMES
|
1134 |
-
dataset_classes = DATASET_CLASSES
|
|
|
|
|
|
|
|
|
|
|
1135 |
with gr.Row():
|
1136 |
-
|
|
|
|
|
|
|
1137 |
# num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
1138 |
-
num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images")
|
1139 |
if not is_random:
|
1140 |
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
1141 |
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
1142 |
-
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
1143 |
-
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
|
|
|
|
1144 |
if is_random:
|
1145 |
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
1146 |
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
@@ -1159,10 +1217,17 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
1159 |
advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
|
1160 |
advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
|
1161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1162 |
|
1163 |
def change_filter_options(dataset_name):
|
1164 |
-
|
1165 |
-
num_classes = dataset_classes[idx]
|
1166 |
if num_classes is None:
|
1167 |
return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
|
1168 |
gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
|
@@ -1171,8 +1236,7 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
1171 |
dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
|
1172 |
|
1173 |
def change_filter_by_class(is_filter, dataset_name):
|
1174 |
-
|
1175 |
-
num_classes = dataset_classes[idx]
|
1176 |
return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
|
1177 |
filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
|
1178 |
|
@@ -1227,9 +1291,26 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
1227 |
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1228 |
else:
|
1229 |
image_idx = list(range(num_images))
|
1230 |
-
|
|
|
1231 |
gr.Info(f"Loaded {len(images)} images from {dataset_name}")
|
1232 |
del dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1233 |
return images
|
1234 |
|
1235 |
load_images_button.click(load_dataset_images,
|
|
|
37 |
from ncut_pytorch import NCUT
|
38 |
from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
|
39 |
|
40 |
+
|
41 |
+
DATASETS = {
|
42 |
+
'Common': [
|
43 |
+
('mrm8488/ImageNet1K-val', 1000),
|
44 |
+
('UCSC-VLAA/Recap-COCO-30K', None),
|
45 |
+
('nateraw/pascal-voc-2012', None),
|
46 |
+
('johnowhitaker/imagenette2-320', 10),
|
47 |
+
('Multimodal-Fatima/CUB_train', 200),
|
48 |
+
('saragag/FlBirds', 7),
|
49 |
+
('microsoft/cats_vs_dogs', 2),
|
50 |
+
('Robotkid2696/food_classification', 20),
|
51 |
+
],
|
52 |
+
'Face': [
|
53 |
+
('nielsr/CelebA-faces', None),
|
54 |
+
('huggan/anime-faces', None),
|
55 |
+
],
|
56 |
+
'Pose': [
|
57 |
+
('razdab/sign_pose_M', None),
|
58 |
+
('sayakpaul/poses-controlnet-dataset', None),
|
59 |
+
('EgoThink/EgoThink', None),
|
60 |
+
('junjuice0/vtuber-tachi-e', None),
|
61 |
+
('Fiacre/small-animal-poses-controlnet-dataset', None),
|
62 |
+
],
|
63 |
+
'Hand': [
|
64 |
+
('trashsock/hands-images', 8),
|
65 |
+
('dduka/guitar-chords-v3', None),
|
66 |
+
],
|
67 |
+
'Satellite': [
|
68 |
+
('arakesh/deepglobe-2448x2448', None),
|
69 |
+
('tanganke/eurosat', 10),
|
70 |
+
('wangyi111/EuroSAT-SAR', None),
|
71 |
+
('efoley/sar_tile_512', None),
|
72 |
+
],
|
73 |
+
'Medical': [
|
74 |
+
('Mahadih534/Chest_CT-Scan_images-Dataset', 4),
|
75 |
+
('Falah/Alzheimer_MRI', 4),
|
76 |
+
('sartajbhuvaji/Brain-Tumor-Classification', 4),
|
77 |
+
('TrainingDataPro/chest-x-rays', None),
|
78 |
+
('hongrui/mimic_chest_xray_v_1', None),
|
79 |
+
('Leonardo6/path-vqa', None),
|
80 |
+
('Itsunori/path-vqa_jap', None),
|
81 |
+
('ruby-jrl/isic-2024-2', None),
|
82 |
+
('VRJBro/lung_cancer_dataset', 5),
|
83 |
+
('keremberke/blood-cell-object-detection', None)
|
84 |
+
],
|
85 |
+
'Miscs': [
|
86 |
+
('yashvoladoddi37/kanjienglish', None),
|
87 |
+
('Borismile/Anime-dataset', None),
|
88 |
+
('jainr3/diffusiondb-pixelart', None),
|
89 |
+
('jlbaker361/dcgan-eval-creative_gan_256_256', None),
|
90 |
+
('Francesco/csgo-videogame', None),
|
91 |
+
('Francesco/apex-videogame', None),
|
92 |
+
('Marqo/deepfashion-multimodal', None),
|
93 |
+
('huggan/pokemon', None),
|
94 |
+
('huggan/few-shot-universe', None),
|
95 |
+
('huggan/flowers-102-categories', None),
|
96 |
+
('huggan/inat_butterflies_top10k', None),
|
97 |
+
]
|
98 |
+
}
|
99 |
+
CENTER_CROP_DATASETS = ["razdab/sign_pose_M"]
|
100 |
+
|
101 |
|
102 |
from datasets import load_dataset
|
103 |
|
104 |
def download_all_datasets():
|
105 |
+
for cat in DATASETS.keys():
|
106 |
+
for tup in DATASETS[cat]:
|
107 |
+
name = tup[0]
|
108 |
+
print(f"Downloading {name}")
|
109 |
+
try:
|
110 |
+
load_dataset(name, trust_remote_code=True)
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error downloading {name}: {e}")
|
113 |
|
114 |
def compute_ncut(
|
115 |
features,
|
|
|
1174 |
|
1175 |
gr.Markdown('### Load Datasets')
|
1176 |
load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
|
1177 |
+
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
|
1178 |
with gr.Column() as basic_block:
|
1179 |
+
example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
1180 |
with gr.Column() as advanced_block:
|
1181 |
+
# dataset_names = DATASET_NAMES
|
1182 |
+
# dataset_classes = DATASET_CLASSES
|
1183 |
+
dataset_categories = list(DATASETS.keys())
|
1184 |
+
defualt_cat = dataset_categories[0]
|
1185 |
+
def get_choices(cat):
|
1186 |
+
return [tup[0] for tup in DATASETS[cat]]
|
1187 |
+
defualt_choices = get_choices(defualt_cat)
|
1188 |
with gr.Row():
|
1189 |
+
dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600)
|
1190 |
+
# dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
1191 |
+
dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400)
|
1192 |
+
dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown)
|
1193 |
# num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
1194 |
+
num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200)
|
1195 |
if not is_random:
|
1196 |
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
1197 |
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
1198 |
+
# is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
1199 |
+
# random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
1200 |
+
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
1201 |
+
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True)
|
1202 |
if is_random:
|
1203 |
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
1204 |
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
|
|
1217 |
advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block])
|
1218 |
advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block])
|
1219 |
|
1220 |
+
def find_num_classes(dataset_name):
|
1221 |
+
num_classes = None
|
1222 |
+
for cat, datasets in DATASETS.items():
|
1223 |
+
datasets = [tup[0] for tup in datasets]
|
1224 |
+
if dataset_name in datasets:
|
1225 |
+
num_classes = DATASETS[cat][datasets.index(dataset_name)][1]
|
1226 |
+
break
|
1227 |
+
return num_classes
|
1228 |
|
1229 |
def change_filter_options(dataset_name):
|
1230 |
+
num_classes = find_num_classes(dataset_name)
|
|
|
1231 |
if num_classes is None:
|
1232 |
return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
|
1233 |
gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
|
|
|
1236 |
dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
|
1237 |
|
1238 |
def change_filter_by_class(is_filter, dataset_name):
|
1239 |
+
num_classes = find_num_classes(dataset_name)
|
|
|
1240 |
return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
|
1241 |
filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
|
1242 |
|
|
|
1291 |
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1292 |
else:
|
1293 |
image_idx = list(range(num_images))
|
1294 |
+
key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
1295 |
+
images = [dataset[i][key] for i in image_idx]
|
1296 |
gr.Info(f"Loaded {len(images)} images from {dataset_name}")
|
1297 |
del dataset
|
1298 |
+
|
1299 |
+
if dataset_name in CENTER_CROP_DATASETS:
|
1300 |
+
def center_crop_image(image):
|
1301 |
+
# image: PIL image
|
1302 |
+
w, h = img.size
|
1303 |
+
min_hw = min(h, w)
|
1304 |
+
# center crop
|
1305 |
+
left = (w - min_hw) // 2
|
1306 |
+
top = (h - min_hw) // 2
|
1307 |
+
right = left + min_hw
|
1308 |
+
bottom = top + min_hw
|
1309 |
+
print(left, top, right, bottom)
|
1310 |
+
img = img.crop((left, top, right, bottom))
|
1311 |
+
return image
|
1312 |
+
images = [center_crop_image(image) for image in images]
|
1313 |
+
|
1314 |
return images
|
1315 |
|
1316 |
load_images_button.click(load_dataset_images,
|