import lancedb # type: ignore import os import gradio as gr # type: ignore from sentence_transformers import SentenceTransformer, CrossEncoder # type: ignore db = lancedb.connect(".lancedb") TABLE = db.open_table(os.getenv("TABLE_NAME")) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) RERANKER = os.getenv("RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2") retriever = SentenceTransformer(os.getenv("EMB_MODEL")) reranker = CrossEncoder(RERANKER) def retrieve(query, k, rerank_factor=3): query_vec = retriever.encode(query) try: documents = ( TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN) .limit(k * rerank_factor) .to_list() ) documents = [doc[TEXT_COLUMN] for doc in documents] scores = reranker.predict([(query, doc) for doc in documents]) best_scores_and_documents = sorted(zip(scores, documents), reverse=True)[:k] best_documents = [doc[1] for doc in best_scores_and_documents] return best_documents except Exception as e: raise gr.Error(str(e))