rishi002 commited on
Commit
4fce8de
Β·
verified Β·
1 Parent(s): fc2736d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -67
app.py CHANGED
@@ -1,14 +1,24 @@
1
  import os
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
3
  from langchain.chains import create_retrieval_chain
4
  from langchain.chains.combine_documents import create_stuff_documents_chain
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain.llms.base import LLM
7
- from collections import OrderedDict
8
- from typing import Optional, List
9
  import google.generativeai as genai
10
 
11
- # Custom utility functions
12
  from embeddings import (
13
  load_pdf_files,
14
  create_chunks,
@@ -31,22 +41,19 @@ if not GOOGLE_API_KEY:
31
  else:
32
  genai.configure(api_key=GOOGLE_API_KEY)
33
 
34
- # Load or create FAISS vector store
35
- def load_or_create_faiss():
36
- embedding_model = get_embedding_model()
37
- if not os.path.exists(DB_FAISS_PATH):
38
- print("πŸ”„ FAISS index not found. Creating new index...")
39
- documents = load_pdf_files(DATA_PATH)
40
- text_chunks = create_chunks(documents)
41
- db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
42
- else:
43
- print("βœ… Existing FAISS index found. Loading it...")
44
- db = load_faiss_db(DB_FAISS_PATH, embedding_model)
45
- return db
46
-
47
- db = load_or_create_faiss()
48
 
49
- # βœ… Custom Gemini LLM wrapper for LangChain
50
  class GeminiLLM(LLM):
51
  model_name: str = "gemini-2.0-flash"
52
 
@@ -73,74 +80,395 @@ class GeminiLLM(LLM):
73
  def _llm_type(self):
74
  return "gemini"
75
 
76
- # Prompt template with user health profile - Updated for modern LangChain
77
- CUSTOM_PROMPT_TEMPLATE = """
78
- You are an EXPERT MEDICAL ADVISOR. Use the user's Health Profile to answer the medical query of user accurately and professionally.
79
- If you do not have exact answer in the vector document then you should ask some follow up questions before answering the question.
80
- If you don't know the answer, just say that you don't know. Don't make up an answer.
81
- Only provide information from the given context.
82
- Keep your answer concise and avoid repeating the same information.
83
- Each important point should be stated only once.
84
- NOTE: SUMMARIZE YOUR ANSWERS STRICTLY WITHIN 300 WORDS.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- Context: {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- Question: {input}
89
 
90
- Start the answer directly.
 
91
  """
 
 
 
92
 
93
- # QA Chain constructor using modern LangChain approach
94
- def create_qa_chain():
95
- prompt = PromptTemplate(
96
- template=CUSTOM_PROMPT_TEMPLATE,
97
- input_variables=["context", "input"]
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- gemini_llm = GeminiLLM()
 
 
 
 
101
 
102
- # Create the document chain
103
- combine_docs_chain = create_stuff_documents_chain(gemini_llm, prompt)
104
 
105
- # Create the retrieval chain
106
- return create_retrieval_chain(
107
- db.as_retriever(search_kwargs={'k': 3}),
108
- combine_docs_chain
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- qa_chain = create_qa_chain()
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Function to handle question asking
114
- def ask_question(query: str, health_info: str = "No health profile provided"):
115
- try:
116
- # Combine user question and health info into one input
117
- full_query = f"User Health Profile: {health_info}\nQuestion: {query}"
118
 
119
- # Use the correct input key for modern LangChain
120
- response = qa_chain.invoke({"input": full_query})
121
- result = response["answer"] # Modern chains return "answer" not "result"
122
 
123
- # Deduplicate output
124
- sentences = [s.strip() for s in result.split('.') if s.strip()]
125
- unique_sentences = list(OrderedDict.fromkeys(sentences))
126
- cleaned_result = '. '.join(unique_sentences) + '.'
 
 
 
 
127
 
128
- return cleaned_result, []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  except Exception as e:
131
- return f"Error: {str(e)}", []
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Gradio Interface
135
  iface = gr.Interface(
136
- fn=ask_question,
137
  inputs=[
138
- gr.Textbox(label="Question", placeholder="Enter your question here..."),
139
- gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
 
140
  ],
141
- outputs=["text", "json"],
142
- title="Medical RAG Chatbot",
143
- description="Ask medical questions and optionally provide your health profile for personalized responses."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- iface.launch(share=True)
 
 
1
  import os
2
  import gradio as gr
3
+ from typing import Dict, List, Optional, TypedDict, Annotated
4
+ from collections import OrderedDict
5
+ import json
6
+ import re
7
+ from datetime import datetime
8
+
9
+ # LangGraph imports
10
+ from langgraph.graph import StateGraph, END
11
+ from langgraph.graph.message import add_messages
12
+ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
13
+
14
+ # LangChain imports
15
  from langchain.chains import create_retrieval_chain
16
  from langchain.chains.combine_documents import create_stuff_documents_chain
17
  from langchain_core.prompts import PromptTemplate
18
  from langchain.llms.base import LLM
 
 
19
  import google.generativeai as genai
20
 
21
+ # Your existing embeddings utility
22
  from embeddings import (
23
  load_pdf_files,
24
  create_chunks,
 
41
  else:
42
  genai.configure(api_key=GOOGLE_API_KEY)
43
 
44
+ # ===== LANGGRAPH STATE DEFINITION =====
45
+ class ConversationState(TypedDict):
46
+ messages: Annotated[list[BaseMessage], add_messages]
47
+ current_query: str
48
+ health_profile: str
49
+ medical_entities: Dict[str, List[str]] # symptoms, medications, conditions, etc.
50
+ conversation_summary: str
51
+ retrieved_documents: List[str]
52
+ clarifying_questions: List[str]
53
+ needs_clarification: bool
54
+ session_context: str
 
 
 
55
 
56
+ # ===== CUSTOM GEMINI LLM =====
57
  class GeminiLLM(LLM):
58
  model_name: str = "gemini-2.0-flash"
59
 
 
80
  def _llm_type(self):
81
  return "gemini"
82
 
83
+ # ===== MEDICAL ENTITY EXTRACTOR =====
84
+ class MedicalEntityExtractor:
85
+ def __init__(self, llm):
86
+ self.llm = llm
87
+ self.extraction_prompt = """
88
+ Extract medical entities from the following text. Return a JSON object with these categories:
89
+ - symptoms: List of symptoms mentioned
90
+ - medications: List of medications mentioned
91
+ - conditions: List of medical conditions mentioned
92
+ - body_parts: List of body parts mentioned
93
+ - severity: Any severity indicators (mild, severe, etc.)
94
+ - duration: Any time-related information
95
+
96
+ Text: {text}
97
+
98
+ Return only valid JSON:
99
+ """
100
+
101
+ def extract_entities(self, text: str) -> Dict[str, List[str]]:
102
+ try:
103
+ prompt = self.extraction_prompt.format(text=text)
104
+ response = self.llm._call(prompt)
105
+
106
+ # Try to parse JSON response
107
+ try:
108
+ entities = json.loads(response)
109
+ return entities
110
+ except:
111
+ # Fallback to simple regex extraction
112
+ return self._fallback_extraction(text)
113
+ except:
114
+ return self._fallback_extraction(text)
115
+
116
+ def _fallback_extraction(self, text: str) -> Dict[str, List[str]]:
117
+ # Simple keyword-based extraction as fallback
118
+ symptoms_keywords = ['fever', 'headache', 'cough', 'pain', 'nausea', 'vomiting', 'diarrhea', 'fatigue', 'weakness']
119
+ medications_keywords = ['paracetamol', 'ibuprofen', 'aspirin', 'acetaminophen', 'antibiotic']
120
+
121
+ text_lower = text.lower()
122
+
123
+ return {
124
+ "symptoms": [s for s in symptoms_keywords if s in text_lower],
125
+ "medications": [m for m in medications_keywords if m in text_lower],
126
+ "conditions": [],
127
+ "body_parts": [],
128
+ "severity": [],
129
+ "duration": []
130
+ }
131
+
132
+ # ===== LOAD FAISS DB =====
133
+ def load_or_create_faiss():
134
+ embedding_model = get_embedding_model()
135
+ if not os.path.exists(DB_FAISS_PATH):
136
+ print("πŸ”„ FAISS index not found. Creating new index...")
137
+ documents = load_pdf_files(DATA_PATH)
138
+ text_chunks = create_chunks(documents)
139
+ db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
140
+ else:
141
+ print("βœ… Existing FAISS index found. Loading it...")
142
+ db = load_faiss_db(DB_FAISS_PATH, embedding_model)
143
+ return db
144
+
145
+ db = load_or_create_faiss()
146
+ gemini_llm = GeminiLLM()
147
+ entity_extractor = MedicalEntityExtractor(gemini_llm)
148
+
149
+ # ===== LANGGRAPH NODES =====
150
 
151
+ def entity_extraction_node(state: ConversationState) -> ConversationState:
152
+ """Extract medical entities from current query and update state."""
153
+ current_query = state["current_query"]
154
+
155
+ # Extract entities from current query
156
+ new_entities = entity_extractor.extract_entities(current_query)
157
+
158
+ # Merge with existing entities
159
+ existing_entities = state.get("medical_entities", {})
160
+
161
+ for category, items in new_entities.items():
162
+ if category not in existing_entities:
163
+ existing_entities[category] = []
164
+
165
+ for item in items:
166
+ if item not in existing_entities[category]:
167
+ existing_entities[category].append(item)
168
+
169
+ state["medical_entities"] = existing_entities
170
+ return state
171
+
172
+ def context_builder_node(state: ConversationState) -> ConversationState:
173
+ """Build comprehensive context from conversation history and entities."""
174
+ messages = state["messages"]
175
+ medical_entities = state.get("medical_entities", {})
176
+ health_profile = state.get("health_profile", "")
177
+
178
+ # Build context from recent messages (last 10 exchanges)
179
+ recent_messages = messages[-20:] if len(messages) > 20 else messages
180
+
181
+ conversation_context = []
182
+ for msg in recent_messages:
183
+ if isinstance(msg, HumanMessage):
184
+ conversation_context.append(f"User: {msg.content}")
185
+ elif isinstance(msg, AIMessage):
186
+ conversation_context.append(f"Assistant: {msg.content}")
187
+
188
+ # Create entity summary
189
+ entity_summary = ""
190
+ if medical_entities:
191
+ entity_parts = []
192
+ for category, items in medical_entities.items():
193
+ if items:
194
+ entity_parts.append(f"{category}: {', '.join(items)}")
195
+ entity_summary = " | ".join(entity_parts)
196
+
197
+ # Build session context
198
+ session_context = f"""
199
+ HEALTH PROFILE: {health_profile}
200
 
201
+ MEDICAL ENTITIES DISCUSSED: {entity_summary}
202
 
203
+ RECENT CONVERSATION:
204
+ {chr(10).join(conversation_context[-10:])} # Last 5 exchanges
205
  """
206
+
207
+ state["session_context"] = session_context
208
+ return state
209
 
210
+ def retrieval_node(state: ConversationState) -> ConversationState:
211
+ """Retrieve relevant documents using enhanced query with context."""
212
+ current_query = state["current_query"]
213
+ session_context = state.get("session_context", "")
214
+ medical_entities = state.get("medical_entities", {})
215
+
216
+ # Create enhanced query for retrieval
217
+ entity_keywords = []
218
+ for category, items in medical_entities.items():
219
+ entity_keywords.extend(items)
220
+
221
+ enhanced_query = current_query
222
+ if entity_keywords:
223
+ enhanced_query += " " + " ".join(entity_keywords)
224
+
225
+ # Retrieve documents
226
+ retriever = db.as_retriever(search_kwargs={'k': 5})
227
+ retrieved_docs = retriever.invoke(enhanced_query)
228
+
229
+ # Extract document content
230
+ doc_contents = [doc.page_content for doc in retrieved_docs]
231
+ state["retrieved_documents"] = doc_contents
232
+
233
+ return state
234
 
235
+ def clarification_check_node(state: ConversationState) -> ConversationState:
236
+ """Check if clarification is needed and generate clarifying questions."""
237
+ current_query = state["current_query"]
238
+ medical_entities = state.get("medical_entities", {})
239
+ retrieved_docs = state.get("retrieved_documents", [])
240
 
241
+ clarification_prompt = f"""
242
+ As a medical expert, analyze if the following query needs clarification for proper medical advice:
243
 
244
+ Query: {current_query}
245
+ Medical Context: {medical_entities}
246
+
247
+ If clarification is needed, generate 1-3 specific clarifying questions.
248
+ If no clarification needed, respond with "NO_CLARIFICATION_NEEDED"
249
+
250
+ Consider asking about:
251
+ - Symptom duration, severity, frequency
252
+ - Associated symptoms
253
+ - Current medications
254
+ - Recent changes in health
255
+ - Specific circumstances
256
+
257
+ Format: If clarification needed, list questions separated by '|'
258
+ """
259
+
260
+ response = gemini_llm._call(clarification_prompt)
261
+
262
+ if "NO_CLARIFICATION_NEEDED" in response:
263
+ state["needs_clarification"] = False
264
+ state["clarifying_questions"] = []
265
+ else:
266
+ state["needs_clarification"] = True
267
+ questions = [q.strip() for q in response.split('|') if q.strip()]
268
+ state["clarifying_questions"] = questions[:3] # Max 3 questions
269
+
270
+ return state
271
 
272
+ def medical_advisor_node(state: ConversationState) -> ConversationState:
273
+ """Generate medical advice using enhanced context."""
274
+ current_query = state["current_query"]
275
+ session_context = state.get("session_context", "")
276
+ retrieved_docs = state.get("retrieved_documents", [])
277
+ clarifying_questions = state.get("clarifying_questions", [])
278
+
279
+ # Combine retrieved documents
280
+ context_docs = "\n\n".join(retrieved_docs)
281
+
282
+ # Create comprehensive prompt
283
+ medical_prompt = f"""
284
+ You are an EXPERT MEDICAL ADVISOR. Use the complete session context to provide accurate, personalized medical advice.
285
 
286
+ IMPORTANT CONTEXT:
287
+ {session_context}
288
+
289
+ RELEVANT MEDICAL DOCUMENTS:
290
+ {context_docs}
291
 
292
+ CURRENT QUESTION: {current_query}
 
 
293
 
294
+ INSTRUCTIONS:
295
+ 1. Use the COMPLETE conversation history and medical entities to understand the full context
296
+ 2. Reference previous symptoms, conditions, and health information discussed
297
+ 3. Provide personalized advice based on the user's health profile
298
+ 4. If information is incomplete, mention what additional details would be helpful
299
+ 5. Keep response under 300 words but comprehensive
300
+ 6. Only provide information supported by the medical documents
301
+ 7. If unsure, clearly state limitations
302
 
303
+ {"CLARIFYING QUESTIONS TO CONSIDER: " + " | ".join(clarifying_questions) if clarifying_questions else ""}
304
+
305
+ Provide your medical advice:
306
+ """
307
+
308
+ response = gemini_llm._call(medical_prompt)
309
+
310
+ # Clean up response
311
+ sentences = [s.strip() for s in response.split('.') if s.strip()]
312
+ unique_sentences = list(OrderedDict.fromkeys(sentences))
313
+ cleaned_response = '. '.join(unique_sentences) + '.'
314
+
315
+ # Add to messages
316
+ state["messages"].append(AIMessage(content=cleaned_response))
317
+
318
+ return state
319
+
320
+ def should_ask_clarification(state: ConversationState) -> str:
321
+ """Routing function to determine if clarification is needed."""
322
+ return "clarification" if state.get("needs_clarification", False) else "response"
323
+
324
+ # ===== BUILD LANGGRAPH =====
325
+ def create_medical_graph():
326
+ workflow = StateGraph(ConversationState)
327
+
328
+ # Add nodes
329
+ workflow.add_node("entity_extraction", entity_extraction_node)
330
+ workflow.add_node("context_builder", context_builder_node)
331
+ workflow.add_node("retrieval", retrieval_node)
332
+ workflow.add_node("clarification_check", clarification_check_node)
333
+ workflow.add_node("medical_advisor", medical_advisor_node)
334
+
335
+ # Add edges
336
+ workflow.set_entry_point("entity_extraction")
337
+ workflow.add_edge("entity_extraction", "context_builder")
338
+ workflow.add_edge("context_builder", "retrieval")
339
+ workflow.add_edge("retrieval", "clarification_check")
340
+
341
+ # Conditional routing
342
+ workflow.add_conditional_edges(
343
+ "clarification_check",
344
+ should_ask_clarification,
345
+ {
346
+ "clarification": "medical_advisor",
347
+ "response": "medical_advisor"
348
+ }
349
+ )
350
+
351
+ workflow.add_edge("medical_advisor", END)
352
+
353
+ return workflow.compile()
354
+
355
+ # Create the graph
356
+ medical_graph = create_medical_graph()
357
+
358
+ # ===== SESSION MANAGEMENT =====
359
+ # Simple in-memory session storage for demo
360
+ active_sessions: Dict[str, ConversationState] = {}
361
+
362
+ def get_or_create_session(session_id: str) -> ConversationState:
363
+ if session_id not in active_sessions:
364
+ active_sessions[session_id] = ConversationState(
365
+ messages=[],
366
+ current_query="",
367
+ health_profile="",
368
+ medical_entities={},
369
+ conversation_summary="",
370
+ retrieved_documents=[],
371
+ clarifying_questions=[],
372
+ needs_clarification=False,
373
+ session_context=""
374
+ )
375
+ return active_sessions[session_id]
376
+
377
+ # ===== MAIN API FUNCTION =====
378
+ def ask_question(query: str, health_info: str = "No health profile provided", session_id: str = "default"):
379
+ """
380
+ Main API function - preserves your original interface while adding session support.
381
+ """
382
+ try:
383
+ # Get or create session state
384
+ state = get_or_create_session(session_id)
385
+
386
+ # Update state with current query and health info
387
+ state["current_query"] = query
388
+ state["health_profile"] = health_info
389
+
390
+ # Add user message to conversation history
391
+ state["messages"].append(HumanMessage(content=query))
392
+
393
+ # Run the medical graph
394
+ result = medical_graph.invoke(state)
395
+
396
+ # Update session state
397
+ active_sessions[session_id] = result
398
+
399
+ # Extract response
400
+ last_message = result["messages"][-1]
401
+ response_text = last_message.content if hasattr(last_message, 'content') else str(last_message)
402
+
403
+ # Prepare additional info
404
+ clarifying_questions = result.get("clarifying_questions", [])
405
+ medical_entities = result.get("medical_entities", {})
406
+
407
+ # Format additional info for gradio
408
+ additional_info = {
409
+ "clarifying_questions": clarifying_questions,
410
+ "medical_entities": medical_entities,
411
+ "session_id": session_id
412
+ }
413
+
414
+ return response_text, additional_info
415
 
416
  except Exception as e:
417
+ return f"Error: {str(e)}", {"error": True}
418
 
419
+ # ===== GRADIO INTERFACE =====
420
+ def gradio_interface(query, health_info, session_id):
421
+ """Wrapper function for Gradio interface."""
422
+ response, additional_info = ask_question(query, health_info, session_id)
423
+
424
+ # Format additional info for display
425
+ info_display = ""
426
+ if additional_info.get("clarifying_questions"):
427
+ info_display += "**Clarifying Questions:**\n"
428
+ for i, q in enumerate(additional_info["clarifying_questions"], 1):
429
+ info_display += f"{i}. {q}\n"
430
+ info_display += "\n"
431
+
432
+ if additional_info.get("medical_entities"):
433
+ info_display += "**Medical Entities Tracked:**\n"
434
+ for category, items in additional_info["medical_entities"].items():
435
+ if items:
436
+ info_display += f"- {category.title()}: {', '.join(items)}\n"
437
+
438
+ return response, info_display
439
 
440
+ # Create Gradio Interface
441
  iface = gr.Interface(
442
+ fn=gradio_interface,
443
  inputs=[
444
+ gr.Textbox(label="Question", placeholder="Enter your medical question here..."),
445
+ gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided"),
446
+ gr.Textbox(label="Session ID", placeholder="Enter session ID (optional)", value="default")
447
  ],
448
+ outputs=[
449
+ gr.Textbox(label="Medical Advice", lines=10),
450
+ gr.Textbox(label="Additional Information", lines=5)
451
+ ],
452
+ title="πŸ₯ Advanced Medical RAG Chatbot with LangGraph",
453
+ description="""
454
+ **Features:**
455
+ - πŸ’­ Maintains conversation context across questions
456
+ - 🧠 Extracts and tracks medical entities (symptoms, medications, etc.)
457
+ - ❓ Asks clarifying questions when needed
458
+ - πŸ‘€ Personalizes responses based on health profile
459
+ - πŸ“š Uses medical knowledge base for accurate information
460
+
461
+ **Tips:**
462
+ - Use the same Session ID to maintain conversation context
463
+ - Provide detailed health profile for personalized advice
464
+ - Answer clarifying questions for better recommendations
465
+ """,
466
+ examples=[
467
+ ["I have been having fever for 2 days", "Age: 25, No chronic conditions", "user123"],
468
+ ["What medicines should I take for this fever?", "Age: 25, No chronic conditions", "user123"],
469
+ ["I also have a headache now", "Age: 25, No chronic conditions", "user123"]
470
+ ]
471
  )
472
 
473
+ if __name__ == "__main__":
474
+ iface.launch(share=True)