mediVedaLLM / app.py
rishi002's picture
Update app.py
4fce8de verified
import os
import gradio as gr
from typing import Dict, List, Optional, TypedDict, Annotated
from collections import OrderedDict
import json
import re
from datetime import datetime
# LangGraph imports
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
# LangChain imports
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import PromptTemplate
from langchain.llms.base import LLM
import google.generativeai as genai
# Your existing embeddings utility
from embeddings import (
load_pdf_files,
create_chunks,
get_embedding_model,
store_embeddings,
load_faiss_db
)
# Constants
DATA_PATH = "dataFolder/"
DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
CACHE_DIR = "/tmp/models_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Google AI API setup
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
print("⚠️ GOOGLE_API_KEY not found in environment variables!")
print("Please set your Google API key in Hugging Face Spaces secrets.")
else:
genai.configure(api_key=GOOGLE_API_KEY)
# ===== LANGGRAPH STATE DEFINITION =====
class ConversationState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
current_query: str
health_profile: str
medical_entities: Dict[str, List[str]] # symptoms, medications, conditions, etc.
conversation_summary: str
retrieved_documents: List[str]
clarifying_questions: List[str]
needs_clarification: bool
session_context: str
# ===== CUSTOM GEMINI LLM =====
class GeminiLLM(LLM):
model_name: str = "gemini-2.0-flash"
class Config:
extra = 'forbid'
arbitrary_types_allowed = True
def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs):
super().__init__(model_name=model_name, **kwargs)
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
try:
model = genai.GenerativeModel(self.model_name)
response = model.generate_content(prompt)
return response.text
except Exception as e:
return f"Error generating response: {str(e)}"
@property
def _identifying_params(self):
return {"model_name": self.model_name}
@property
def _llm_type(self):
return "gemini"
# ===== MEDICAL ENTITY EXTRACTOR =====
class MedicalEntityExtractor:
def __init__(self, llm):
self.llm = llm
self.extraction_prompt = """
Extract medical entities from the following text. Return a JSON object with these categories:
- symptoms: List of symptoms mentioned
- medications: List of medications mentioned
- conditions: List of medical conditions mentioned
- body_parts: List of body parts mentioned
- severity: Any severity indicators (mild, severe, etc.)
- duration: Any time-related information
Text: {text}
Return only valid JSON:
"""
def extract_entities(self, text: str) -> Dict[str, List[str]]:
try:
prompt = self.extraction_prompt.format(text=text)
response = self.llm._call(prompt)
# Try to parse JSON response
try:
entities = json.loads(response)
return entities
except:
# Fallback to simple regex extraction
return self._fallback_extraction(text)
except:
return self._fallback_extraction(text)
def _fallback_extraction(self, text: str) -> Dict[str, List[str]]:
# Simple keyword-based extraction as fallback
symptoms_keywords = ['fever', 'headache', 'cough', 'pain', 'nausea', 'vomiting', 'diarrhea', 'fatigue', 'weakness']
medications_keywords = ['paracetamol', 'ibuprofen', 'aspirin', 'acetaminophen', 'antibiotic']
text_lower = text.lower()
return {
"symptoms": [s for s in symptoms_keywords if s in text_lower],
"medications": [m for m in medications_keywords if m in text_lower],
"conditions": [],
"body_parts": [],
"severity": [],
"duration": []
}
# ===== LOAD FAISS DB =====
def load_or_create_faiss():
embedding_model = get_embedding_model()
if not os.path.exists(DB_FAISS_PATH):
print("πŸ”„ FAISS index not found. Creating new index...")
documents = load_pdf_files(DATA_PATH)
text_chunks = create_chunks(documents)
db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)
else:
print("βœ… Existing FAISS index found. Loading it...")
db = load_faiss_db(DB_FAISS_PATH, embedding_model)
return db
db = load_or_create_faiss()
gemini_llm = GeminiLLM()
entity_extractor = MedicalEntityExtractor(gemini_llm)
# ===== LANGGRAPH NODES =====
def entity_extraction_node(state: ConversationState) -> ConversationState:
"""Extract medical entities from current query and update state."""
current_query = state["current_query"]
# Extract entities from current query
new_entities = entity_extractor.extract_entities(current_query)
# Merge with existing entities
existing_entities = state.get("medical_entities", {})
for category, items in new_entities.items():
if category not in existing_entities:
existing_entities[category] = []
for item in items:
if item not in existing_entities[category]:
existing_entities[category].append(item)
state["medical_entities"] = existing_entities
return state
def context_builder_node(state: ConversationState) -> ConversationState:
"""Build comprehensive context from conversation history and entities."""
messages = state["messages"]
medical_entities = state.get("medical_entities", {})
health_profile = state.get("health_profile", "")
# Build context from recent messages (last 10 exchanges)
recent_messages = messages[-20:] if len(messages) > 20 else messages
conversation_context = []
for msg in recent_messages:
if isinstance(msg, HumanMessage):
conversation_context.append(f"User: {msg.content}")
elif isinstance(msg, AIMessage):
conversation_context.append(f"Assistant: {msg.content}")
# Create entity summary
entity_summary = ""
if medical_entities:
entity_parts = []
for category, items in medical_entities.items():
if items:
entity_parts.append(f"{category}: {', '.join(items)}")
entity_summary = " | ".join(entity_parts)
# Build session context
session_context = f"""
HEALTH PROFILE: {health_profile}
MEDICAL ENTITIES DISCUSSED: {entity_summary}
RECENT CONVERSATION:
{chr(10).join(conversation_context[-10:])} # Last 5 exchanges
"""
state["session_context"] = session_context
return state
def retrieval_node(state: ConversationState) -> ConversationState:
"""Retrieve relevant documents using enhanced query with context."""
current_query = state["current_query"]
session_context = state.get("session_context", "")
medical_entities = state.get("medical_entities", {})
# Create enhanced query for retrieval
entity_keywords = []
for category, items in medical_entities.items():
entity_keywords.extend(items)
enhanced_query = current_query
if entity_keywords:
enhanced_query += " " + " ".join(entity_keywords)
# Retrieve documents
retriever = db.as_retriever(search_kwargs={'k': 5})
retrieved_docs = retriever.invoke(enhanced_query)
# Extract document content
doc_contents = [doc.page_content for doc in retrieved_docs]
state["retrieved_documents"] = doc_contents
return state
def clarification_check_node(state: ConversationState) -> ConversationState:
"""Check if clarification is needed and generate clarifying questions."""
current_query = state["current_query"]
medical_entities = state.get("medical_entities", {})
retrieved_docs = state.get("retrieved_documents", [])
clarification_prompt = f"""
As a medical expert, analyze if the following query needs clarification for proper medical advice:
Query: {current_query}
Medical Context: {medical_entities}
If clarification is needed, generate 1-3 specific clarifying questions.
If no clarification needed, respond with "NO_CLARIFICATION_NEEDED"
Consider asking about:
- Symptom duration, severity, frequency
- Associated symptoms
- Current medications
- Recent changes in health
- Specific circumstances
Format: If clarification needed, list questions separated by '|'
"""
response = gemini_llm._call(clarification_prompt)
if "NO_CLARIFICATION_NEEDED" in response:
state["needs_clarification"] = False
state["clarifying_questions"] = []
else:
state["needs_clarification"] = True
questions = [q.strip() for q in response.split('|') if q.strip()]
state["clarifying_questions"] = questions[:3] # Max 3 questions
return state
def medical_advisor_node(state: ConversationState) -> ConversationState:
"""Generate medical advice using enhanced context."""
current_query = state["current_query"]
session_context = state.get("session_context", "")
retrieved_docs = state.get("retrieved_documents", [])
clarifying_questions = state.get("clarifying_questions", [])
# Combine retrieved documents
context_docs = "\n\n".join(retrieved_docs)
# Create comprehensive prompt
medical_prompt = f"""
You are an EXPERT MEDICAL ADVISOR. Use the complete session context to provide accurate, personalized medical advice.
IMPORTANT CONTEXT:
{session_context}
RELEVANT MEDICAL DOCUMENTS:
{context_docs}
CURRENT QUESTION: {current_query}
INSTRUCTIONS:
1. Use the COMPLETE conversation history and medical entities to understand the full context
2. Reference previous symptoms, conditions, and health information discussed
3. Provide personalized advice based on the user's health profile
4. If information is incomplete, mention what additional details would be helpful
5. Keep response under 300 words but comprehensive
6. Only provide information supported by the medical documents
7. If unsure, clearly state limitations
{"CLARIFYING QUESTIONS TO CONSIDER: " + " | ".join(clarifying_questions) if clarifying_questions else ""}
Provide your medical advice:
"""
response = gemini_llm._call(medical_prompt)
# Clean up response
sentences = [s.strip() for s in response.split('.') if s.strip()]
unique_sentences = list(OrderedDict.fromkeys(sentences))
cleaned_response = '. '.join(unique_sentences) + '.'
# Add to messages
state["messages"].append(AIMessage(content=cleaned_response))
return state
def should_ask_clarification(state: ConversationState) -> str:
"""Routing function to determine if clarification is needed."""
return "clarification" if state.get("needs_clarification", False) else "response"
# ===== BUILD LANGGRAPH =====
def create_medical_graph():
workflow = StateGraph(ConversationState)
# Add nodes
workflow.add_node("entity_extraction", entity_extraction_node)
workflow.add_node("context_builder", context_builder_node)
workflow.add_node("retrieval", retrieval_node)
workflow.add_node("clarification_check", clarification_check_node)
workflow.add_node("medical_advisor", medical_advisor_node)
# Add edges
workflow.set_entry_point("entity_extraction")
workflow.add_edge("entity_extraction", "context_builder")
workflow.add_edge("context_builder", "retrieval")
workflow.add_edge("retrieval", "clarification_check")
# Conditional routing
workflow.add_conditional_edges(
"clarification_check",
should_ask_clarification,
{
"clarification": "medical_advisor",
"response": "medical_advisor"
}
)
workflow.add_edge("medical_advisor", END)
return workflow.compile()
# Create the graph
medical_graph = create_medical_graph()
# ===== SESSION MANAGEMENT =====
# Simple in-memory session storage for demo
active_sessions: Dict[str, ConversationState] = {}
def get_or_create_session(session_id: str) -> ConversationState:
if session_id not in active_sessions:
active_sessions[session_id] = ConversationState(
messages=[],
current_query="",
health_profile="",
medical_entities={},
conversation_summary="",
retrieved_documents=[],
clarifying_questions=[],
needs_clarification=False,
session_context=""
)
return active_sessions[session_id]
# ===== MAIN API FUNCTION =====
def ask_question(query: str, health_info: str = "No health profile provided", session_id: str = "default"):
"""
Main API function - preserves your original interface while adding session support.
"""
try:
# Get or create session state
state = get_or_create_session(session_id)
# Update state with current query and health info
state["current_query"] = query
state["health_profile"] = health_info
# Add user message to conversation history
state["messages"].append(HumanMessage(content=query))
# Run the medical graph
result = medical_graph.invoke(state)
# Update session state
active_sessions[session_id] = result
# Extract response
last_message = result["messages"][-1]
response_text = last_message.content if hasattr(last_message, 'content') else str(last_message)
# Prepare additional info
clarifying_questions = result.get("clarifying_questions", [])
medical_entities = result.get("medical_entities", {})
# Format additional info for gradio
additional_info = {
"clarifying_questions": clarifying_questions,
"medical_entities": medical_entities,
"session_id": session_id
}
return response_text, additional_info
except Exception as e:
return f"Error: {str(e)}", {"error": True}
# ===== GRADIO INTERFACE =====
def gradio_interface(query, health_info, session_id):
"""Wrapper function for Gradio interface."""
response, additional_info = ask_question(query, health_info, session_id)
# Format additional info for display
info_display = ""
if additional_info.get("clarifying_questions"):
info_display += "**Clarifying Questions:**\n"
for i, q in enumerate(additional_info["clarifying_questions"], 1):
info_display += f"{i}. {q}\n"
info_display += "\n"
if additional_info.get("medical_entities"):
info_display += "**Medical Entities Tracked:**\n"
for category, items in additional_info["medical_entities"].items():
if items:
info_display += f"- {category.title()}: {', '.join(items)}\n"
return response, info_display
# Create Gradio Interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Question", placeholder="Enter your medical question here..."),
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided"),
gr.Textbox(label="Session ID", placeholder="Enter session ID (optional)", value="default")
],
outputs=[
gr.Textbox(label="Medical Advice", lines=10),
gr.Textbox(label="Additional Information", lines=5)
],
title="πŸ₯ Advanced Medical RAG Chatbot with LangGraph",
description="""
**Features:**
- πŸ’­ Maintains conversation context across questions
- 🧠 Extracts and tracks medical entities (symptoms, medications, etc.)
- ❓ Asks clarifying questions when needed
- πŸ‘€ Personalizes responses based on health profile
- πŸ“š Uses medical knowledge base for accurate information
**Tips:**
- Use the same Session ID to maintain conversation context
- Provide detailed health profile for personalized advice
- Answer clarifying questions for better recommendations
""",
examples=[
["I have been having fever for 2 days", "Age: 25, No chronic conditions", "user123"],
["What medicines should I take for this fever?", "Age: 25, No chronic conditions", "user123"],
["I also have a headache now", "Age: 25, No chronic conditions", "user123"]
]
)
if __name__ == "__main__":
iface.launch(share=True)