import pickle import gradio as gr from datasets import load_dataset from transformers import AutoModel # `LSH` and `Table` imports are necessary in order for the # `lsh.pickle` file to load successfully. from similarity_utils import LSH, BuildLSHTable, Table seed = 42 # Only runs once when the script is first run. with open("lsh.pickle", "rb") as handle: loaded_lsh = pickle.load(handle) # Load model for computing embeddings. model_ckpt = "google/vit-base-patch16-224" #model_ckpt = "matteopilotto/vit-base-patch16-224-in21k-snacks" model = AutoModel.from_pretrained(model_ckpt) lsh_builder = BuildLSHTable(model) lsh_builder.lsh = loaded_lsh # Candidate images. dataset = load_dataset("Matthijs/snacks") candidate_dataset = dataset["train"].shuffle(seed=seed) def query(image, top_k): results = lsh_builder.query(image) # Should be a list of string file paths for gr.Gallery to work images = [] # List of labels for each image in the gallery labels = [] candidates = [] for idx, r in enumerate(sorted(results, key=results.get, reverse=True)): if idx == top_k: break image_id, label = r.split("_")[0], r.split("_")[1] candidates.append(candidate_dataset[int(image_id)]["image"]) labels.append(f"Label: {label}") for i, candidate in enumerate(candidates): filename = f"similar_{i}.png" candidate.save(filename) images.append(filename) # The gallery component can be a list of tuples, where the first element is a path to a file # and the second element is an optional caption for that image return list(zip(images, labels)) title = "Fetch Similar Snacks 🪴" # You can set the type of gr.Image to be PIL, numpy or str (filepath) # Not sure what the best for this demo is. gr.Interface( query, inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)], #outputs=gr.Gallery().style(grid=[3], height="auto"), outputs=gr.Gallery(), # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans title=title, examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]], ).launch()