File size: 2,289 Bytes
940f08c
b71501e
940f08c
 
 
 
 
 
b71501e
940f08c
b71501e
940f08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableMap
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

load_dotenv()

# Load FAISS and RAG
def load_rag_pipeline():
    embeddings = HuggingFaceEmbeddings(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
    db = FAISS.load_local("parkinsons_vector_db", embeddings, allow_dangerous_deserialization=True)
    retriever = db.as_retriever(search_kwargs={"k": 5})

    template = """You are a Parkinson's disease expert. Follow these rules:
1. Use {language_style} language (technical/simple)
2. Base answers ONLY on these sources
3. Cite sources in your answer
4. If unsure, say "I don't know"

Sources:
{context}

Question: {question}
Answer:"""

    prompt = ChatPromptTemplate.from_template(template)
    llm = ChatOpenAI(model="gpt-3.5-turbo")

    def build_context(inp):
        question_text = inp.get("question", "")
        language_style_text = inp.get("language_style", "")

        docs = retriever.get_relevant_documents(question_text)
        context = " ".join(doc.page_content for doc in docs)

        return {
            "question": question_text,
            "language_style": language_style_text,
            "context": context
        }

    rag_chain = (
        RunnableMap({
            "question": RunnablePassthrough(),
            "language_style": RunnablePassthrough()
        })
        | build_context
        | prompt
        | llm
    )

    return rag_chain

rag_chain = load_rag_pipeline()

# Gradio Function
def query_rag(question, language_style):
    response = rag_chain.invoke({"question": question, "language_style": language_style})
    return response.content

# Create Gradio Interface
iface = gr.Interface(
    fn=query_rag,
    inputs=[
        gr.Textbox(label="Question"),
        gr.Radio(["simple", "technical"], label="Language Style", value="simple"),
    ],
    outputs=gr.Textbox(label="Answer"),
    title="Parkinson's RAG Assistant",
    description="Ask questions about Parkinson's disease with simple or technical explanations.",
)

iface.launch()