import gradio as gr import torch from PIL import Image import os import numpy as np import matplotlib.pyplot as plt from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" # Load models and processors clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device) dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base") def get_image_embedding(image, model, processor, model_type): if isinstance(image, str): # Handle file input image = Image.open(image) inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): if model_type == "clip": embedding = model.get_image_features(**inputs) elif model_type == "dinov2": outputs = model(**inputs) embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize return embedding def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10): query_embedding = get_image_embedding(query_img, model, processor, model_type) gallery_embeddings = [] for img in gallery_imgs: emb = get_image_embedding(img, model, processor, model_type) gallery_embeddings.append((emb, img)) rank_list = [] for emb, img in gallery_embeddings: similarity_score = (query_embedding @ emb.T).item() rank_list.append((similarity_score, img)) rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k] return [img for _, img in rank_list] def display_results(query_img, gallery_imgs, top_k): clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k) dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k) return [query_img] + clip_results, [query_img] + dino_results def gradio_interface(query_img, gallery_imgs, top_k): if not isinstance(gallery_imgs, list): gallery_imgs = [gallery_imgs] gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue clip_res, dino_res = display_results(query_img, gallery_imgs, top_k) return clip_res, dino_res import copy gallery_path = "dataset/gallery" filenames = os.listdir(gallery_path) flag_filenames = [filename for filename in filenames if "flag" in filenames] tattoo_filenames = [filename for filename in filenames if "tattoo" in filename] gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ] gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ] query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"] print(gallery_examples_flags) print(gallery_examples_tattoos) demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type="pil", label="Query Image"), gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"), gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"), ], outputs=[ gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]), gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]), ], title="CLIP vs DINOv2 Image Retrieval", description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.", examples=[[query_examples[0], gallery_examples_flags, 10], [query_examples[1], gallery_examples_tattoos, 10]], css=""" #gallery-files { max-height: 150px; overflow-y: scroll; } #clip-results, #dino-results { max-height: 150px; } """ ) demo.launch(share=True)