Spaces:
Sleeping
Sleeping
import os | |
from io import BytesIO | |
from multiprocessing import Pool, cpu_count | |
from datasets import load_dataset | |
from PIL import Image | |
import gradio as gr | |
import pandas as pd | |
imagenet_hard_dataset = load_dataset("taesiri/imagenet-hard", split="validation") | |
THUMBNAIL_PATH = "dataset/thumbnails" | |
os.makedirs(THUMBNAIL_PATH, exist_ok=True) | |
max_size = (480, 480) | |
all_origins = set() | |
all_labels = set() | |
dataset_df = None | |
beautiful_dataset_names = { | |
"imagenet": "ImageNet", | |
"imagenet_a": "ImageNet-A", | |
"imagenet_r": "ImageNet-R", | |
"imagenet_sketch": "ImageNet-Sketch", | |
"objectnet": "ObjectNet", | |
"imagenet_v2": "ImageNet-V2", | |
} | |
def process_image(i): | |
global all_origins | |
image = imagenet_hard_dataset[i]["image"].convert("RGB") | |
url_prefix = "https://imagenet-hard.taesiri.ai/" | |
origin = imagenet_hard_dataset[i]["origin"] | |
label = imagenet_hard_dataset[i]["english_label"] | |
save_path = os.path.join(THUMBNAIL_PATH, origin) | |
# make sure the folder exists | |
os.makedirs(save_path, exist_ok=True) | |
image_path = os.path.join(save_path, f"{i}.jpg") | |
image.thumbnail(max_size, Image.LANCZOS) | |
image.save(image_path, "JPEG", quality=100) | |
url = url_prefix + image_path | |
return { | |
"preview": url, | |
"filepath": image_path, | |
"origin": imagenet_hard_dataset[i]["origin"], | |
"labels": imagenet_hard_dataset[i]["english_label"], | |
} | |
# PREPROCESSING | |
if os.path.exists("dataset.pkl"): | |
dataset_df = pd.read_pickle("dataset.pkl") | |
all_origins = set(dataset_df["origin"]) | |
all_labels = set().union(*dataset_df["labels"]) | |
else: | |
with Pool(cpu_count()) as pool: | |
samples_data = pool.map(process_image, range(len(imagenet_hard_dataset))) | |
dataset_df = pd.DataFrame(samples_data) | |
print(dataset_df) | |
all_origins = set(dataset_df["origin"]) | |
all_labels = set().union(*dataset_df["labels"]) | |
# save dataframe on disk | |
dataset_df.to_csv("dataset.csv") | |
dataset_df.to_pickle("dataset.pkl") | |
def get_values_for_the_slice(slice_df): | |
returned_values = [] | |
for row in slice_df.itertuples(): | |
# returned_values.append(gr.update(value=row.preview)) | |
labels = ", ".join(row.labels) | |
# replace _ with space | |
labels = labels.replace("_", " ") | |
dataset_name = beautiful_dataset_names[row.origin] | |
label_string = f"{labels} - ({dataset_name})" | |
returned_values.append(gr.update(label=label_string, value=row.preview)) | |
# returned_values.append(gr.update(value=beautiful_dataset_names[row.origin])) | |
if len(returned_values) < 16: | |
returned_values.extend([None] * (16 - len(returned_values))) | |
return returned_values | |
def get_slice(origin, label): | |
global dataset_df | |
if not origin and not label: | |
filtered_df = dataset_df | |
else: | |
filtered_df = dataset_df[ | |
(dataset_df["origin"] == origin if origin else True) | |
& (dataset_df["labels"].apply(lambda x: label in x) if label else True) | |
] | |
max_value = len(filtered_df) // 16 | |
start_index = 0 | |
end_index = start_index + 16 | |
slice_df = filtered_df.iloc[start_index:end_index] | |
returned_values = get_values_for_the_slice(slice_df) | |
filtered_df = gr.Dataframe(filtered_df, datatype="markdown") | |
return filtered_df, gr.update(maximum=max_value, value=0), *returned_values | |
def reset_filters_fn(): | |
return gr.update(value=None), gr.update(value=None) | |
def make_grid(grid_size): | |
list_of_components = [] | |
with gr.Row(): | |
for row_counter in range(grid_size[0]): | |
with gr.Column(): | |
for col_counter in range(grid_size[1]): | |
item_image = gr.Image() | |
# with gr.Accordion("Click for details", open=False): | |
# item_source = gr.Textbox(label="Source Dataset") | |
list_of_components.append(item_image) | |
# list_of_components.append(item_source) | |
# list_of_components.append(item_labels) | |
return list_of_components | |
def slider_upadte(slider, df): | |
start_index = (slider) * 16 | |
end_index = start_index + 16 | |
slice_df = df.iloc[start_index:end_index] | |
returned_values = get_values_for_the_slice(slice_df) | |
return returned_values | |
with gr.Blocks() as demo: | |
gr.Markdown("# ImageNet-Hard Browser") | |
# add link to home page and dataset | |
gr.HTML("") | |
gr.HTML() | |
gr.HTML( | |
""" | |
<center> | |
<span style="font-size: 14px; vertical-align: middle;"> | |
<a href='https://zoom.taesiri.ai/'>Project Home Page</a> | | |
<a href='https://huggingface.co/datasets/taesiri/imagenet-hard'>Dataset</a> | |
</span> | |
</center> | |
""" | |
) | |
with gr.Row(): | |
origin_dropdown = gr.Dropdown(all_origins, label="Origin") | |
label_dropdown = gr.Dropdown(all_labels, label="Category") | |
with gr.Row(): | |
show_btn = gr.Button("Show") | |
reset_filters = gr.Button("Reset Filters") | |
preview_dataframe = gr.Dataframe(visible=False) | |
gr.Markdown("## Preview") | |
maximum_vale = len(dataset_df) // 16 | |
preview_slider = gr.Slider(minimum=1, maximum=maximum_vale, step=1, value=1) | |
all_components = make_grid((4, 4)) | |
show_btn.click( | |
fn=get_slice, | |
inputs=[origin_dropdown, label_dropdown], | |
outputs=[preview_dataframe, preview_slider, *all_components], | |
) | |
reset_filters.click( | |
fn=reset_filters_fn, | |
inputs=[], | |
outputs=[origin_dropdown, label_dropdown], | |
) | |
preview_slider.change( | |
fn=slider_upadte, | |
inputs=[preview_slider, preview_dataframe], | |
outputs=[*all_components], | |
) | |
demo.launch() | |