ramiibrahim2002 commited on
Commit
940f08c
·
1 Parent(s): f0181fb
Files changed (1) hide show
  1. app.py +74 -4
app.py CHANGED
@@ -1,7 +1,77 @@
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.runnables import RunnablePassthrough, RunnableMap
7
+ from langchain_openai import ChatOpenAI
8
+ from dotenv import load_dotenv
9
 
10
+ load_dotenv()
 
11
 
12
+ # Load FAISS and RAG
13
+ def load_rag_pipeline():
14
+ embeddings = HuggingFaceEmbeddings(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
15
+ db = FAISS.load_local("parkinsons_vector_db", embeddings, allow_dangerous_deserialization=True)
16
+ retriever = db.as_retriever(search_kwargs={"k": 5})
17
+
18
+ template = """You are a Parkinson's disease expert. Follow these rules:
19
+ 1. Use {language_style} language (technical/simple)
20
+ 2. Base answers ONLY on these sources
21
+ 3. Cite sources in your answer
22
+ 4. If unsure, say "I don't know"
23
+
24
+ Sources:
25
+ {context}
26
+
27
+ Question: {question}
28
+ Answer:"""
29
+
30
+ prompt = ChatPromptTemplate.from_template(template)
31
+ llm = ChatOpenAI(model="gpt-3.5-turbo")
32
+
33
+ def build_context(inp):
34
+ question_text = inp.get("question", "")
35
+ language_style_text = inp.get("language_style", "")
36
+
37
+ docs = retriever.get_relevant_documents(question_text)
38
+ context = " ".join(doc.page_content for doc in docs)
39
+
40
+ return {
41
+ "question": question_text,
42
+ "language_style": language_style_text,
43
+ "context": context
44
+ }
45
+
46
+ rag_chain = (
47
+ RunnableMap({
48
+ "question": RunnablePassthrough(),
49
+ "language_style": RunnablePassthrough()
50
+ })
51
+ | build_context
52
+ | prompt
53
+ | llm
54
+ )
55
+
56
+ return rag_chain
57
+
58
+ rag_chain = load_rag_pipeline()
59
+
60
+ # Gradio Function
61
+ def query_rag(question, language_style):
62
+ response = rag_chain.invoke({"question": question, "language_style": language_style})
63
+ return response.content
64
+
65
+ # Create Gradio Interface
66
+ iface = gr.Interface(
67
+ fn=query_rag,
68
+ inputs=[
69
+ gr.Textbox(label="Question"),
70
+ gr.Radio(["simple", "technical"], label="Language Style", value="simple"),
71
+ ],
72
+ outputs=gr.Textbox(label="Answer"),
73
+ title="Parkinson's RAG Assistant",
74
+ description="Ask questions about Parkinson's disease with simple or technical explanations.",
75
+ )
76
+
77
+ iface.launch()