Spaces:
Sleeping
Sleeping
File size: 2,289 Bytes
940f08c b71501e 940f08c b71501e 940f08c b71501e 940f08c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableMap
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
load_dotenv()
# Load FAISS and RAG
def load_rag_pipeline():
embeddings = HuggingFaceEmbeddings(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
db = FAISS.load_local("parkinsons_vector_db", embeddings, allow_dangerous_deserialization=True)
retriever = db.as_retriever(search_kwargs={"k": 5})
template = """You are a Parkinson's disease expert. Follow these rules:
1. Use {language_style} language (technical/simple)
2. Base answers ONLY on these sources
3. Cite sources in your answer
4. If unsure, say "I don't know"
Sources:
{context}
Question: {question}
Answer:"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(model="gpt-3.5-turbo")
def build_context(inp):
question_text = inp.get("question", "")
language_style_text = inp.get("language_style", "")
docs = retriever.get_relevant_documents(question_text)
context = " ".join(doc.page_content for doc in docs)
return {
"question": question_text,
"language_style": language_style_text,
"context": context
}
rag_chain = (
RunnableMap({
"question": RunnablePassthrough(),
"language_style": RunnablePassthrough()
})
| build_context
| prompt
| llm
)
return rag_chain
rag_chain = load_rag_pipeline()
# Gradio Function
def query_rag(question, language_style):
response = rag_chain.invoke({"question": question, "language_style": language_style})
return response.content
# Create Gradio Interface
iface = gr.Interface(
fn=query_rag,
inputs=[
gr.Textbox(label="Question"),
gr.Radio(["simple", "technical"], label="Language Style", value="simple"),
],
outputs=gr.Textbox(label="Answer"),
title="Parkinson's RAG Assistant",
description="Ask questions about Parkinson's disease with simple or technical explanations.",
)
iface.launch()
|