msh2481
commited on
Commit
•
ef93aa3
1
Parent(s):
82b3a71
Step 3
Browse files- backend/semantic_search.py +15 -7
backend/semantic_search.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
import lancedb
|
2 |
import os
|
3 |
-
import gradio as gr
|
4 |
-
from sentence_transformers import SentenceTransformer
|
5 |
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
@@ -10,17 +10,25 @@ TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
|
10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
|
|
13 |
|
14 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
|
|
15 |
|
16 |
|
17 |
-
def retrieve(query, k):
|
18 |
query_vec = retriever.encode(query)
|
19 |
try:
|
20 |
-
documents =
|
|
|
|
|
|
|
|
|
21 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
except Exception as e:
|
26 |
raise gr.Error(str(e))
|
|
|
1 |
+
import lancedb # type: ignore
|
2 |
import os
|
3 |
+
import gradio as gr # type: ignore
|
4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder # type: ignore
|
5 |
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
|
|
10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
13 |
+
RERANKER = os.getenv("RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
14 |
|
15 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
16 |
+
reranker = CrossEncoder(RERANKER)
|
17 |
|
18 |
|
19 |
+
def retrieve(query, k, rerank_factor=3):
|
20 |
query_vec = retriever.encode(query)
|
21 |
try:
|
22 |
+
documents = (
|
23 |
+
TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN)
|
24 |
+
.limit(k * rerank_factor)
|
25 |
+
.to_list()
|
26 |
+
)
|
27 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
28 |
+
scores = reranker.predict([(query, doc) for doc in documents])
|
29 |
+
best_scores_and_documents = sorted(zip(scores, documents), reverse=True)[:k]
|
30 |
+
best_documents = [doc[1] for doc in best_scores_and_documents]
|
31 |
+
return best_documents
|
32 |
|
33 |
except Exception as e:
|
34 |
raise gr.Error(str(e))
|