msh2481 commited on
Commit
ef93aa3
1 Parent(s): 82b3a71
Files changed (1) hide show
  1. backend/semantic_search.py +15 -7
backend/semantic_search.py CHANGED
@@ -1,7 +1,7 @@
1
- import lancedb
2
  import os
3
- import gradio as gr
4
- from sentence_transformers import SentenceTransformer
5
 
6
 
7
  db = lancedb.connect(".lancedb")
@@ -10,17 +10,25 @@ TABLE = db.open_table(os.getenv("TABLE_NAME"))
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
 
13
 
14
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
 
15
 
16
 
17
- def retrieve(query, k):
18
  query_vec = retriever.encode(query)
19
  try:
20
- documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
 
 
 
 
21
  documents = [doc[TEXT_COLUMN] for doc in documents]
22
-
23
- return documents
 
 
24
 
25
  except Exception as e:
26
  raise gr.Error(str(e))
 
1
+ import lancedb # type: ignore
2
  import os
3
+ import gradio as gr # type: ignore
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder # type: ignore
5
 
6
 
7
  db = lancedb.connect(".lancedb")
 
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
+ RERANKER = os.getenv("RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2")
14
 
15
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
16
+ reranker = CrossEncoder(RERANKER)
17
 
18
 
19
+ def retrieve(query, k, rerank_factor=3):
20
  query_vec = retriever.encode(query)
21
  try:
22
+ documents = (
23
+ TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN)
24
+ .limit(k * rerank_factor)
25
+ .to_list()
26
+ )
27
  documents = [doc[TEXT_COLUMN] for doc in documents]
28
+ scores = reranker.predict([(query, doc) for doc in documents])
29
+ best_scores_and_documents = sorted(zip(scores, documents), reverse=True)[:k]
30
+ best_documents = [doc[1] for doc in best_scores_and_documents]
31
+ return best_documents
32
 
33
  except Exception as e:
34
  raise gr.Error(str(e))