import os from io import BytesIO from multiprocessing import Pool, cpu_count from datasets import load_dataset from PIL import Image import gradio as gr import pandas as pd imagenet_hard_dataset = load_dataset("taesiri/imagenet-hard", split="validation") THUMBNAIL_PATH = "dataset/thumbnails" os.makedirs(THUMBNAIL_PATH, exist_ok=True) max_size = (480, 480) all_origins = set() all_labels = set() dataset_df = None def process_image(i): global all_origins image = imagenet_hard_dataset[i]["image"].convert("RGB") url_prefix = "https://imagenet-hard.taesiri.ai/" origin = imagenet_hard_dataset[i]["origin"] label = imagenet_hard_dataset[i]["english_label"] save_path = os.path.join(THUMBNAIL_PATH, origin) # make sure the folder exists os.makedirs(save_path, exist_ok=True) image_path = os.path.join(save_path, f"{i}.jpg") image.thumbnail(max_size, Image.LANCZOS) image.save(image_path, "JPEG", quality=100) url = url_prefix + image_path return { "preview": url, "filepath": image_path, "origin": imagenet_hard_dataset[i]["origin"], "labels": imagenet_hard_dataset[i]["english_label"], } # PREPROCESSING if os.path.exists("dataset.pkl"): dataset_df = pd.read_pickle("dataset.pkl") all_origins = set(dataset_df["origin"]) all_labels = set().union(*dataset_df["labels"]) else: with Pool(cpu_count()) as pool: samples_data = pool.map(process_image, range(len(imagenet_hard_dataset))) dataset_df = pd.DataFrame(samples_data) print(dataset_df) all_origins = set(dataset_df["origin"]) all_labels = set().union(*dataset_df["labels"]) # save dataframe on disk dataset_df.to_csv("dataset.csv") dataset_df.to_pickle("dataset.pkl") def get_slice(origin, label): global dataset_df if not origin and not label: filtered_df = dataset_df else: filtered_df = dataset_df[ (dataset_df["origin"] == origin if origin else True) & (dataset_df["labels"].apply(lambda x: label in x) if label else True) ] max_value = len(filtered_df) // 16 returned_values = [] start_index = 0 end_index = start_index + 16 slice_df = filtered_df.iloc[start_index:end_index] for row in slice_df.itertuples(): returned_values.append(gr.update(value=row.preview)) returned_values.append(gr.update(value=row.origin)) returned_values.append(gr.update(value=row.labels)) if len(returned_values) < 48: returned_values.extend([None] * (48 - len(returned_values))) filtered_df = gr.Dataframe(filtered_df, datatype="markdown") return filtered_df, gr.update(maximum=max_value, value=0), *returned_values def reset_filters_fn(): return gr.update(value=None), gr.update(value=None) def make_grid(grid_size): list_of_components = [] with gr.Row(): for row_counter in range(grid_size[0]): with gr.Column(): for col_counter in range(grid_size[1]): item_image = gr.Image() with gr.Accordion("Click for details", open=False): item_source = gr.Textbox(label="Source Dataset") item_labels = gr.Textbox(label="Labels") list_of_components.append(item_image) list_of_components.append(item_source) list_of_components.append(item_labels) return list_of_components def slider_upadte(slider, df): returned_values = [] start_index = (slider) * 16 end_index = start_index + 16 slice_df = df.iloc[start_index:end_index] for row in slice_df.itertuples(): returned_values.append(gr.update(value=row.preview)) returned_values.append(gr.update(value=row.origin)) returned_values.append(gr.update(value=row.labels)) if len(returned_values) < 48: returned_values.extend([None] * (48 - len(returned_values))) return returned_values with gr.Blocks() as demo: gr.Markdown("# ImageNet-Hard Browser") # add link to home page and dataset gr.HTML("") gr.HTML() gr.HTML( """