taesiri's picture
update
c1a6c5e
raw
history blame
5.49 kB
import os
from io import BytesIO
from multiprocessing import Pool, cpu_count
from datasets import load_dataset
from PIL import Image
import gradio as gr
import pandas as pd
imagenet_hard_dataset = load_dataset("taesiri/imagenet-hard", split="validation")
THUMBNAIL_PATH = "dataset/thumbnails"
os.makedirs(THUMBNAIL_PATH, exist_ok=True)
max_size = (480, 480)
all_origins = set()
all_labels = set()
dataset_df = None
def process_image(i):
global all_origins
image = imagenet_hard_dataset[i]["image"].convert("RGB")
url_prefix = "https://imagenet-hard.taesiri.ai/"
origin = imagenet_hard_dataset[i]["origin"]
label = imagenet_hard_dataset[i]["english_label"]
save_path = os.path.join(THUMBNAIL_PATH, origin)
# make sure the folder exists
os.makedirs(save_path, exist_ok=True)
image_path = os.path.join(save_path, f"{i}.jpg")
image.thumbnail(max_size, Image.LANCZOS)
image.save(image_path, "JPEG", quality=100)
url = url_prefix + image_path
return {
"preview": url,
"filepath": image_path,
"origin": imagenet_hard_dataset[i]["origin"],
"labels": imagenet_hard_dataset[i]["english_label"],
}
# PREPROCESSING
if os.path.exists("dataset.pkl"):
dataset_df = pd.read_pickle("dataset.pkl")
all_origins = set(dataset_df["origin"])
all_labels = set().union(*dataset_df["labels"])
else:
with Pool(cpu_count()) as pool:
samples_data = pool.map(process_image, range(len(imagenet_hard_dataset)))
dataset_df = pd.DataFrame(samples_data)
print(dataset_df)
all_origins = set(dataset_df["origin"])
all_labels = set().union(*dataset_df["labels"])
# save dataframe on disk
dataset_df.to_csv("dataset.csv")
dataset_df.to_pickle("dataset.pkl")
def get_slice(origin, label):
global dataset_df
if not origin and not label:
filtered_df = dataset_df
else:
filtered_df = dataset_df[
(dataset_df["origin"] == origin if origin else True)
& (dataset_df["labels"].apply(lambda x: label in x) if label else True)
]
max_value = len(filtered_df) // 16
returned_values = []
start_index = 0
end_index = start_index + 16
slice_df = filtered_df.iloc[start_index:end_index]
for row in slice_df.itertuples():
returned_values.append(gr.update(value=row.preview))
returned_values.append(gr.update(value=row.origin))
returned_values.append(gr.update(value=row.labels))
if len(returned_values) < 48:
returned_values.extend([None] * (48 - len(returned_values)))
filtered_df = gr.Dataframe(filtered_df, datatype="markdown")
return filtered_df, gr.update(maximum=max_value, value=0), *returned_values
def reset_filters_fn():
return gr.update(value=None), gr.update(value=None)
def make_grid(grid_size):
list_of_components = []
with gr.Row():
for row_counter in range(grid_size[0]):
with gr.Column():
for col_counter in range(grid_size[1]):
item_image = gr.Image()
with gr.Accordion("Click for details", open=False):
item_source = gr.Textbox(label="Source Dataset")
item_labels = gr.Textbox(label="Labels")
list_of_components.append(item_image)
list_of_components.append(item_source)
list_of_components.append(item_labels)
return list_of_components
def slider_upadte(slider, df):
returned_values = []
start_index = (slider) * 16
end_index = start_index + 16
slice_df = df.iloc[start_index:end_index]
for row in slice_df.itertuples():
returned_values.append(gr.update(value=row.preview))
returned_values.append(gr.update(value=row.origin))
returned_values.append(gr.update(value=row.labels))
if len(returned_values) < 48:
returned_values.extend([None] * (48 - len(returned_values)))
return returned_values
with gr.Blocks() as demo:
gr.Markdown("# ImageNet-Hard Browser")
# add link to home page and dataset
gr.HTML("")
gr.HTML()
gr.HTML(
"""
<center>
<span style="font-size: 14px; vertical-align: middle;">
<a href='https://zoom.taesiri.ai/'>Project Home Page</a> &nbsp;|&nbsp;
<a href='https://huggingface.co/datasets/taesiri/imagenet-hard'>Dataset</a>
</span>
</center>
"""
)
with gr.Row():
origin_dropdown = gr.Dropdown(all_origins, label="Origin")
label_dropdown = gr.Dropdown(all_labels, label="Label")
with gr.Row():
show_btn = gr.Button("Show")
reset_filters = gr.Button("Reset Filters")
preview_dataframe = gr.Dataframe(height=500, visible=False)
gr.Markdown("## Preview")
maximum_vale = len(dataset_df) // 16
preview_slider = gr.Slider(minimum=1, maximum=maximum_vale, step=1, value=1)
all_components = make_grid((4, 4))
show_btn.click(
fn=get_slice,
inputs=[origin_dropdown, label_dropdown],
outputs=[preview_dataframe, preview_slider, *all_components],
)
reset_filters.click(
fn=reset_filters_fn,
inputs=[],
outputs=[origin_dropdown, label_dropdown],
)
preview_slider.change(
fn=slider_upadte,
inputs=[preview_slider, preview_dataframe],
outputs=[*all_components],
)
demo.launch()