as-cle-bert commited on
Commit
626e10d
1 Parent(s): 2e814be

Create QdrantRag.py

Browse files
Files changed (1) hide show
  1. QdrantRag.py +212 -0
QdrantRag.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient, models
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoModel, AutoImageProcessor
4
+ import torch
5
+ import os
6
+ from datasets import load_dataset
7
+ from dotenv import load_dotenv
8
+ import numpy as np
9
+ import uuid
10
+ from PIL import Image, ImageFile
11
+ from fastembed import SparseTextEmbedding
12
+ import cohere
13
+
14
+ load_dotenv()
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ encoder = SentenceTransformer("sentence-transformers/LaBSE").to(device)
18
+ processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
19
+ image_encoder = AutoModel.from_pretrained("facebook/dinov2-large").to(device)
20
+ qdrant_client = QdrantClient(url=os.getenv("qdrant_url"), api_key=os.getenv("qdrant_api_key"))
21
+ sparse_encoder = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
22
+ co = cohere.ClientV2(os.getenv("cohere_api_key"))
23
+
24
+ dataset = load_dataset("Karbo31881/Pokemon_images")
25
+ ds = dataset["train"]
26
+ labels = ds["text"]
27
+
28
+ def get_sparse_embedding(text: str, model: SparseTextEmbedding):
29
+ embeddings = list(model.embed(text))
30
+ vector = {f"sparse-text": models.SparseVector(indices=embeddings[0].indices, values=embeddings[0].values)}
31
+ return vector
32
+
33
+ def get_query_sparse_embedding(text: str, model: SparseTextEmbedding):
34
+ embeddings = list(model.embed(text))
35
+ query_vector = models.NamedSparseVector(
36
+ name="sparse-text",
37
+ vector=models.SparseVector(
38
+ indices=embeddings[0].indices,
39
+ values=embeddings[0].values,
40
+ ),
41
+ )
42
+ return query_vector
43
+
44
+ def upload_text_to_qdrant(client: QdrantClient, collection_name: str, encoder: SentenceTransformer, text: str, point_id_dense: int, point_id_sparse: int):
45
+ try:
46
+ docs = {"text": text}
47
+ client.upsert(
48
+ collection_name=collection_name,
49
+ points=[
50
+ models.PointStruct(
51
+ id=point_id_dense,
52
+ vector={f"dense-text": encoder.encode(docs["text"]).tolist()},
53
+ payload=docs,
54
+ )
55
+ ],
56
+ )
57
+ client.upsert(
58
+ collection_name=collection_name,
59
+ points=[
60
+ models.PointStruct(
61
+ id=point_id_sparse,
62
+ vector=get_sparse_embedding(docs["text"], sparse_encoder),
63
+ payload=docs,
64
+ )
65
+ ],
66
+ )
67
+ return True
68
+ except Exception as e:
69
+ return False
70
+
71
+ def upload_images_to_qdrant(client: QdrantClient, collection_name: str, vectorsfile: str, labelslist: list):
72
+ try:
73
+ vectors = np.load(vectorsfile)
74
+ docs = []
75
+ for label in labelslist:
76
+ docs.append({"label": label})
77
+ client.upload_points(
78
+ collection_name=collection_name,
79
+ points=[
80
+ models.PointStruct(
81
+ id=idx,
82
+ vector=vectors[idx].tolist(),
83
+ payload=doc,
84
+ )
85
+ for idx, doc in enumerate(docs)
86
+ ],
87
+ )
88
+ return True
89
+ except Exception as e:
90
+ return False
91
+
92
+ class SemanticCache:
93
+ def __init__(self, client: QdrantClient, text_encoder: SentenceTransformer, collection_name: str, threshold: float = 0.75):
94
+ self.client = client
95
+ self.text_encoder = text_encoder
96
+ self.collection_name = collection_name
97
+ self.threshold = threshold
98
+ def upload_to_cache(self, question: str, answer: str):
99
+ docs = {"question": question, "answer": answer}
100
+ point_id = str(uuid.uuid4())
101
+ self.client.upsert(
102
+ collection_name=self.collection_name,
103
+ points=[
104
+ models.PointStruct(
105
+ id=point_id,
106
+ vector=self.text_encoder.encode(docs["question"]).tolist(),
107
+ payload=docs,
108
+ )
109
+ ],
110
+ )
111
+ def search_cache(self, question: str, limit: int = 5):
112
+ vector = self.text_encoder.encode(question).tolist()
113
+ search_result = self.client.search(
114
+ collection_name=self.collection_name,
115
+ query_vector=vector,
116
+ query_filter=None,
117
+ limit=limit,
118
+ )
119
+ payloads = [hit.payload["answer"] for hit in search_result if hit.score > self.threshold]
120
+ if len(payloads) > 0:
121
+ return payloads[0]
122
+ else:
123
+ return ""
124
+
125
+
126
+ class NeuralSearcher:
127
+ def __init__(self, text_collection_name: str, image_collection_name: str, client: QdrantClient, text_encoder: SentenceTransformer , image_encoder: AutoModel, image_processor: AutoImageProcessor, sparse_encoder: SparseTextEmbedding):
128
+ self.text_collection_name = text_collection_name
129
+ self.image_collection_name = image_collection_name
130
+ self.text_encoder = text_encoder
131
+ self.image_encoder = image_encoder
132
+ self.image_processor = image_processor
133
+ self.qdrant_client = client
134
+ self.sparse_encoder = sparse_encoder
135
+
136
+ def search_text(self, text: str, limit: int = 5):
137
+ vector = self.text_encoder.encode(text).tolist()
138
+
139
+ search_result_dense = self.qdrant_client.search(
140
+ collection_name=self.text_collection_name,
141
+ query_vector=models.NamedVector(name="dense-text", vector=vector),
142
+ query_filter=None,
143
+ limit=limit,
144
+ )
145
+
146
+ search_result_sparse = self.qdrant_client.search(
147
+ collection_name=self.text_collection_name,
148
+ query_vector=get_query_sparse_embedding(text, self.sparse_encoder),
149
+ query_filter=None,
150
+ limit=limit,
151
+ )
152
+ payloads = [hit.payload["text"] for hit in search_result_dense]
153
+ payloads += [hit.payload["text"] for hit in search_result_sparse]
154
+ return payloads
155
+ def reranking(self, text: str, search_result: list):
156
+ results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3)
157
+ ranked_results = [search_result[results.results[i].index] for i in range(3)]
158
+ return ranked_results
159
+ def search_image(self, image: ImageFile, limit: int = 1):
160
+ img = image
161
+ inputs = self.image_processor(images=img, return_tensors="pt").to(device)
162
+ with torch.no_grad():
163
+ outputs = self.image_encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
164
+ search_result = self.qdrant_client.search(
165
+ collection_name=self.image_collection_name,
166
+ query_vector=outputs[0].tolist(),
167
+ query_filter=None,
168
+ limit=limit,
169
+ )
170
+ payloads = [hit.payload["label"] for hit in search_result]
171
+ return payloads
172
+
173
+ qdrant_client.recreate_collection(
174
+ collection_name="pokemon_texts",
175
+ vectors_config={"dense-text": models.VectorParams(
176
+ size=768, # Vector size is defined by used model
177
+ distance=models.Distance.COSINE,
178
+ )},
179
+ sparse_vectors_config={"sparse-text": models.SparseVectorParams(
180
+ index=models.SparseIndexParams(
181
+ on_disk=False
182
+ )
183
+ )}
184
+ )
185
+ textdata = load_dataset("wanghaofan/pokemon-wiki-captions")
186
+ names = textdata["train"]["name_en"]
187
+ texts = textdata["train"]["text_en"]
188
+
189
+ c = 0
190
+
191
+ for j in range(len(texts)):
192
+ txt = names[j].upper() + "\n\n" + texts[j]
193
+ l = c+1
194
+ upload_text_to_qdrant(qdrant_client, "pokemon_texts", encoder, txt, c, l)
195
+ c = l+1
196
+
197
+ qdrant_client.recreate_collection(
198
+ collection_name="pokemon_images",
199
+ vectors_config=models.VectorParams(
200
+ size=1024, # Vector size is defined by used model
201
+ distance=models.Distance.COSINE,
202
+ ),
203
+ )
204
+ upload_images_to_qdrant(qdrant_client, "pokemon_images", "data/vector_pokemon.npy", labels)
205
+
206
+ qdrant_client.recreate_collection(
207
+ collection_name="semantic_cache",
208
+ vectors_config=models.VectorParams(
209
+ size=768, # Vector size is defined by used model
210
+ distance=models.Distance.COSINE,
211
+ ),
212
+ )