Spaces:
Runtime error
Runtime error
import pickle | |
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoModel | |
# `LSH` and `Table` imports are necessary in order for the | |
# `lsh.pickle` file to load successfully. | |
from similarity_utils import LSH, BuildLSHTable, Table | |
seed = 42 | |
# Only runs once when the script is first run. | |
with open("lsh.pickle", "rb") as handle: | |
loaded_lsh = pickle.load(handle) | |
# Load model for computing embeddings. | |
model_ckpt = "google/vit-base-patch16-224" | |
#model_ckpt = "matteopilotto/vit-base-patch16-224-in21k-snacks" | |
model = AutoModel.from_pretrained(model_ckpt) | |
lsh_builder = BuildLSHTable(model) | |
lsh_builder.lsh = loaded_lsh | |
# Candidate images. | |
dataset = load_dataset("Matthijs/snacks") | |
candidate_dataset = dataset["train"].shuffle(seed=seed) | |
def query(image, top_k): | |
results = lsh_builder.query(image) | |
# Should be a list of string file paths for gr.Gallery to work | |
images = [] | |
# List of labels for each image in the gallery | |
labels = [] | |
candidates = [] | |
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)): | |
if idx == top_k: | |
break | |
image_id, label = r.split("_")[0], r.split("_")[1] | |
candidates.append(candidate_dataset[int(image_id)]["image"]) | |
labels.append(f"Label: {label}") | |
for i, candidate in enumerate(candidates): | |
filename = f"similar_{i}.png" | |
candidate.save(filename) | |
images.append(filename) | |
# The gallery component can be a list of tuples, where the first element is a path to a file | |
# and the second element is an optional caption for that image | |
return list(zip(images, labels)) | |
title = "Fetch Similar Snacks 🪴" | |
# You can set the type of gr.Image to be PIL, numpy or str (filepath) | |
# Not sure what the best for this demo is. | |
gr.Interface( | |
query, | |
inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)], | |
#outputs=gr.Gallery().style(grid=[3], height="auto"), | |
outputs=gr.Gallery(), | |
# Filenames denote the integer labels. Know here: https://hf.co/datasets/beans | |
title=title, | |
examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]], | |
).launch() | |