File size: 2,912 Bytes
31ab17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from qdrant_client import QdrantClient
from transformers import ColQwen2, ColQwen2Processor

# Initialize ColPali model and processor
model_name = "vidore/colqwen2-v0.1"
device = "cuda:0" if torch.cuda.is_available() else "cpu"  # You can change this to "mps" for Apple Silicon if needed
colpali_model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
)
colpali_processor = ColQwen2Processor.from_pretrained(
    model_name,
)

# Initialize Qdrant client
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
                             port=None, api_key=QDRANT_API_KEY, timeout=10)

collection_name = "your_collection_name"  # Replace with your actual collection name

def search_images_by_text(query_text, top_k=5):
    # Process and encode the text query
    with torch.no_grad():
        batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
        query_embedding = colpali_model(**batch_query)

    # Convert the query embedding to a list of vectors
    multivector_query = query_embedding[0].cpu().float().numpy().tolist()

    # Search in Qdrant
    search_result = qdrant_client.query_points(
        collection_name=collection_name,
        query=multivector_query,
        limit=top_k,
        timeout=800,
    )

    return search_result

def modify_iiif_url(url, width, height):
    parts = url.split('/')
    size_index = -3
    parts[size_index] = f"{width},{height}"
    return '/'.join(parts)

def search_and_display(query, top_k, width, height):
    results = search_images_by_text(query, top_k)
    images = []
    captions = []

    for result in results.points:
        modified_url = modify_iiif_url(result.payload['image_url'], width, height)
        response = requests.get(modified_url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        images.append(img)
        captions.append(f"Score: {result.score:.2f}")

    return images, captions

# Define Gradio interface
iface = gr.Interface(
    fn=search_and_display,
    inputs=[
        gr.Textbox(label="Search Query"),
        gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
        gr.Slider(minimum=100, maximum=1000, step=50, label="Image Width", value=300),
        gr.Slider(minimum=100, maximum=1000, step=50, label="Image Height", value=300)
    ],
    outputs=[
        gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
        gr.JSON(label="Captions")
    ],
    title="Image Search with IIIF Resizing",
    description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images."
)

# Launch the Gradio interface
iface.launch()