Spaces:
Running
Running
as-cle-bert
commited on
Commit
•
626e10d
1
Parent(s):
2e814be
Create QdrantRag.py
Browse files- QdrantRag.py +212 -0
QdrantRag.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qdrant_client import QdrantClient, models
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from transformers import AutoModel, AutoImageProcessor
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from datasets import load_dataset
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
import numpy as np
|
9 |
+
import uuid
|
10 |
+
from PIL import Image, ImageFile
|
11 |
+
from fastembed import SparseTextEmbedding
|
12 |
+
import cohere
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
encoder = SentenceTransformer("sentence-transformers/LaBSE").to(device)
|
18 |
+
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
|
19 |
+
image_encoder = AutoModel.from_pretrained("facebook/dinov2-large").to(device)
|
20 |
+
qdrant_client = QdrantClient(url=os.getenv("qdrant_url"), api_key=os.getenv("qdrant_api_key"))
|
21 |
+
sparse_encoder = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
|
22 |
+
co = cohere.ClientV2(os.getenv("cohere_api_key"))
|
23 |
+
|
24 |
+
dataset = load_dataset("Karbo31881/Pokemon_images")
|
25 |
+
ds = dataset["train"]
|
26 |
+
labels = ds["text"]
|
27 |
+
|
28 |
+
def get_sparse_embedding(text: str, model: SparseTextEmbedding):
|
29 |
+
embeddings = list(model.embed(text))
|
30 |
+
vector = {f"sparse-text": models.SparseVector(indices=embeddings[0].indices, values=embeddings[0].values)}
|
31 |
+
return vector
|
32 |
+
|
33 |
+
def get_query_sparse_embedding(text: str, model: SparseTextEmbedding):
|
34 |
+
embeddings = list(model.embed(text))
|
35 |
+
query_vector = models.NamedSparseVector(
|
36 |
+
name="sparse-text",
|
37 |
+
vector=models.SparseVector(
|
38 |
+
indices=embeddings[0].indices,
|
39 |
+
values=embeddings[0].values,
|
40 |
+
),
|
41 |
+
)
|
42 |
+
return query_vector
|
43 |
+
|
44 |
+
def upload_text_to_qdrant(client: QdrantClient, collection_name: str, encoder: SentenceTransformer, text: str, point_id_dense: int, point_id_sparse: int):
|
45 |
+
try:
|
46 |
+
docs = {"text": text}
|
47 |
+
client.upsert(
|
48 |
+
collection_name=collection_name,
|
49 |
+
points=[
|
50 |
+
models.PointStruct(
|
51 |
+
id=point_id_dense,
|
52 |
+
vector={f"dense-text": encoder.encode(docs["text"]).tolist()},
|
53 |
+
payload=docs,
|
54 |
+
)
|
55 |
+
],
|
56 |
+
)
|
57 |
+
client.upsert(
|
58 |
+
collection_name=collection_name,
|
59 |
+
points=[
|
60 |
+
models.PointStruct(
|
61 |
+
id=point_id_sparse,
|
62 |
+
vector=get_sparse_embedding(docs["text"], sparse_encoder),
|
63 |
+
payload=docs,
|
64 |
+
)
|
65 |
+
],
|
66 |
+
)
|
67 |
+
return True
|
68 |
+
except Exception as e:
|
69 |
+
return False
|
70 |
+
|
71 |
+
def upload_images_to_qdrant(client: QdrantClient, collection_name: str, vectorsfile: str, labelslist: list):
|
72 |
+
try:
|
73 |
+
vectors = np.load(vectorsfile)
|
74 |
+
docs = []
|
75 |
+
for label in labelslist:
|
76 |
+
docs.append({"label": label})
|
77 |
+
client.upload_points(
|
78 |
+
collection_name=collection_name,
|
79 |
+
points=[
|
80 |
+
models.PointStruct(
|
81 |
+
id=idx,
|
82 |
+
vector=vectors[idx].tolist(),
|
83 |
+
payload=doc,
|
84 |
+
)
|
85 |
+
for idx, doc in enumerate(docs)
|
86 |
+
],
|
87 |
+
)
|
88 |
+
return True
|
89 |
+
except Exception as e:
|
90 |
+
return False
|
91 |
+
|
92 |
+
class SemanticCache:
|
93 |
+
def __init__(self, client: QdrantClient, text_encoder: SentenceTransformer, collection_name: str, threshold: float = 0.75):
|
94 |
+
self.client = client
|
95 |
+
self.text_encoder = text_encoder
|
96 |
+
self.collection_name = collection_name
|
97 |
+
self.threshold = threshold
|
98 |
+
def upload_to_cache(self, question: str, answer: str):
|
99 |
+
docs = {"question": question, "answer": answer}
|
100 |
+
point_id = str(uuid.uuid4())
|
101 |
+
self.client.upsert(
|
102 |
+
collection_name=self.collection_name,
|
103 |
+
points=[
|
104 |
+
models.PointStruct(
|
105 |
+
id=point_id,
|
106 |
+
vector=self.text_encoder.encode(docs["question"]).tolist(),
|
107 |
+
payload=docs,
|
108 |
+
)
|
109 |
+
],
|
110 |
+
)
|
111 |
+
def search_cache(self, question: str, limit: int = 5):
|
112 |
+
vector = self.text_encoder.encode(question).tolist()
|
113 |
+
search_result = self.client.search(
|
114 |
+
collection_name=self.collection_name,
|
115 |
+
query_vector=vector,
|
116 |
+
query_filter=None,
|
117 |
+
limit=limit,
|
118 |
+
)
|
119 |
+
payloads = [hit.payload["answer"] for hit in search_result if hit.score > self.threshold]
|
120 |
+
if len(payloads) > 0:
|
121 |
+
return payloads[0]
|
122 |
+
else:
|
123 |
+
return ""
|
124 |
+
|
125 |
+
|
126 |
+
class NeuralSearcher:
|
127 |
+
def __init__(self, text_collection_name: str, image_collection_name: str, client: QdrantClient, text_encoder: SentenceTransformer , image_encoder: AutoModel, image_processor: AutoImageProcessor, sparse_encoder: SparseTextEmbedding):
|
128 |
+
self.text_collection_name = text_collection_name
|
129 |
+
self.image_collection_name = image_collection_name
|
130 |
+
self.text_encoder = text_encoder
|
131 |
+
self.image_encoder = image_encoder
|
132 |
+
self.image_processor = image_processor
|
133 |
+
self.qdrant_client = client
|
134 |
+
self.sparse_encoder = sparse_encoder
|
135 |
+
|
136 |
+
def search_text(self, text: str, limit: int = 5):
|
137 |
+
vector = self.text_encoder.encode(text).tolist()
|
138 |
+
|
139 |
+
search_result_dense = self.qdrant_client.search(
|
140 |
+
collection_name=self.text_collection_name,
|
141 |
+
query_vector=models.NamedVector(name="dense-text", vector=vector),
|
142 |
+
query_filter=None,
|
143 |
+
limit=limit,
|
144 |
+
)
|
145 |
+
|
146 |
+
search_result_sparse = self.qdrant_client.search(
|
147 |
+
collection_name=self.text_collection_name,
|
148 |
+
query_vector=get_query_sparse_embedding(text, self.sparse_encoder),
|
149 |
+
query_filter=None,
|
150 |
+
limit=limit,
|
151 |
+
)
|
152 |
+
payloads = [hit.payload["text"] for hit in search_result_dense]
|
153 |
+
payloads += [hit.payload["text"] for hit in search_result_sparse]
|
154 |
+
return payloads
|
155 |
+
def reranking(self, text: str, search_result: list):
|
156 |
+
results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3)
|
157 |
+
ranked_results = [search_result[results.results[i].index] for i in range(3)]
|
158 |
+
return ranked_results
|
159 |
+
def search_image(self, image: ImageFile, limit: int = 1):
|
160 |
+
img = image
|
161 |
+
inputs = self.image_processor(images=img, return_tensors="pt").to(device)
|
162 |
+
with torch.no_grad():
|
163 |
+
outputs = self.image_encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
164 |
+
search_result = self.qdrant_client.search(
|
165 |
+
collection_name=self.image_collection_name,
|
166 |
+
query_vector=outputs[0].tolist(),
|
167 |
+
query_filter=None,
|
168 |
+
limit=limit,
|
169 |
+
)
|
170 |
+
payloads = [hit.payload["label"] for hit in search_result]
|
171 |
+
return payloads
|
172 |
+
|
173 |
+
qdrant_client.recreate_collection(
|
174 |
+
collection_name="pokemon_texts",
|
175 |
+
vectors_config={"dense-text": models.VectorParams(
|
176 |
+
size=768, # Vector size is defined by used model
|
177 |
+
distance=models.Distance.COSINE,
|
178 |
+
)},
|
179 |
+
sparse_vectors_config={"sparse-text": models.SparseVectorParams(
|
180 |
+
index=models.SparseIndexParams(
|
181 |
+
on_disk=False
|
182 |
+
)
|
183 |
+
)}
|
184 |
+
)
|
185 |
+
textdata = load_dataset("wanghaofan/pokemon-wiki-captions")
|
186 |
+
names = textdata["train"]["name_en"]
|
187 |
+
texts = textdata["train"]["text_en"]
|
188 |
+
|
189 |
+
c = 0
|
190 |
+
|
191 |
+
for j in range(len(texts)):
|
192 |
+
txt = names[j].upper() + "\n\n" + texts[j]
|
193 |
+
l = c+1
|
194 |
+
upload_text_to_qdrant(qdrant_client, "pokemon_texts", encoder, txt, c, l)
|
195 |
+
c = l+1
|
196 |
+
|
197 |
+
qdrant_client.recreate_collection(
|
198 |
+
collection_name="pokemon_images",
|
199 |
+
vectors_config=models.VectorParams(
|
200 |
+
size=1024, # Vector size is defined by used model
|
201 |
+
distance=models.Distance.COSINE,
|
202 |
+
),
|
203 |
+
)
|
204 |
+
upload_images_to_qdrant(qdrant_client, "pokemon_images", "data/vector_pokemon.npy", labels)
|
205 |
+
|
206 |
+
qdrant_client.recreate_collection(
|
207 |
+
collection_name="semantic_cache",
|
208 |
+
vectors_config=models.VectorParams(
|
209 |
+
size=768, # Vector size is defined by used model
|
210 |
+
distance=models.Distance.COSINE,
|
211 |
+
),
|
212 |
+
)
|