import os import gradio as gr from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough, RunnableMap from langchain_openai import ChatOpenAI from dotenv import load_dotenv load_dotenv() # Load FAISS and RAG def load_rag_pipeline(): embeddings = HuggingFaceEmbeddings(model_name="pritamdeka/S-PubMedBert-MS-MARCO") db = FAISS.load_local("parkinsons_vector_db", embeddings, allow_dangerous_deserialization=True) retriever = db.as_retriever(search_kwargs={"k": 5}) template = """You are a Parkinson's disease expert. Follow these rules: 1. Use {language_style} language (technical/simple) 2. Base answers ONLY on these sources 3. Cite sources in your answer 4. If unsure, say "I don't know" Sources: {context} Question: {question} Answer:""" prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(model="gpt-3.5-turbo") def build_context(inp): question_text = inp.get("question", "") language_style_text = inp.get("language_style", "") docs = retriever.get_relevant_documents(question_text) context = " ".join(doc.page_content for doc in docs) return { "question": question_text, "language_style": language_style_text, "context": context } rag_chain = ( RunnableMap({ "question": RunnablePassthrough(), "language_style": RunnablePassthrough() }) | build_context | prompt | llm ) return rag_chain rag_chain = load_rag_pipeline() # Gradio Function def query_rag(question, language_style): response = rag_chain.invoke({"question": question, "language_style": language_style}) return response.content # Create Gradio Interface iface = gr.Interface( fn=query_rag, inputs=[ gr.Textbox(label="Question"), gr.Radio(["simple", "technical"], label="Language Style", value="simple"), ], outputs=gr.Textbox(label="Answer"), title="Parkinson's RAG Assistant", description="Ask questions about Parkinson's disease with simple or technical explanations.", ) iface.launch()