import gradio as gr
from sentence_transformers import CrossEncoder
import pandas as pd

reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")


def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame:
    documents = documents.copy()
    documents = documents.drop_duplicates("text")
    documents["rank"] = reranker.predict([[query, hit] for hit in documents["text"]])
    documents = documents.sort_values(by="rank", ascending=False)
    return documents


with gr.Blocks() as demo:
    gr.Markdown("""# RAG - Augment 
                
                Applies reranking to the retrieved documents using [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2).
                
                Part of [AI blueprint](https://github.com/huggingface/ai-blueprint) - a blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""")

    query_input = gr.Textbox(
        label="Query", placeholder="Enter your question here...", lines=3
    )
    documents_input = gr.Dataframe(
        label="Documents", headers=["text"], wrap=True, interactive=True
    )
        
    submit_btn = gr.Button("Submit")
    documents_output = gr.Dataframe(
        label="Documents", headers=["text", "rank"], wrap=True
    )

    submit_btn.click(
        fn=rerank,
        inputs=[query_input, documents_input],
        outputs=[documents_output],
    )

demo.launch()