davanstrien HF staff commited on
Commit
31ab17b
·
verified ·
1 Parent(s): f08fe33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from qdrant_client import QdrantClient
8
+ from transformers import ColQwen2, ColQwen2Processor
9
+
10
+ # Initialize ColPali model and processor
11
+ model_name = "vidore/colqwen2-v0.1"
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed
13
+ colpali_model = ColQwen2.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.bfloat16,
16
+ device_map=device,
17
+ )
18
+ colpali_processor = ColQwen2Processor.from_pretrained(
19
+ model_name,
20
+ )
21
+
22
+ # Initialize Qdrant client
23
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
24
+ qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
25
+ port=None, api_key=QDRANT_API_KEY, timeout=10)
26
+
27
+ collection_name = "your_collection_name" # Replace with your actual collection name
28
+
29
+ def search_images_by_text(query_text, top_k=5):
30
+ # Process and encode the text query
31
+ with torch.no_grad():
32
+ batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
33
+ query_embedding = colpali_model(**batch_query)
34
+
35
+ # Convert the query embedding to a list of vectors
36
+ multivector_query = query_embedding[0].cpu().float().numpy().tolist()
37
+
38
+ # Search in Qdrant
39
+ search_result = qdrant_client.query_points(
40
+ collection_name=collection_name,
41
+ query=multivector_query,
42
+ limit=top_k,
43
+ timeout=800,
44
+ )
45
+
46
+ return search_result
47
+
48
+ def modify_iiif_url(url, width, height):
49
+ parts = url.split('/')
50
+ size_index = -3
51
+ parts[size_index] = f"{width},{height}"
52
+ return '/'.join(parts)
53
+
54
+ def search_and_display(query, top_k, width, height):
55
+ results = search_images_by_text(query, top_k)
56
+ images = []
57
+ captions = []
58
+
59
+ for result in results.points:
60
+ modified_url = modify_iiif_url(result.payload['image_url'], width, height)
61
+ response = requests.get(modified_url)
62
+ img = Image.open(BytesIO(response.content)).convert("RGB")
63
+ images.append(img)
64
+ captions.append(f"Score: {result.score:.2f}")
65
+
66
+ return images, captions
67
+
68
+ # Define Gradio interface
69
+ iface = gr.Interface(
70
+ fn=search_and_display,
71
+ inputs=[
72
+ gr.Textbox(label="Search Query"),
73
+ gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
74
+ gr.Slider(minimum=100, maximum=1000, step=50, label="Image Width", value=300),
75
+ gr.Slider(minimum=100, maximum=1000, step=50, label="Image Height", value=300)
76
+ ],
77
+ outputs=[
78
+ gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
79
+ gr.JSON(label="Captions")
80
+ ],
81
+ title="Image Search with IIIF Resizing",
82
+ description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images."
83
+ )
84
+
85
+ # Launch the Gradio interface
86
+ iface.launch()