davanstrien's picture
davanstrien HF staff
Create app.py
31ab17b verified
raw
history blame
2.91 kB
import os
import torch
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from qdrant_client import QdrantClient
from transformers import ColQwen2, ColQwen2Processor
# Initialize ColPali model and processor
model_name = "vidore/colqwen2-v0.1"
device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed
colpali_model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
)
colpali_processor = ColQwen2Processor.from_pretrained(
model_name,
)
# Initialize Qdrant client
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
port=None, api_key=QDRANT_API_KEY, timeout=10)
collection_name = "your_collection_name" # Replace with your actual collection name
def search_images_by_text(query_text, top_k=5):
# Process and encode the text query
with torch.no_grad():
batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
query_embedding = colpali_model(**batch_query)
# Convert the query embedding to a list of vectors
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
# Search in Qdrant
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
limit=top_k,
timeout=800,
)
return search_result
def modify_iiif_url(url, width, height):
parts = url.split('/')
size_index = -3
parts[size_index] = f"{width},{height}"
return '/'.join(parts)
def search_and_display(query, top_k, width, height):
results = search_images_by_text(query, top_k)
images = []
captions = []
for result in results.points:
modified_url = modify_iiif_url(result.payload['image_url'], width, height)
response = requests.get(modified_url)
img = Image.open(BytesIO(response.content)).convert("RGB")
images.append(img)
captions.append(f"Score: {result.score:.2f}")
return images, captions
# Define Gradio interface
iface = gr.Interface(
fn=search_and_display,
inputs=[
gr.Textbox(label="Search Query"),
gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Width", value=300),
gr.Slider(minimum=100, maximum=1000, step=50, label="Image Height", value=300)
],
outputs=[
gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
gr.JSON(label="Captions")
],
title="Image Search with IIIF Resizing",
description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images."
)
# Launch the Gradio interface
iface.launch()