Spaces:
Runtime error
Runtime error
File size: 2,912 Bytes
31ab17b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import torch
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from qdrant_client import QdrantClient
from transformers import ColQwen2, ColQwen2Processor
# Initialize ColPali model and processor
model_name = "vidore/colqwen2-v0.1"
device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed
colpali_model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
)
colpali_processor = ColQwen2Processor.from_pretrained(
model_name,
)
# Initialize Qdrant client
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
port=None, api_key=QDRANT_API_KEY, timeout=10)
collection_name = "your_collection_name" # Replace with your actual collection name
def search_images_by_text(query_text, top_k=5):
# Process and encode the text query
with torch.no_grad():
batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
query_embedding = colpali_model(**batch_query)
# Convert the query embedding to a list of vectors
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
# Search in Qdrant
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
limit=top_k,
timeout=800,
)
return search_result
def modify_iiif_url(url, width, height):
parts = url.split('/')
size_index = -3
parts[size_index] = f"{width},{height}"
return '/'.join(parts)
def search_and_display(query, top_k, width, height):
results = search_images_by_text(query, top_k)
images = []
captions = []
for result in results.points:
modified_url = modify_iiif_url(result.payload['image_url'], width, height)
response = requests.get(modified_url)
img = Image.open(BytesIO(response.content)).convert("RGB")
images.append(img)
captions.append(f"Score: {result.score:.2f}")
return images, captions
# Define Gradio interface
iface = gr.Interface(
fn=search_and_display,
inputs=[
gr.Textbox(label="Search Query"),
gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Width", value=300),
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Height", value=300)
],
outputs=[
gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
gr.JSON(label="Captions")
],
title="Image Search with IIIF Resizing",
description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images."
)
# Launch the Gradio interface
iface.launch() |