|
import lancedb |
|
import os |
|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
|
|
|
db = lancedb.connect(".lancedb") |
|
|
|
TABLE = db.open_table(os.getenv("TABLE_NAME")) |
|
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") |
|
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") |
|
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) |
|
RERANKER = os.getenv("RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
|
retriever = SentenceTransformer(os.getenv("EMB_MODEL")) |
|
reranker = CrossEncoder(RERANKER) |
|
|
|
|
|
def retrieve(query, k, rerank_factor=3): |
|
query_vec = retriever.encode(query) |
|
try: |
|
documents = ( |
|
TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN) |
|
.limit(k * rerank_factor) |
|
.to_list() |
|
) |
|
documents = [doc[TEXT_COLUMN] for doc in documents] |
|
scores = reranker.predict([(query, doc) for doc in documents]) |
|
best_scores_and_documents = sorted(zip(scores, documents), reverse=True)[:k] |
|
best_documents = [doc[1] for doc in best_scores_and_documents] |
|
return best_documents |
|
|
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|