Spaces:
Running
Running
import os | |
import gradio as gr | |
from typing import Dict, List, Optional, TypedDict, Annotated | |
from collections import OrderedDict | |
import json | |
import re | |
from datetime import datetime | |
# LangGraph imports | |
from langgraph.graph import StateGraph, END | |
from langgraph.graph.message import add_messages | |
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
# LangChain imports | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import PromptTemplate | |
from langchain.llms.base import LLM | |
import google.generativeai as genai | |
# Your existing embeddings utility | |
from embeddings import ( | |
load_pdf_files, | |
create_chunks, | |
get_embedding_model, | |
store_embeddings, | |
load_faiss_db | |
) | |
# Constants | |
DATA_PATH = "dataFolder/" | |
DB_FAISS_PATH = "/tmp/vectorstore/db_faiss" | |
CACHE_DIR = "/tmp/models_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Google AI API setup | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
print("β οΈ GOOGLE_API_KEY not found in environment variables!") | |
print("Please set your Google API key in Hugging Face Spaces secrets.") | |
else: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# ===== LANGGRAPH STATE DEFINITION ===== | |
class ConversationState(TypedDict): | |
messages: Annotated[list[BaseMessage], add_messages] | |
current_query: str | |
health_profile: str | |
medical_entities: Dict[str, List[str]] # symptoms, medications, conditions, etc. | |
conversation_summary: str | |
retrieved_documents: List[str] | |
clarifying_questions: List[str] | |
needs_clarification: bool | |
session_context: str | |
# ===== CUSTOM GEMINI LLM ===== | |
class GeminiLLM(LLM): | |
model_name: str = "gemini-2.0-flash" | |
class Config: | |
extra = 'forbid' | |
arbitrary_types_allowed = True | |
def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs): | |
super().__init__(model_name=model_name, **kwargs) | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
try: | |
model = genai.GenerativeModel(self.model_name) | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def _identifying_params(self): | |
return {"model_name": self.model_name} | |
def _llm_type(self): | |
return "gemini" | |
# ===== MEDICAL ENTITY EXTRACTOR ===== | |
class MedicalEntityExtractor: | |
def __init__(self, llm): | |
self.llm = llm | |
self.extraction_prompt = """ | |
Extract medical entities from the following text. Return a JSON object with these categories: | |
- symptoms: List of symptoms mentioned | |
- medications: List of medications mentioned | |
- conditions: List of medical conditions mentioned | |
- body_parts: List of body parts mentioned | |
- severity: Any severity indicators (mild, severe, etc.) | |
- duration: Any time-related information | |
Text: {text} | |
Return only valid JSON: | |
""" | |
def extract_entities(self, text: str) -> Dict[str, List[str]]: | |
try: | |
prompt = self.extraction_prompt.format(text=text) | |
response = self.llm._call(prompt) | |
# Try to parse JSON response | |
try: | |
entities = json.loads(response) | |
return entities | |
except: | |
# Fallback to simple regex extraction | |
return self._fallback_extraction(text) | |
except: | |
return self._fallback_extraction(text) | |
def _fallback_extraction(self, text: str) -> Dict[str, List[str]]: | |
# Simple keyword-based extraction as fallback | |
symptoms_keywords = ['fever', 'headache', 'cough', 'pain', 'nausea', 'vomiting', 'diarrhea', 'fatigue', 'weakness'] | |
medications_keywords = ['paracetamol', 'ibuprofen', 'aspirin', 'acetaminophen', 'antibiotic'] | |
text_lower = text.lower() | |
return { | |
"symptoms": [s for s in symptoms_keywords if s in text_lower], | |
"medications": [m for m in medications_keywords if m in text_lower], | |
"conditions": [], | |
"body_parts": [], | |
"severity": [], | |
"duration": [] | |
} | |
# ===== LOAD FAISS DB ===== | |
def load_or_create_faiss(): | |
embedding_model = get_embedding_model() | |
if not os.path.exists(DB_FAISS_PATH): | |
print("π FAISS index not found. Creating new index...") | |
documents = load_pdf_files(DATA_PATH) | |
text_chunks = create_chunks(documents) | |
db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH) | |
else: | |
print("β Existing FAISS index found. Loading it...") | |
db = load_faiss_db(DB_FAISS_PATH, embedding_model) | |
return db | |
db = load_or_create_faiss() | |
gemini_llm = GeminiLLM() | |
entity_extractor = MedicalEntityExtractor(gemini_llm) | |
# ===== LANGGRAPH NODES ===== | |
def entity_extraction_node(state: ConversationState) -> ConversationState: | |
"""Extract medical entities from current query and update state.""" | |
current_query = state["current_query"] | |
# Extract entities from current query | |
new_entities = entity_extractor.extract_entities(current_query) | |
# Merge with existing entities | |
existing_entities = state.get("medical_entities", {}) | |
for category, items in new_entities.items(): | |
if category not in existing_entities: | |
existing_entities[category] = [] | |
for item in items: | |
if item not in existing_entities[category]: | |
existing_entities[category].append(item) | |
state["medical_entities"] = existing_entities | |
return state | |
def context_builder_node(state: ConversationState) -> ConversationState: | |
"""Build comprehensive context from conversation history and entities.""" | |
messages = state["messages"] | |
medical_entities = state.get("medical_entities", {}) | |
health_profile = state.get("health_profile", "") | |
# Build context from recent messages (last 10 exchanges) | |
recent_messages = messages[-20:] if len(messages) > 20 else messages | |
conversation_context = [] | |
for msg in recent_messages: | |
if isinstance(msg, HumanMessage): | |
conversation_context.append(f"User: {msg.content}") | |
elif isinstance(msg, AIMessage): | |
conversation_context.append(f"Assistant: {msg.content}") | |
# Create entity summary | |
entity_summary = "" | |
if medical_entities: | |
entity_parts = [] | |
for category, items in medical_entities.items(): | |
if items: | |
entity_parts.append(f"{category}: {', '.join(items)}") | |
entity_summary = " | ".join(entity_parts) | |
# Build session context | |
session_context = f""" | |
HEALTH PROFILE: {health_profile} | |
MEDICAL ENTITIES DISCUSSED: {entity_summary} | |
RECENT CONVERSATION: | |
{chr(10).join(conversation_context[-10:])} # Last 5 exchanges | |
""" | |
state["session_context"] = session_context | |
return state | |
def retrieval_node(state: ConversationState) -> ConversationState: | |
"""Retrieve relevant documents using enhanced query with context.""" | |
current_query = state["current_query"] | |
session_context = state.get("session_context", "") | |
medical_entities = state.get("medical_entities", {}) | |
# Create enhanced query for retrieval | |
entity_keywords = [] | |
for category, items in medical_entities.items(): | |
entity_keywords.extend(items) | |
enhanced_query = current_query | |
if entity_keywords: | |
enhanced_query += " " + " ".join(entity_keywords) | |
# Retrieve documents | |
retriever = db.as_retriever(search_kwargs={'k': 5}) | |
retrieved_docs = retriever.invoke(enhanced_query) | |
# Extract document content | |
doc_contents = [doc.page_content for doc in retrieved_docs] | |
state["retrieved_documents"] = doc_contents | |
return state | |
def clarification_check_node(state: ConversationState) -> ConversationState: | |
"""Check if clarification is needed and generate clarifying questions.""" | |
current_query = state["current_query"] | |
medical_entities = state.get("medical_entities", {}) | |
retrieved_docs = state.get("retrieved_documents", []) | |
clarification_prompt = f""" | |
As a medical expert, analyze if the following query needs clarification for proper medical advice: | |
Query: {current_query} | |
Medical Context: {medical_entities} | |
If clarification is needed, generate 1-3 specific clarifying questions. | |
If no clarification needed, respond with "NO_CLARIFICATION_NEEDED" | |
Consider asking about: | |
- Symptom duration, severity, frequency | |
- Associated symptoms | |
- Current medications | |
- Recent changes in health | |
- Specific circumstances | |
Format: If clarification needed, list questions separated by '|' | |
""" | |
response = gemini_llm._call(clarification_prompt) | |
if "NO_CLARIFICATION_NEEDED" in response: | |
state["needs_clarification"] = False | |
state["clarifying_questions"] = [] | |
else: | |
state["needs_clarification"] = True | |
questions = [q.strip() for q in response.split('|') if q.strip()] | |
state["clarifying_questions"] = questions[:3] # Max 3 questions | |
return state | |
def medical_advisor_node(state: ConversationState) -> ConversationState: | |
"""Generate medical advice using enhanced context.""" | |
current_query = state["current_query"] | |
session_context = state.get("session_context", "") | |
retrieved_docs = state.get("retrieved_documents", []) | |
clarifying_questions = state.get("clarifying_questions", []) | |
# Combine retrieved documents | |
context_docs = "\n\n".join(retrieved_docs) | |
# Create comprehensive prompt | |
medical_prompt = f""" | |
You are an EXPERT MEDICAL ADVISOR. Use the complete session context to provide accurate, personalized medical advice. | |
IMPORTANT CONTEXT: | |
{session_context} | |
RELEVANT MEDICAL DOCUMENTS: | |
{context_docs} | |
CURRENT QUESTION: {current_query} | |
INSTRUCTIONS: | |
1. Use the COMPLETE conversation history and medical entities to understand the full context | |
2. Reference previous symptoms, conditions, and health information discussed | |
3. Provide personalized advice based on the user's health profile | |
4. If information is incomplete, mention what additional details would be helpful | |
5. Keep response under 300 words but comprehensive | |
6. Only provide information supported by the medical documents | |
7. If unsure, clearly state limitations | |
{"CLARIFYING QUESTIONS TO CONSIDER: " + " | ".join(clarifying_questions) if clarifying_questions else ""} | |
Provide your medical advice: | |
""" | |
response = gemini_llm._call(medical_prompt) | |
# Clean up response | |
sentences = [s.strip() for s in response.split('.') if s.strip()] | |
unique_sentences = list(OrderedDict.fromkeys(sentences)) | |
cleaned_response = '. '.join(unique_sentences) + '.' | |
# Add to messages | |
state["messages"].append(AIMessage(content=cleaned_response)) | |
return state | |
def should_ask_clarification(state: ConversationState) -> str: | |
"""Routing function to determine if clarification is needed.""" | |
return "clarification" if state.get("needs_clarification", False) else "response" | |
# ===== BUILD LANGGRAPH ===== | |
def create_medical_graph(): | |
workflow = StateGraph(ConversationState) | |
# Add nodes | |
workflow.add_node("entity_extraction", entity_extraction_node) | |
workflow.add_node("context_builder", context_builder_node) | |
workflow.add_node("retrieval", retrieval_node) | |
workflow.add_node("clarification_check", clarification_check_node) | |
workflow.add_node("medical_advisor", medical_advisor_node) | |
# Add edges | |
workflow.set_entry_point("entity_extraction") | |
workflow.add_edge("entity_extraction", "context_builder") | |
workflow.add_edge("context_builder", "retrieval") | |
workflow.add_edge("retrieval", "clarification_check") | |
# Conditional routing | |
workflow.add_conditional_edges( | |
"clarification_check", | |
should_ask_clarification, | |
{ | |
"clarification": "medical_advisor", | |
"response": "medical_advisor" | |
} | |
) | |
workflow.add_edge("medical_advisor", END) | |
return workflow.compile() | |
# Create the graph | |
medical_graph = create_medical_graph() | |
# ===== SESSION MANAGEMENT ===== | |
# Simple in-memory session storage for demo | |
active_sessions: Dict[str, ConversationState] = {} | |
def get_or_create_session(session_id: str) -> ConversationState: | |
if session_id not in active_sessions: | |
active_sessions[session_id] = ConversationState( | |
messages=[], | |
current_query="", | |
health_profile="", | |
medical_entities={}, | |
conversation_summary="", | |
retrieved_documents=[], | |
clarifying_questions=[], | |
needs_clarification=False, | |
session_context="" | |
) | |
return active_sessions[session_id] | |
# ===== MAIN API FUNCTION ===== | |
def ask_question(query: str, health_info: str = "No health profile provided", session_id: str = "default"): | |
""" | |
Main API function - preserves your original interface while adding session support. | |
""" | |
try: | |
# Get or create session state | |
state = get_or_create_session(session_id) | |
# Update state with current query and health info | |
state["current_query"] = query | |
state["health_profile"] = health_info | |
# Add user message to conversation history | |
state["messages"].append(HumanMessage(content=query)) | |
# Run the medical graph | |
result = medical_graph.invoke(state) | |
# Update session state | |
active_sessions[session_id] = result | |
# Extract response | |
last_message = result["messages"][-1] | |
response_text = last_message.content if hasattr(last_message, 'content') else str(last_message) | |
# Prepare additional info | |
clarifying_questions = result.get("clarifying_questions", []) | |
medical_entities = result.get("medical_entities", {}) | |
# Format additional info for gradio | |
additional_info = { | |
"clarifying_questions": clarifying_questions, | |
"medical_entities": medical_entities, | |
"session_id": session_id | |
} | |
return response_text, additional_info | |
except Exception as e: | |
return f"Error: {str(e)}", {"error": True} | |
# ===== GRADIO INTERFACE ===== | |
def gradio_interface(query, health_info, session_id): | |
"""Wrapper function for Gradio interface.""" | |
response, additional_info = ask_question(query, health_info, session_id) | |
# Format additional info for display | |
info_display = "" | |
if additional_info.get("clarifying_questions"): | |
info_display += "**Clarifying Questions:**\n" | |
for i, q in enumerate(additional_info["clarifying_questions"], 1): | |
info_display += f"{i}. {q}\n" | |
info_display += "\n" | |
if additional_info.get("medical_entities"): | |
info_display += "**Medical Entities Tracked:**\n" | |
for category, items in additional_info["medical_entities"].items(): | |
if items: | |
info_display += f"- {category.title()}: {', '.join(items)}\n" | |
return response, info_display | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Question", placeholder="Enter your medical question here..."), | |
gr.Textbox(label="Health Profile", placeholder="Enter your health information (optional)...", value="No health profile provided"), | |
gr.Textbox(label="Session ID", placeholder="Enter session ID (optional)", value="default") | |
], | |
outputs=[ | |
gr.Textbox(label="Medical Advice", lines=10), | |
gr.Textbox(label="Additional Information", lines=5) | |
], | |
title="π₯ Advanced Medical RAG Chatbot with LangGraph", | |
description=""" | |
**Features:** | |
- π Maintains conversation context across questions | |
- π§ Extracts and tracks medical entities (symptoms, medications, etc.) | |
- β Asks clarifying questions when needed | |
- π€ Personalizes responses based on health profile | |
- π Uses medical knowledge base for accurate information | |
**Tips:** | |
- Use the same Session ID to maintain conversation context | |
- Provide detailed health profile for personalized advice | |
- Answer clarifying questions for better recommendations | |
""", | |
examples=[ | |
["I have been having fever for 2 days", "Age: 25, No chronic conditions", "user123"], | |
["What medicines should I take for this fever?", "Age: 25, No chronic conditions", "user123"], | |
["I also have a headache now", "Age: 25, No chronic conditions", "user123"] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) |