Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,24 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain.chains import create_retrieval_chain
|
4 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
5 |
from langchain_core.prompts import PromptTemplate
|
6 |
from langchain.llms.base import LLM
|
7 |
-
from collections import OrderedDict
|
8 |
-
from typing import Optional, List
|
9 |
import google.generativeai as genai
|
10 |
|
11 |
-
#
|
12 |
from embeddings import (
|
13 |
load_pdf_files,
|
14 |
create_chunks,
|
@@ -31,22 +41,19 @@ if not GOOGLE_API_KEY:
|
|
31 |
else:
|
32 |
genai.configure(api_key=GOOGLE_API_KEY)
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
return db
|
46 |
-
|
47 |
-
db = load_or_create_faiss()
|
48 |
|
49 |
-
#
|
50 |
class GeminiLLM(LLM):
|
51 |
model_name: str = "gemini-2.0-flash"
|
52 |
|
@@ -73,74 +80,395 @@ class GeminiLLM(LLM):
|
|
73 |
def _llm_type(self):
|
74 |
return "gemini"
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
|
90 |
-
|
|
|
91 |
"""
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
103 |
-
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
response = qa_chain.invoke({"input": full_query})
|
121 |
-
result = response["answer"] # Modern chains return "answer" not "result"
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
except Exception as e:
|
131 |
-
return f"Error: {str(e)}",
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
# Gradio Interface
|
135 |
iface = gr.Interface(
|
136 |
-
fn=
|
137 |
inputs=[
|
138 |
-
gr.Textbox(label="Question", placeholder="Enter your question here..."),
|
139 |
-
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided")
|
|
|
140 |
],
|
141 |
-
outputs=[
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
|
146 |
-
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from typing import Dict, List, Optional, TypedDict, Annotated
|
4 |
+
from collections import OrderedDict
|
5 |
+
import json
|
6 |
+
import re
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
# LangGraph imports
|
10 |
+
from langgraph.graph import StateGraph, END
|
11 |
+
from langgraph.graph.message import add_messages
|
12 |
+
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
13 |
+
|
14 |
+
# LangChain imports
|
15 |
from langchain.chains import create_retrieval_chain
|
16 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
17 |
from langchain_core.prompts import PromptTemplate
|
18 |
from langchain.llms.base import LLM
|
|
|
|
|
19 |
import google.generativeai as genai
|
20 |
|
21 |
+
# Your existing embeddings utility
|
22 |
from embeddings import (
|
23 |
load_pdf_files,
|
24 |
create_chunks,
|
|
|
41 |
else:
|
42 |
genai.configure(api_key=GOOGLE_API_KEY)
|
43 |
|
44 |
+
# ===== LANGGRAPH STATE DEFINITION =====
|
45 |
+
class ConversationState(TypedDict):
|
46 |
+
messages: Annotated[list[BaseMessage], add_messages]
|
47 |
+
current_query: str
|
48 |
+
health_profile: str
|
49 |
+
medical_entities: Dict[str, List[str]] # symptoms, medications, conditions, etc.
|
50 |
+
conversation_summary: str
|
51 |
+
retrieved_documents: List[str]
|
52 |
+
clarifying_questions: List[str]
|
53 |
+
needs_clarification: bool
|
54 |
+
session_context: str
|
|
|
|
|
|
|
55 |
|
56 |
+
# ===== CUSTOM GEMINI LLM =====
|
57 |
class GeminiLLM(LLM):
|
58 |
model_name: str = "gemini-2.0-flash"
|
59 |
|
|
|
80 |
def _llm_type(self):
|
81 |
return "gemini"
|
82 |
|
83 |
+
# ===== MEDICAL ENTITY EXTRACTOR =====
|
84 |
+
class MedicalEntityExtractor:
|
85 |
+
def __init__(self, llm):
|
86 |
+
self.llm = llm
|
87 |
+
self.extraction_prompt = """
|
88 |
+
Extract medical entities from the following text. Return a JSON object with these categories:
|
89 |
+
- symptoms: List of symptoms mentioned
|
90 |
+
- medications: List of medications mentioned
|
91 |
+
- conditions: List of medical conditions mentioned
|
92 |
+
- body_parts: List of body parts mentioned
|
93 |
+
- severity: Any severity indicators (mild, severe, etc.)
|
94 |
+
- duration: Any time-related information
|
95 |
+
|
96 |
+
Text: {text}
|
97 |
+
|
98 |
+
Return only valid JSON:
|
99 |
+
"""
|
100 |
+
|
101 |
+
def extract_entities(self, text: str) -> Dict[str, List[str]]:
|
102 |
+
try:
|
103 |
+
prompt = self.extraction_prompt.format(text=text)
|
104 |
+
response = self.llm._call(prompt)
|
105 |
+
|
106 |
+
# Try to parse JSON response
|
107 |
+
try:
|
108 |
+
entities = json.loads(response)
|
109 |
+
return entities
|
110 |
+
except:
|
111 |
+
# Fallback to simple regex extraction
|
112 |
+
return self._fallback_extraction(text)
|
113 |
+
except:
|
114 |
+
return self._fallback_extraction(text)
|
115 |
+
|
116 |
+
def _fallback_extraction(self, text: str) -> Dict[str, List[str]]:
|
117 |
+
# Simple keyword-based extraction as fallback
|
118 |
+
symptoms_keywords = ['fever', 'headache', 'cough', 'pain', 'nausea', 'vomiting', 'diarrhea', 'fatigue', 'weakness']
|
119 |
+
medications_keywords = ['paracetamol', 'ibuprofen', 'aspirin', 'acetaminophen', 'antibiotic']
|
120 |
+
|
121 |
+
text_lower = text.lower()
|
122 |
+
|
123 |
+
return {
|
124 |
+
"symptoms": [s for s in symptoms_keywords if s in text_lower],
|
125 |
+
"medications": [m for m in medications_keywords if m in text_lower],
|
126 |
+
"conditions": [],
|
127 |
+
"body_parts": [],
|
128 |
+
"severity": [],
|
129 |
+
"duration": []
|
130 |
+
}
|
131 |
+
|
132 |
+
# ===== LOAD FAISS DB =====
|
133 |
+
def load_or_create_faiss():
|
134 |
+
embedding_model = get_embedding_model()
|
135 |
+
if not os.path.exists(DB_FAISS_PATH):
|
136 |
+
print("π FAISS index not found. Creating new index...")
|
137 |
+
documents = load_pdf_files(DATA_PATH)
|
138 |
+
text_chunks = create_chunks(documents)
|
139 |
+
db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
|
140 |
+
else:
|
141 |
+
print("β
Existing FAISS index found. Loading it...")
|
142 |
+
db = load_faiss_db(DB_FAISS_PATH, embedding_model)
|
143 |
+
return db
|
144 |
+
|
145 |
+
db = load_or_create_faiss()
|
146 |
+
gemini_llm = GeminiLLM()
|
147 |
+
entity_extractor = MedicalEntityExtractor(gemini_llm)
|
148 |
+
|
149 |
+
# ===== LANGGRAPH NODES =====
|
150 |
|
151 |
+
def entity_extraction_node(state: ConversationState) -> ConversationState:
|
152 |
+
"""Extract medical entities from current query and update state."""
|
153 |
+
current_query = state["current_query"]
|
154 |
+
|
155 |
+
# Extract entities from current query
|
156 |
+
new_entities = entity_extractor.extract_entities(current_query)
|
157 |
+
|
158 |
+
# Merge with existing entities
|
159 |
+
existing_entities = state.get("medical_entities", {})
|
160 |
+
|
161 |
+
for category, items in new_entities.items():
|
162 |
+
if category not in existing_entities:
|
163 |
+
existing_entities[category] = []
|
164 |
+
|
165 |
+
for item in items:
|
166 |
+
if item not in existing_entities[category]:
|
167 |
+
existing_entities[category].append(item)
|
168 |
+
|
169 |
+
state["medical_entities"] = existing_entities
|
170 |
+
return state
|
171 |
+
|
172 |
+
def context_builder_node(state: ConversationState) -> ConversationState:
|
173 |
+
"""Build comprehensive context from conversation history and entities."""
|
174 |
+
messages = state["messages"]
|
175 |
+
medical_entities = state.get("medical_entities", {})
|
176 |
+
health_profile = state.get("health_profile", "")
|
177 |
+
|
178 |
+
# Build context from recent messages (last 10 exchanges)
|
179 |
+
recent_messages = messages[-20:] if len(messages) > 20 else messages
|
180 |
+
|
181 |
+
conversation_context = []
|
182 |
+
for msg in recent_messages:
|
183 |
+
if isinstance(msg, HumanMessage):
|
184 |
+
conversation_context.append(f"User: {msg.content}")
|
185 |
+
elif isinstance(msg, AIMessage):
|
186 |
+
conversation_context.append(f"Assistant: {msg.content}")
|
187 |
+
|
188 |
+
# Create entity summary
|
189 |
+
entity_summary = ""
|
190 |
+
if medical_entities:
|
191 |
+
entity_parts = []
|
192 |
+
for category, items in medical_entities.items():
|
193 |
+
if items:
|
194 |
+
entity_parts.append(f"{category}: {', '.join(items)}")
|
195 |
+
entity_summary = " | ".join(entity_parts)
|
196 |
+
|
197 |
+
# Build session context
|
198 |
+
session_context = f"""
|
199 |
+
HEALTH PROFILE: {health_profile}
|
200 |
|
201 |
+
MEDICAL ENTITIES DISCUSSED: {entity_summary}
|
202 |
|
203 |
+
RECENT CONVERSATION:
|
204 |
+
{chr(10).join(conversation_context[-10:])} # Last 5 exchanges
|
205 |
"""
|
206 |
+
|
207 |
+
state["session_context"] = session_context
|
208 |
+
return state
|
209 |
|
210 |
+
def retrieval_node(state: ConversationState) -> ConversationState:
|
211 |
+
"""Retrieve relevant documents using enhanced query with context."""
|
212 |
+
current_query = state["current_query"]
|
213 |
+
session_context = state.get("session_context", "")
|
214 |
+
medical_entities = state.get("medical_entities", {})
|
215 |
+
|
216 |
+
# Create enhanced query for retrieval
|
217 |
+
entity_keywords = []
|
218 |
+
for category, items in medical_entities.items():
|
219 |
+
entity_keywords.extend(items)
|
220 |
+
|
221 |
+
enhanced_query = current_query
|
222 |
+
if entity_keywords:
|
223 |
+
enhanced_query += " " + " ".join(entity_keywords)
|
224 |
+
|
225 |
+
# Retrieve documents
|
226 |
+
retriever = db.as_retriever(search_kwargs={'k': 5})
|
227 |
+
retrieved_docs = retriever.invoke(enhanced_query)
|
228 |
+
|
229 |
+
# Extract document content
|
230 |
+
doc_contents = [doc.page_content for doc in retrieved_docs]
|
231 |
+
state["retrieved_documents"] = doc_contents
|
232 |
+
|
233 |
+
return state
|
234 |
|
235 |
+
def clarification_check_node(state: ConversationState) -> ConversationState:
|
236 |
+
"""Check if clarification is needed and generate clarifying questions."""
|
237 |
+
current_query = state["current_query"]
|
238 |
+
medical_entities = state.get("medical_entities", {})
|
239 |
+
retrieved_docs = state.get("retrieved_documents", [])
|
240 |
|
241 |
+
clarification_prompt = f"""
|
242 |
+
As a medical expert, analyze if the following query needs clarification for proper medical advice:
|
243 |
|
244 |
+
Query: {current_query}
|
245 |
+
Medical Context: {medical_entities}
|
246 |
+
|
247 |
+
If clarification is needed, generate 1-3 specific clarifying questions.
|
248 |
+
If no clarification needed, respond with "NO_CLARIFICATION_NEEDED"
|
249 |
+
|
250 |
+
Consider asking about:
|
251 |
+
- Symptom duration, severity, frequency
|
252 |
+
- Associated symptoms
|
253 |
+
- Current medications
|
254 |
+
- Recent changes in health
|
255 |
+
- Specific circumstances
|
256 |
+
|
257 |
+
Format: If clarification needed, list questions separated by '|'
|
258 |
+
"""
|
259 |
+
|
260 |
+
response = gemini_llm._call(clarification_prompt)
|
261 |
+
|
262 |
+
if "NO_CLARIFICATION_NEEDED" in response:
|
263 |
+
state["needs_clarification"] = False
|
264 |
+
state["clarifying_questions"] = []
|
265 |
+
else:
|
266 |
+
state["needs_clarification"] = True
|
267 |
+
questions = [q.strip() for q in response.split('|') if q.strip()]
|
268 |
+
state["clarifying_questions"] = questions[:3] # Max 3 questions
|
269 |
+
|
270 |
+
return state
|
271 |
|
272 |
+
def medical_advisor_node(state: ConversationState) -> ConversationState:
|
273 |
+
"""Generate medical advice using enhanced context."""
|
274 |
+
current_query = state["current_query"]
|
275 |
+
session_context = state.get("session_context", "")
|
276 |
+
retrieved_docs = state.get("retrieved_documents", [])
|
277 |
+
clarifying_questions = state.get("clarifying_questions", [])
|
278 |
+
|
279 |
+
# Combine retrieved documents
|
280 |
+
context_docs = "\n\n".join(retrieved_docs)
|
281 |
+
|
282 |
+
# Create comprehensive prompt
|
283 |
+
medical_prompt = f"""
|
284 |
+
You are an EXPERT MEDICAL ADVISOR. Use the complete session context to provide accurate, personalized medical advice.
|
285 |
|
286 |
+
IMPORTANT CONTEXT:
|
287 |
+
{session_context}
|
288 |
+
|
289 |
+
RELEVANT MEDICAL DOCUMENTS:
|
290 |
+
{context_docs}
|
291 |
|
292 |
+
CURRENT QUESTION: {current_query}
|
|
|
|
|
293 |
|
294 |
+
INSTRUCTIONS:
|
295 |
+
1. Use the COMPLETE conversation history and medical entities to understand the full context
|
296 |
+
2. Reference previous symptoms, conditions, and health information discussed
|
297 |
+
3. Provide personalized advice based on the user's health profile
|
298 |
+
4. If information is incomplete, mention what additional details would be helpful
|
299 |
+
5. Keep response under 300 words but comprehensive
|
300 |
+
6. Only provide information supported by the medical documents
|
301 |
+
7. If unsure, clearly state limitations
|
302 |
|
303 |
+
{"CLARIFYING QUESTIONS TO CONSIDER: " + " | ".join(clarifying_questions) if clarifying_questions else ""}
|
304 |
+
|
305 |
+
Provide your medical advice:
|
306 |
+
"""
|
307 |
+
|
308 |
+
response = gemini_llm._call(medical_prompt)
|
309 |
+
|
310 |
+
# Clean up response
|
311 |
+
sentences = [s.strip() for s in response.split('.') if s.strip()]
|
312 |
+
unique_sentences = list(OrderedDict.fromkeys(sentences))
|
313 |
+
cleaned_response = '. '.join(unique_sentences) + '.'
|
314 |
+
|
315 |
+
# Add to messages
|
316 |
+
state["messages"].append(AIMessage(content=cleaned_response))
|
317 |
+
|
318 |
+
return state
|
319 |
+
|
320 |
+
def should_ask_clarification(state: ConversationState) -> str:
|
321 |
+
"""Routing function to determine if clarification is needed."""
|
322 |
+
return "clarification" if state.get("needs_clarification", False) else "response"
|
323 |
+
|
324 |
+
# ===== BUILD LANGGRAPH =====
|
325 |
+
def create_medical_graph():
|
326 |
+
workflow = StateGraph(ConversationState)
|
327 |
+
|
328 |
+
# Add nodes
|
329 |
+
workflow.add_node("entity_extraction", entity_extraction_node)
|
330 |
+
workflow.add_node("context_builder", context_builder_node)
|
331 |
+
workflow.add_node("retrieval", retrieval_node)
|
332 |
+
workflow.add_node("clarification_check", clarification_check_node)
|
333 |
+
workflow.add_node("medical_advisor", medical_advisor_node)
|
334 |
+
|
335 |
+
# Add edges
|
336 |
+
workflow.set_entry_point("entity_extraction")
|
337 |
+
workflow.add_edge("entity_extraction", "context_builder")
|
338 |
+
workflow.add_edge("context_builder", "retrieval")
|
339 |
+
workflow.add_edge("retrieval", "clarification_check")
|
340 |
+
|
341 |
+
# Conditional routing
|
342 |
+
workflow.add_conditional_edges(
|
343 |
+
"clarification_check",
|
344 |
+
should_ask_clarification,
|
345 |
+
{
|
346 |
+
"clarification": "medical_advisor",
|
347 |
+
"response": "medical_advisor"
|
348 |
+
}
|
349 |
+
)
|
350 |
+
|
351 |
+
workflow.add_edge("medical_advisor", END)
|
352 |
+
|
353 |
+
return workflow.compile()
|
354 |
+
|
355 |
+
# Create the graph
|
356 |
+
medical_graph = create_medical_graph()
|
357 |
+
|
358 |
+
# ===== SESSION MANAGEMENT =====
|
359 |
+
# Simple in-memory session storage for demo
|
360 |
+
active_sessions: Dict[str, ConversationState] = {}
|
361 |
+
|
362 |
+
def get_or_create_session(session_id: str) -> ConversationState:
|
363 |
+
if session_id not in active_sessions:
|
364 |
+
active_sessions[session_id] = ConversationState(
|
365 |
+
messages=[],
|
366 |
+
current_query="",
|
367 |
+
health_profile="",
|
368 |
+
medical_entities={},
|
369 |
+
conversation_summary="",
|
370 |
+
retrieved_documents=[],
|
371 |
+
clarifying_questions=[],
|
372 |
+
needs_clarification=False,
|
373 |
+
session_context=""
|
374 |
+
)
|
375 |
+
return active_sessions[session_id]
|
376 |
+
|
377 |
+
# ===== MAIN API FUNCTION =====
|
378 |
+
def ask_question(query: str, health_info: str = "No health profile provided", session_id: str = "default"):
|
379 |
+
"""
|
380 |
+
Main API function - preserves your original interface while adding session support.
|
381 |
+
"""
|
382 |
+
try:
|
383 |
+
# Get or create session state
|
384 |
+
state = get_or_create_session(session_id)
|
385 |
+
|
386 |
+
# Update state with current query and health info
|
387 |
+
state["current_query"] = query
|
388 |
+
state["health_profile"] = health_info
|
389 |
+
|
390 |
+
# Add user message to conversation history
|
391 |
+
state["messages"].append(HumanMessage(content=query))
|
392 |
+
|
393 |
+
# Run the medical graph
|
394 |
+
result = medical_graph.invoke(state)
|
395 |
+
|
396 |
+
# Update session state
|
397 |
+
active_sessions[session_id] = result
|
398 |
+
|
399 |
+
# Extract response
|
400 |
+
last_message = result["messages"][-1]
|
401 |
+
response_text = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
402 |
+
|
403 |
+
# Prepare additional info
|
404 |
+
clarifying_questions = result.get("clarifying_questions", [])
|
405 |
+
medical_entities = result.get("medical_entities", {})
|
406 |
+
|
407 |
+
# Format additional info for gradio
|
408 |
+
additional_info = {
|
409 |
+
"clarifying_questions": clarifying_questions,
|
410 |
+
"medical_entities": medical_entities,
|
411 |
+
"session_id": session_id
|
412 |
+
}
|
413 |
+
|
414 |
+
return response_text, additional_info
|
415 |
|
416 |
except Exception as e:
|
417 |
+
return f"Error: {str(e)}", {"error": True}
|
418 |
|
419 |
+
# ===== GRADIO INTERFACE =====
|
420 |
+
def gradio_interface(query, health_info, session_id):
|
421 |
+
"""Wrapper function for Gradio interface."""
|
422 |
+
response, additional_info = ask_question(query, health_info, session_id)
|
423 |
+
|
424 |
+
# Format additional info for display
|
425 |
+
info_display = ""
|
426 |
+
if additional_info.get("clarifying_questions"):
|
427 |
+
info_display += "**Clarifying Questions:**\n"
|
428 |
+
for i, q in enumerate(additional_info["clarifying_questions"], 1):
|
429 |
+
info_display += f"{i}. {q}\n"
|
430 |
+
info_display += "\n"
|
431 |
+
|
432 |
+
if additional_info.get("medical_entities"):
|
433 |
+
info_display += "**Medical Entities Tracked:**\n"
|
434 |
+
for category, items in additional_info["medical_entities"].items():
|
435 |
+
if items:
|
436 |
+
info_display += f"- {category.title()}: {', '.join(items)}\n"
|
437 |
+
|
438 |
+
return response, info_display
|
439 |
|
440 |
+
# Create Gradio Interface
|
441 |
iface = gr.Interface(
|
442 |
+
fn=gradio_interface,
|
443 |
inputs=[
|
444 |
+
gr.Textbox(label="Question", placeholder="Enter your medical question here..."),
|
445 |
+
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided"),
|
446 |
+
gr.Textbox(label="Session ID", placeholder="Enter session ID (optional)", value="default")
|
447 |
],
|
448 |
+
outputs=[
|
449 |
+
gr.Textbox(label="Medical Advice", lines=10),
|
450 |
+
gr.Textbox(label="Additional Information", lines=5)
|
451 |
+
],
|
452 |
+
title="π₯ Advanced Medical RAG Chatbot with LangGraph",
|
453 |
+
description="""
|
454 |
+
**Features:**
|
455 |
+
- π Maintains conversation context across questions
|
456 |
+
- π§ Extracts and tracks medical entities (symptoms, medications, etc.)
|
457 |
+
- β Asks clarifying questions when needed
|
458 |
+
- π€ Personalizes responses based on health profile
|
459 |
+
- π Uses medical knowledge base for accurate information
|
460 |
+
|
461 |
+
**Tips:**
|
462 |
+
- Use the same Session ID to maintain conversation context
|
463 |
+
- Provide detailed health profile for personalized advice
|
464 |
+
- Answer clarifying questions for better recommendations
|
465 |
+
""",
|
466 |
+
examples=[
|
467 |
+
["I have been having fever for 2 days", "Age: 25, No chronic conditions", "user123"],
|
468 |
+
["What medicines should I take for this fever?", "Age: 25, No chronic conditions", "user123"],
|
469 |
+
["I also have a headache now", "Age: 25, No chronic conditions", "user123"]
|
470 |
+
]
|
471 |
)
|
472 |
|
473 |
+
if __name__ == "__main__":
|
474 |
+
iface.launch(share=True)
|