Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from qdrant_client import QdrantClient | |
from transformers import ColQwen2, ColQwen2Processor | |
# Initialize ColPali model and processor | |
model_name = "vidore/colqwen2-v0.1" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed | |
colpali_model = ColQwen2.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
) | |
colpali_processor = ColQwen2Processor.from_pretrained( | |
model_name, | |
) | |
# Initialize Qdrant client | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space", | |
port=None, api_key=QDRANT_API_KEY, timeout=10) | |
collection_name = "your_collection_name" # Replace with your actual collection name | |
def search_images_by_text(query_text, top_k=5): | |
# Process and encode the text query | |
with torch.no_grad(): | |
batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device) | |
query_embedding = colpali_model(**batch_query) | |
# Convert the query embedding to a list of vectors | |
multivector_query = query_embedding[0].cpu().float().numpy().tolist() | |
# Search in Qdrant | |
search_result = qdrant_client.query_points( | |
collection_name=collection_name, | |
query=multivector_query, | |
limit=top_k, | |
timeout=800, | |
) | |
return search_result | |
def modify_iiif_url(url, width, height): | |
parts = url.split('/') | |
size_index = -3 | |
parts[size_index] = f"{width},{height}" | |
return '/'.join(parts) | |
def search_and_display(query, top_k, width, height): | |
results = search_images_by_text(query, top_k) | |
images = [] | |
captions = [] | |
for result in results.points: | |
modified_url = modify_iiif_url(result.payload['image_url'], width, height) | |
response = requests.get(modified_url) | |
img = Image.open(BytesIO(response.content)).convert("RGB") | |
images.append(img) | |
captions.append(f"Score: {result.score:.2f}") | |
return images, captions | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=search_and_display, | |
inputs=[ | |
gr.Textbox(label="Search Query"), | |
gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5), | |
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Width", value=300), | |
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Height", value=300) | |
], | |
outputs=[ | |
gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"), | |
gr.JSON(label="Captions") | |
], | |
title="Image Search with IIIF Resizing", | |
description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images." | |
) | |
# Launch the Gradio interface | |
iface.launch() |