Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import json | |
import os | |
from typing import Union, List, Dict, Optional, Tuple | |
from groq import Groq | |
from duckduckgo_search import DDGS | |
from datetime import datetime, timedelta | |
import time | |
import numpy as np | |
import pickle | |
from dataclasses import dataclass, asdict | |
import hashlib | |
from collections import defaultdict | |
# Set page configuration | |
st.set_page_config( | |
page_title="MedAssist - AI Medical Preconsultation", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
page_icon="π₯" | |
) | |
# Enhanced CSS for medical theme | |
st.markdown(""" | |
<style> | |
/* Medical theme styling */ | |
html, body, .stApp, .main { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
color: #ffffff !important; | |
} | |
.medical-header { | |
background: linear-gradient(45deg, #2c5aa0, #4a90e2) !important; | |
color: white !important; | |
padding: 2rem !important; | |
border-radius: 15px !important; | |
text-align: center !important; | |
margin-bottom: 2rem !important; | |
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.37) !important; | |
} | |
.chat-container { | |
background: rgba(255, 255, 255, 0.1) !important; | |
border-radius: 15px !important; | |
padding: 1rem !important; | |
backdrop-filter: blur(10px) !important; | |
border: 1px solid rgba(255, 255, 255, 0.2) !important; | |
margin-bottom: 1rem !important; | |
max-height: 500px !important; | |
overflow-y: auto !important; | |
} | |
.user-message { | |
background: linear-gradient(45deg, #4CAF50, #66BB6A) !important; | |
color: white !important; | |
padding: 1rem !important; | |
border-radius: 15px 15px 5px 15px !important; | |
margin: 0.5rem 0 !important; | |
margin-left: 2rem !important; | |
box-shadow: 0 4px 15px rgba(76, 175, 80, 0.4) !important; | |
} | |
.assistant-message { | |
background: rgba(255, 255, 255, 0.15) !important; | |
color: white !important; | |
padding: 1rem !important; | |
border-radius: 15px 15px 15px 5px !important; | |
margin: 0.5rem 0 !important; | |
margin-right: 2rem !important; | |
border-left: 4px solid #2196F3 !important; | |
backdrop-filter: blur(5px) !important; | |
} | |
.agent-status-card { | |
background: rgba(255, 255, 255, 0.15) !important; | |
border: 1px solid rgba(255, 255, 255, 0.3) !important; | |
border-radius: 12px !important; | |
padding: 1rem !important; | |
margin: 0.5rem 0 !important; | |
backdrop-filter: blur(5px) !important; | |
} | |
.evolution-metrics { | |
background: linear-gradient(45deg, #FF6B6B, #FF8E8E) !important; | |
color: white !important; | |
padding: 1rem !important; | |
border-radius: 10px !important; | |
margin: 0.5rem 0 !important; | |
} | |
.warning-box { | |
background: rgba(255, 152, 0, 0.2) !important; | |
border: 2px solid #FF9800 !important; | |
border-radius: 10px !important; | |
padding: 1.5rem !important; | |
margin: 1rem 0 !important; | |
color: white !important; | |
} | |
.stButton > button { | |
background: linear-gradient(45deg, #2196F3, #64B5F6) !important; | |
color: white !important; | |
border: none !important; | |
border-radius: 25px !important; | |
font-weight: bold !important; | |
padding: 0.75rem 2rem !important; | |
transition: all 0.3s ease !important; | |
} | |
.stButton > button:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 8px 25px rgba(33, 150, 243, 0.6) !important; | |
} | |
.chat-input { | |
position: sticky !important; | |
bottom: 0 !important; | |
background: rgba(255, 255, 255, 0.1) !important; | |
padding: 1rem !important; | |
border-radius: 15px !important; | |
backdrop-filter: blur(10px) !important; | |
} | |
.spinner { | |
border: 2px solid rgba(255, 255, 255, 0.3); | |
border-radius: 50%; | |
border-top: 2px solid #ffffff; | |
width: 20px; | |
height: 20px; | |
animation: spin 1s linear infinite; | |
display: inline-block; | |
} | |
@keyframes spin { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
class ConversationEntry: | |
"""Data structure for storing conversation entries""" | |
timestamp: str | |
user_input: str | |
assistant_response: str | |
symptoms: List[str] | |
severity_score: float | |
confidence_score: float | |
search_queries_used: List[str] | |
user_feedback: Optional[int] = None # 1-5 rating | |
was_helpful: Optional[bool] = None | |
class AgentPerformance: | |
"""Track agent performance metrics""" | |
agent_name: str | |
total_queries: int = 0 | |
successful_responses: int = 0 | |
average_confidence: float = 0.0 | |
user_satisfaction: float = 0.0 | |
learning_rate: float = 0.01 | |
expertise_areas: Dict[str, float] = None | |
def __post_init__(self): | |
if self.expertise_areas is None: | |
self.expertise_areas = defaultdict(float) | |
class MedicalSearchTool: | |
"""Enhanced medical search tool with domain-specific optimization""" | |
def __init__(self): | |
self.ddgs = DDGS() | |
self.medical_sources = [ | |
"mayoclinic.org", "webmd.com", "healthline.com", "medlineplus.gov", | |
"nih.gov", "who.int", "cdc.gov", "ncbi.nlm.nih.gov" | |
] | |
def search_medical_info(self, query: str, search_type: str = "symptoms") -> str: | |
"""Search for medical information with safety considerations""" | |
try: | |
# Add medical context to search | |
medical_queries = { | |
"symptoms": f"medical symptoms {query} causes diagnosis", | |
"treatment": f"medical treatment {query} therapy options", | |
"prevention": f"disease prevention {query} health tips", | |
"general": f"medical information {query} health facts" | |
} | |
enhanced_query = medical_queries.get(search_type, medical_queries["general"]) | |
# Perform search with medical focus | |
search_results = list(self.ddgs.text( | |
enhanced_query, | |
max_results=5, | |
region='wt-wt', | |
safesearch='on' | |
)) | |
if not search_results: | |
return "No relevant medical information found. Please consult with a healthcare professional." | |
# Filter and format results with medical authority preference | |
formatted_results = [] | |
for idx, result in enumerate(search_results, 1): | |
title = result.get('title', 'No title') | |
snippet = result.get('body', 'No description') | |
url = result.get('href', 'No URL') | |
# Prioritize trusted medical sources | |
source_trust = "β" if any(source in url for source in self.medical_sources) else "" | |
formatted_results.append( | |
f"{idx}. {source_trust} {title}\n" | |
f" Information: {snippet}\n" | |
f" Source: {url}\n" | |
) | |
return "\n".join(formatted_results) | |
except Exception as e: | |
return f"Search temporarily unavailable: {str(e)}" | |
class GroqLLM: | |
"""Medical-optimized LLM client""" | |
def __init__(self, model_name="openai/gpt-oss-20b"): | |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
self.model_name = model_name | |
self.medical_context = """ | |
You are a medical AI assistant for preconsultation guidance. | |
IMPORTANT: Always remind users that this is not a substitute for professional medical advice. | |
Provide helpful information while emphasizing the need for proper medical consultation. | |
""" | |
def generate_response(self, prompt: str, conversation_history: List[str] = None) -> Tuple[str, float]: | |
"""Generate response with confidence scoring""" | |
try: | |
# Build context with conversation history | |
context = self.medical_context | |
if conversation_history: | |
context += f"\n\nConversation History:\n{chr(10).join(conversation_history[-5:])}" | |
full_prompt = f"{context}\n\nUser Query: {prompt}\n\nPlease provide helpful medical guidance while emphasizing the importance of professional medical consultation." | |
completion = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[{"role": "user", "content": full_prompt}], | |
temperature=0.3, # Lower temperature for medical accuracy | |
max_tokens=1500, | |
stream=False | |
) | |
response = completion.choices[0].message.content if completion.choices else "Unable to generate response" | |
# Calculate confidence score based on response characteristics | |
confidence = self._calculate_confidence(response, prompt) | |
return response, confidence | |
except Exception as e: | |
return f"LLM temporarily unavailable: {str(e)}", 0.0 | |
def _calculate_confidence(self, response: str, query: str) -> float: | |
"""Calculate confidence score based on response quality""" | |
confidence_factors = 0.0 | |
# Check for medical disclaimers (increases confidence in safety) | |
if any(phrase in response.lower() for phrase in ["consult", "doctor", "medical professional", "healthcare provider"]): | |
confidence_factors += 0.3 | |
# Check response length (adequate detail) | |
if 200 <= len(response) <= 1000: | |
confidence_factors += 0.2 | |
# Check for structured information | |
if any(marker in response for marker in ["1.", "β’", "-", "**"]): | |
confidence_factors += 0.2 | |
# Check for balanced information (not overly certain) | |
if any(phrase in response.lower() for phrase in ["may", "might", "could", "possible", "typically"]): | |
confidence_factors += 0.3 | |
return min(confidence_factors, 1.0) | |
class EvolutionaryMedicalAgent: | |
"""Evolutionary agent with reinforcement learning capabilities""" | |
def __init__(self, agent_id: str, specialization: str): | |
self.agent_id = agent_id | |
self.specialization = specialization | |
self.performance = AgentPerformance(agent_name=agent_id) | |
self.knowledge_base = defaultdict(float) | |
self.response_patterns = {} | |
self.learning_memory = [] | |
def process_query(self, query: str, context: str, search_results: str) -> Tuple[str, float]: | |
"""Process query and adapt based on specialization""" | |
# Update query count | |
self.performance.total_queries += 1 | |
# Extract key terms for learning | |
key_terms = self._extract_medical_terms(query) | |
# Build specialized response based on agent's expertise | |
specialized_prompt = f""" | |
As a {self.specialization} specialist, analyze this medical query: | |
Query: {query} | |
Context: {context} | |
Search Results: {search_results} | |
Provide specialized insights based on your expertise in {self.specialization}. | |
Always emphasize the need for professional medical consultation. | |
""" | |
# Simulate processing (in real implementation, this would use the LLM) | |
response = f"Based on my specialization in {self.specialization}, {query.lower()} suggests several considerations. However, please consult with a healthcare professional for proper diagnosis and treatment." | |
confidence = 0.7 + (self.performance.average_confidence * 0.3) | |
# Update expertise in relevant areas | |
for term in key_terms: | |
self.knowledge_base[term] += 0.1 | |
return response, confidence | |
def update_from_feedback(self, query: str, response: str, feedback_score: int, was_helpful: bool): | |
"""Update agent based on user feedback (reinforcement learning)""" | |
# Calculate reward signal | |
reward = (feedback_score - 3) / 2 # Convert 1-5 scale to -1 to 1 | |
if was_helpful: | |
reward += 0.2 | |
# Update performance metrics | |
if feedback_score >= 3: | |
self.performance.successful_responses += 1 | |
# Update satisfaction and confidence | |
self.performance.user_satisfaction = ( | |
(self.performance.user_satisfaction * (self.performance.total_queries - 1) + feedback_score) / | |
self.performance.total_queries | |
) | |
# Store learning memory | |
self.learning_memory.append({ | |
'query': query, | |
'response': response, | |
'reward': reward, | |
'timestamp': datetime.now().isoformat() | |
}) | |
# Adapt learning rate based on performance | |
if self.performance.user_satisfaction > 4.0: | |
self.performance.learning_rate *= 0.95 # Slow down learning when performing well | |
elif self.performance.user_satisfaction < 3.0: | |
self.performance.learning_rate *= 1.1 # Speed up learning when performing poorly | |
# Update expertise areas based on feedback | |
terms = self._extract_medical_terms(query) | |
for term in terms: | |
self.knowledge_base[term] += reward * self.performance.learning_rate | |
def _extract_medical_terms(self, text: str) -> List[str]: | |
"""Extract medical terms from text for learning""" | |
medical_keywords = [ | |
'pain', 'fever', 'headache', 'nausea', 'fatigue', 'cough', 'cold', 'flu', | |
'diabetes', 'hypertension', 'infection', 'allergy', 'asthma', 'arthritis', | |
'anxiety', 'depression', 'insomnia', 'migraine', 'rash', 'swelling' | |
] | |
found_terms = [] | |
text_lower = text.lower() | |
for term in medical_keywords: | |
if term in text_lower: | |
found_terms.append(term) | |
return found_terms | |
def get_expertise_summary(self) -> Dict: | |
"""Get summary of agent's learned expertise""" | |
return { | |
'specialization': self.specialization, | |
'total_queries': self.performance.total_queries, | |
'success_rate': (self.performance.successful_responses / max(1, self.performance.total_queries)) * 100, | |
'user_satisfaction': self.performance.user_satisfaction, | |
'learning_rate': self.performance.learning_rate, | |
'top_expertise_areas': dict(sorted(self.knowledge_base.items(), key=lambda x: x[1], reverse=True)[:5]) | |
} | |
class MedicalConsultationSystem: | |
"""Main medical consultation system with evolutionary agents""" | |
def __init__(self): | |
self.llm = GroqLLM() | |
self.search_tool = MedicalSearchTool() | |
self.agents = self._initialize_agents() | |
self.conversation_history = [] | |
self.conversation_data = [] | |
def _initialize_agents(self) -> Dict[str, EvolutionaryMedicalAgent]: | |
"""Initialize specialized medical agents""" | |
return { | |
"general_practitioner": EvolutionaryMedicalAgent("gp", "General Practice Medicine"), | |
"symptom_analyzer": EvolutionaryMedicalAgent("symptom", "Symptom Analysis and Triage"), | |
"wellness_advisor": EvolutionaryMedicalAgent("wellness", "Preventive Care and Wellness"), | |
"mental_health": EvolutionaryMedicalAgent("mental", "Mental Health and Psychology"), | |
"emergency_assessor": EvolutionaryMedicalAgent("emergency", "Emergency Assessment and Urgent Care") | |
} | |
def process_medical_query(self, user_query: str) -> Dict: | |
"""Process medical query through evolutionary agent system""" | |
timestamp = datetime.now().isoformat() | |
# Determine which agents should handle this query | |
relevant_agents = self._select_relevant_agents(user_query) | |
# Search for medical information | |
search_results = self.search_tool.search_medical_info(user_query, "symptoms") | |
# Build conversation context | |
context = "\n".join(self.conversation_history[-3:]) if self.conversation_history else "" | |
# Get responses from relevant agents | |
agent_responses = {} | |
for agent_name in relevant_agents: | |
agent = self.agents[agent_name] | |
response, confidence = agent.process_query(user_query, context, search_results) | |
agent_responses[agent_name] = { | |
'response': response, | |
'confidence': confidence, | |
'specialization': agent.specialization | |
} | |
# Generate main LLM response | |
main_response, main_confidence = self.llm.generate_response( | |
f"{user_query}\n\nRelevant Information: {search_results}", | |
self.conversation_history | |
) | |
# Combine responses intelligently | |
final_response = self._combine_responses(main_response, agent_responses) | |
# Update conversation history | |
self.conversation_history.extend([ | |
f"User: {user_query}", | |
f"Assistant: {final_response}" | |
]) | |
# Extract symptoms for analysis | |
symptoms = self._extract_symptoms(user_query) | |
severity_score = self._assess_severity(user_query, symptoms) | |
# Store conversation data | |
conversation_entry = ConversationEntry( | |
timestamp=timestamp, | |
user_input=user_query, | |
assistant_response=final_response, | |
symptoms=symptoms, | |
severity_score=severity_score, | |
confidence_score=main_confidence, | |
search_queries_used=[user_query] | |
) | |
self.conversation_data.append(conversation_entry) | |
return { | |
'response': final_response, | |
'confidence': main_confidence, | |
'severity_score': severity_score, | |
'symptoms_detected': symptoms, | |
'agents_consulted': relevant_agents, | |
'agent_responses': agent_responses, | |
'search_performed': True | |
} | |
def _select_relevant_agents(self, query: str) -> List[str]: | |
"""Select most relevant agents for the query""" | |
query_lower = query.lower() | |
relevant_agents = ["general_practitioner"] # Always include GP | |
# Mental health keywords | |
mental_health_keywords = ["stress", "anxiety", "depression", "sleep", "mood", "worry", "panic", "sad"] | |
if any(keyword in query_lower for keyword in mental_health_keywords): | |
relevant_agents.append("mental_health") | |
# Emergency keywords | |
emergency_keywords = ["severe", "intense", "emergency", "urgent", "chest pain", "difficulty breathing", "blood"] | |
if any(keyword in query_lower for keyword in emergency_keywords): | |
relevant_agents.append("emergency_assessor") | |
# Wellness keywords | |
wellness_keywords = ["prevention", "healthy", "nutrition", "exercise", "lifestyle", "diet"] | |
if any(keyword in query_lower for keyword in wellness_keywords): | |
relevant_agents.append("wellness_advisor") | |
# Always include symptom analyzer for health queries | |
if any(keyword in query_lower for keyword in ["pain", "ache", "hurt", "symptom", "feel"]): | |
relevant_agents.append("symptom_analyzer") | |
return list(set(relevant_agents)) | |
def _combine_responses(self, main_response: str, agent_responses: Dict) -> str: | |
"""Intelligently combine responses from multiple agents""" | |
if not agent_responses: | |
return main_response | |
combined = main_response + "\n\n**Specialist Insights:**\n" | |
for agent_name, data in agent_responses.items(): | |
if data['confidence'] > 0.6: # Only include confident responses | |
combined += f"\nβ’ **{data['specialization']}**: {data['response'][:200]}...\n" | |
return combined | |
def _extract_symptoms(self, query: str) -> List[str]: | |
"""Extract symptoms from user query""" | |
common_symptoms = [ | |
'fever', 'headache', 'nausea', 'pain', 'cough', 'fatigue', 'dizziness', | |
'rash', 'swelling', 'shortness of breath', 'chest pain', 'abdominal pain' | |
] | |
query_lower = query.lower() | |
detected_symptoms = [symptom for symptom in common_symptoms if symptom in query_lower] | |
return detected_symptoms | |
def _assess_severity(self, query: str, symptoms: List[str]) -> float: | |
"""Assess severity of reported symptoms (0-10 scale)""" | |
severity_score = 0.0 | |
query_lower = query.lower() | |
# High severity indicators | |
high_severity = ["severe", "intense", "unbearable", "emergency", "chest pain", "difficulty breathing"] | |
medium_severity = ["moderate", "persistent", "recurring", "worse", "concerning"] | |
if any(indicator in query_lower for indicator in high_severity): | |
severity_score += 7.0 | |
elif any(indicator in query_lower for indicator in medium_severity): | |
severity_score += 4.0 | |
else: | |
severity_score += 2.0 | |
# Add points for multiple symptoms | |
severity_score += min(len(symptoms) * 0.5, 2.0) | |
return min(severity_score, 10.0) | |
def update_agent_performance(self, query_index: int, feedback_score: int, was_helpful: bool): | |
"""Update agent performance based on user feedback""" | |
if query_index < len(self.conversation_data): | |
entry = self.conversation_data[query_index] | |
entry.user_feedback = feedback_score | |
entry.was_helpful = was_helpful | |
# Update all agents that were involved in this query | |
for agent in self.agents.values(): | |
agent.update_from_feedback(entry.user_input, entry.assistant_response, feedback_score, was_helpful) | |
def get_system_metrics(self) -> Dict: | |
"""Get comprehensive system performance metrics""" | |
total_conversations = len(self.conversation_data) | |
if total_conversations == 0: | |
return {"status": "No conversations yet"} | |
avg_confidence = np.mean([entry.confidence_score for entry in self.conversation_data]) | |
avg_severity = np.mean([entry.severity_score for entry in self.conversation_data]) | |
feedback_entries = [entry for entry in self.conversation_data if entry.user_feedback is not None] | |
avg_feedback = np.mean([entry.user_feedback for entry in feedback_entries]) if feedback_entries else 0 | |
return { | |
"total_conversations": total_conversations, | |
"average_confidence": avg_confidence, | |
"average_severity": avg_severity, | |
"average_user_feedback": avg_feedback, | |
"agent_performance": {name: agent.get_expertise_summary() for name, agent in self.agents.items()} | |
} | |
# Initialize session state | |
if 'medical_system' not in st.session_state: | |
st.session_state.medical_system = MedicalConsultationSystem() | |
if 'chat_messages' not in st.session_state: | |
st.session_state.chat_messages = [] | |
medical_system = st.session_state.medical_system | |
# Main interface | |
st.markdown(""" | |
<div class="medical-header"> | |
<h1>π₯ MedAssist - AI Medical Preconsultation</h1> | |
<p>Advanced AI-powered medical guidance with evolutionary learning agents</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Medical disclaimer | |
st.markdown(""" | |
<div class="warning-box"> | |
<h3>β οΈ Important Medical Disclaimer</h3> | |
<p>This AI system provides general health information and is NOT a substitute for professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals for medical concerns. In case of emergency, contact emergency services immediately.</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Main layout | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown("### π¬ Medical Consultation Chat") | |
# Chat display area | |
chat_container = st.container() | |
with chat_container: | |
st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
for i, message in enumerate(st.session_state.chat_messages): | |
if message["role"] == "user": | |
st.markdown(f'<div class="user-message">π€ <strong>You:</strong> {message["content"]}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="assistant-message">π€ <strong>MedAssist:</strong> {message["content"]}</div>', unsafe_allow_html=True) | |
# Add feedback buttons for assistant messages | |
col_a, col_b, col_c = st.columns([1, 1, 8]) | |
with col_a: | |
if st.button("π", key=f"helpful_{i}"): | |
medical_system.update_agent_performance(i//2, 5, True) | |
st.success("Feedback recorded!") | |
with col_b: | |
if st.button("π", key=f"not_helpful_{i}"): | |
medical_system.update_agent_performance(i//2, 2, False) | |
st.info("Feedback recorded. We'll improve!") | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Chat input | |
with st.container(): | |
st.markdown('<div class="chat-input">', unsafe_allow_html=True) | |
user_input = st.text_input("Describe your symptoms or health concerns:", | |
placeholder="e.g., I've been having headaches for 3 days...", | |
key="medical_input") | |
col_send, col_clear = st.columns([1, 4]) | |
with col_send: | |
send_message = st.button("Send π€", type="primary") | |
with col_clear: | |
if st.button("Clear Chat ποΈ"): | |
st.session_state.chat_messages = [] | |
st.rerun() | |
st.markdown('</div>', unsafe_allow_html=True) | |
with col2: | |
st.markdown("### π€ AI Agent Status") | |
# Agent status display | |
for agent_name, agent in medical_system.agents.items(): | |
expertise = agent.get_expertise_summary() | |
st.markdown(f""" | |
<div class="agent-status-card"> | |
<h4>{agent.specialization}</h4> | |
<p><strong>Queries:</strong> {expertise['total_queries']}</p> | |
<p><strong>Success Rate:</strong> {expertise['success_rate']:.1f}%</p> | |
<p><strong>Satisfaction:</strong> {expertise['user_satisfaction']:.1f}/5</p> | |
<p><strong>Learning Rate:</strong> {expertise['learning_rate']:.3f}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown("### π System Metrics") | |
metrics = medical_system.get_system_metrics() | |
if "total_conversations" in metrics: | |
st.markdown(f""" | |
<div class="evolution-metrics"> | |
<p><strong>Total Chats:</strong> {metrics['total_conversations']}</p> | |
<p><strong>Avg Confidence:</strong> {metrics['average_confidence']:.2f}</p> | |
<p><strong>Avg Severity:</strong> {metrics['average_severity']:.1f}/10</p> | |
<p><strong>User Rating:</strong> {metrics['average_user_feedback']:.1f}/5</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Process user input | |
if send_message and user_input: | |
# Add user message | |
st.session_state.chat_messages.append({"role": "user", "content": user_input}) | |
# Show thinking indicator | |
with st.spinner("π§ AI agents are analyzing your query..."): | |
# Process the query | |
result = medical_system.process_medical_query(user_input) | |
# Add assistant response | |
response_content = result['response'] | |
# Add severity and confidence info | |
if result['severity_score'] > 7: | |
response_content += f"\n\nβ οΈ **High severity detected ({result['severity_score']:.1f}/10). Please seek immediate medical attention if symptoms are severe.**" | |
elif result['severity_score'] > 4: | |
response_content += f"\n\nβ‘ **Moderate severity detected ({result['severity_score']:.1f}/10). Consider scheduling a medical appointment.**" | |
if result['symptoms_detected']: | |
response_content += f"\n\nπ **Detected symptoms:** {', '.join(result['symptoms_detected'])}" | |
response_content += f"\n\nπ€ **Confidence Score:** {result['confidence']:.2f} | **Agents Consulted:** {', '.join(result['agents_consulted'])}" | |
st.session_state.chat_messages.append({"role": "assistant", "content": response_content}) | |
st.rerun() | |
# Sidebar with additional features | |
with st.sidebar: | |
st.markdown("### π οΈ System Controls") | |
if st.button("π Reset System"): | |
st.session_state.medical_system = MedicalConsultationSystem() | |
st.session_state.chat_messages = [] | |
st.rerun() | |
st.markdown("### π Learning Analytics") | |
if st.button("π View Detailed Analytics"): | |
st.session_state.show_analytics = True | |
if st.button("πΎ Export Chat History"): | |
if st.session_state.chat_messages: | |
chat_data = { | |
'timestamp': datetime.now().isoformat(), | |
'messages': st.session_state.chat_messages, | |
'system_metrics': medical_system.get_system_metrics() | |
} | |
st.download_button( | |
label="Download Chat Data", | |
data=json.dumps(chat_data, indent=2), | |
file_name=f"medical_chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
mime="application/json" | |
) | |
else: | |
st.warning("No chat history to export") | |
st.markdown("### π― Quick Health Topics") | |
quick_topics = [ | |
"Common cold symptoms", | |
"Headache causes", | |
"Stress management", | |
"Sleep problems", | |
"Healthy diet tips", | |
"Exercise recommendations" | |
] | |
for topic in quick_topics: | |
if st.button(f"π‘ {topic}", key=f"topic_{topic.replace(' ', '_')}"): | |
st.session_state.chat_messages.append({"role": "user", "content": f"Tell me about {topic.lower()}"}) | |
with st.spinner("π§ Processing..."): | |
result = medical_system.process_medical_query(f"Tell me about {topic.lower()}") | |
response_content = result['response'] | |
if result['symptoms_detected']: | |
response_content += f"\n\nπ **Related symptoms:** {', '.join(result['symptoms_detected'])}" | |
response_content += f"\n\nπ€ **Confidence:** {result['confidence']:.2f}" | |
st.session_state.chat_messages.append({"role": "assistant", "content": response_content}) | |
st.rerun() | |
# Analytics Dashboard (if requested) | |
if st.session_state.get('show_analytics', False): | |
st.markdown("---") | |
st.markdown("## π Detailed System Analytics") | |
metrics = medical_system.get_system_metrics() | |
if "agent_performance" in metrics: | |
# Agent Performance Comparison | |
st.markdown("### π€ Agent Performance Analysis") | |
agent_data = [] | |
for agent_name, performance in metrics["agent_performance"].items(): | |
agent_data.append({ | |
'Agent': performance['specialization'], | |
'Success Rate (%)': performance['success_rate'], | |
'User Satisfaction': performance['user_satisfaction'], | |
'Learning Rate': performance['learning_rate'], | |
'Total Queries': performance['total_queries'] | |
}) | |
if agent_data: | |
df_agents = pd.DataFrame(agent_data) | |
st.dataframe(df_agents, use_container_width=True) | |
# Performance charts | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("#### Success Rate by Agent") | |
if not df_agents.empty: | |
st.bar_chart(df_agents.set_index('Agent')['Success Rate (%)']) | |
with col2: | |
st.markdown("#### User Satisfaction by Agent") | |
if not df_agents.empty: | |
st.bar_chart(df_agents.set_index('Agent')['User Satisfaction']) | |
# Conversation Analysis | |
st.markdown("### π¬ Conversation Analysis") | |
if medical_system.conversation_data: | |
conversation_df = pd.DataFrame([asdict(entry) for entry in medical_system.conversation_data]) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Total Conversations", len(conversation_df)) | |
avg_confidence = conversation_df['confidence_score'].mean() | |
st.metric("Average Confidence", f"{avg_confidence:.2f}") | |
with col2: | |
avg_severity = conversation_df['severity_score'].mean() | |
st.metric("Average Severity", f"{avg_severity:.1f}/10") | |
feedback_data = conversation_df[conversation_df['user_feedback'].notna()] | |
if not feedback_data.empty: | |
avg_feedback = feedback_data['user_feedback'].mean() | |
st.metric("Average User Rating", f"{avg_feedback:.1f}/5") | |
with col3: | |
symptoms_detected = sum(len(symptoms) for symptoms in conversation_df['symptoms']) | |
st.metric("Total Symptoms Detected", symptoms_detected) | |
helpful_responses = conversation_df['was_helpful'].sum() if 'was_helpful' in conversation_df else 0 | |
st.metric("Helpful Responses", helpful_responses) | |
# Severity distribution | |
st.markdown("#### Severity Score Distribution") | |
severity_counts = conversation_df['severity_score'].value_counts().sort_index() | |
st.bar_chart(severity_counts) | |
# Most common symptoms | |
st.markdown("#### Most Common Symptoms") | |
all_symptoms = [] | |
for symptoms_list in conversation_df['symptoms']: | |
all_symptoms.extend(symptoms_list) | |
if all_symptoms: | |
symptom_counts = pd.Series(all_symptoms).value_counts().head(10) | |
st.bar_chart(symptom_counts) | |
else: | |
st.info("No symptoms data available yet") | |
# Timeline analysis | |
st.markdown("#### Usage Timeline") | |
conversation_df['timestamp'] = pd.to_datetime(conversation_df['timestamp']) | |
daily_usage = conversation_df.groupby(conversation_df['timestamp'].dt.date).size() | |
st.line_chart(daily_usage) | |
else: | |
st.info("No conversation data available for analysis yet") | |
# Learning Progress | |
st.markdown("### π§ AI Learning Progress") | |
for agent_name, agent in medical_system.agents.items(): | |
with st.expander(f"π {agent.specialization} Learning Details"): | |
expertise = agent.get_expertise_summary() | |
st.write(f"**Total Experience:** {expertise['total_queries']} queries processed") | |
st.write(f"**Current Learning Rate:** {expertise['learning_rate']:.4f}") | |
st.write(f"**Performance Trend:** {'Improving' if expertise['user_satisfaction'] > 3.5 else 'Learning'}") | |
if expertise['top_expertise_areas']: | |
st.write("**Top Expertise Areas:**") | |
for area, score in expertise['top_expertise_areas'].items(): | |
st.write(f" β’ {area.title()}: {score:.2f}") | |
# Learning memory (last few interactions) | |
if hasattr(agent, 'learning_memory') and agent.learning_memory: | |
st.write("**Recent Learning Events:**") | |
for memory in agent.learning_memory[-3:]: | |
reward_emoji = "β " if memory['reward'] > 0 else "β" if memory['reward'] < 0 else "β‘οΈ" | |
st.write(f" {reward_emoji} Reward: {memory['reward']:.2f} | Query: {memory['query'][:50]}...") | |
if st.button("π Close Analytics"): | |
st.session_state.show_analytics = False | |
st.rerun() | |
# Health Tips Section | |
st.markdown("---") | |
st.markdown("### π Daily Health Tips") | |
health_tips = [ | |
"π§ Stay hydrated: Aim for 8-10 glasses of water daily", | |
"πΆ Take regular walks: Even 10 minutes can boost your mood", | |
"π΄ Maintain sleep hygiene: 7-9 hours of quality sleep is essential", | |
"π₯ Eat colorful foods: Variety ensures you get different nutrients", | |
"π§ Practice mindfulness: Just 5 minutes of meditation can reduce stress", | |
"π± Take breaks from screens: Follow the 20-20-20 rule", | |
"π€ Stay connected: Social connections are vital for mental health", | |
"βοΈ Get sunlight: 15 minutes of sunlight helps with Vitamin D" | |
] | |
# Display a random tip | |
import random | |
daily_tip = random.choice(health_tips) | |
st.info(f"**π‘ Today's Health Tip:** {daily_tip}") | |
# Emergency Resources Section | |
st.markdown("### π¨ Emergency Resources") | |
emergency_col1, emergency_col2 = st.columns(2) | |
with emergency_col1: | |
st.markdown(""" | |
**π When to Seek Immediate Help:** | |
- Chest pain or difficulty breathing | |
- Severe allergic reactions | |
- Loss of consciousness | |
- Severe bleeding | |
- Signs of stroke (FAST test) | |
- Severe burns | |
""") | |
with emergency_col2: | |
st.markdown(""" | |
**π Emergency Contacts:** | |
- Emergency Services: 911 (US), 112 (EU) | |
- Poison Control: 1-800-222-1222 (US) | |
- Mental Health Crisis: 988 (US) | |
- Text HOME to 741741 (Crisis Text Line) | |
**π₯ Find Nearest Hospital:** | |
Use your maps app or call emergency services | |
""") | |
# Data Persistence and Learning Enhancement | |
class DataPersistence: | |
"""Handle data persistence for learning and analytics""" | |
def __init__(self, data_dir: str = "medical_ai_data"): | |
self.data_dir = data_dir | |
os.makedirs(data_dir, exist_ok=True) | |
def save_conversation_data(self, system: MedicalConsultationSystem): | |
"""Save conversation data for future learning""" | |
try: | |
data_file = os.path.join(self.data_dir, f"conversations_{datetime.now().strftime('%Y%m%d')}.json") | |
conversations = [] | |
for entry in system.conversation_data: | |
conversations.append(asdict(entry)) | |
with open(data_file, 'w') as f: | |
json.dump(conversations, f, indent=2) | |
return True | |
except Exception as e: | |
st.error(f"Failed to save data: {str(e)}") | |
return False | |
def save_agent_knowledge(self, system: MedicalConsultationSystem): | |
"""Save agent learning data""" | |
try: | |
for agent_name, agent in system.agents.items(): | |
agent_file = os.path.join(self.data_dir, f"agent_{agent_name}_knowledge.pkl") | |
agent_data = { | |
'knowledge_base': dict(agent.knowledge_base), | |
'performance': asdict(agent.performance), | |
'learning_memory': agent.learning_memory[-100:] # Keep last 100 entries | |
} | |
with open(agent_file, 'wb') as f: | |
pickle.dump(agent_data, f) | |
return True | |
except Exception as e: | |
st.error(f"Failed to save agent knowledge: {str(e)}") | |
return False | |
def load_agent_knowledge(self, system: MedicalConsultationSystem): | |
"""Load previously saved agent knowledge""" | |
try: | |
for agent_name, agent in system.agents.items(): | |
agent_file = os.path.join(self.data_dir, f"agent_{agent_name}_knowledge.pkl") | |
if os.path.exists(agent_file): | |
with open(agent_file, 'rb') as f: | |
agent_data = pickle.load(f) | |
# Restore knowledge base | |
agent.knowledge_base = defaultdict(float, agent_data.get('knowledge_base', {})) | |
# Restore learning memory | |
agent.learning_memory = agent_data.get('learning_memory', []) | |
# Restore performance metrics | |
if 'performance' in agent_data: | |
perf_data = agent_data['performance'] | |
agent.performance.total_queries = perf_data.get('total_queries', 0) | |
agent.performance.successful_responses = perf_data.get('successful_responses', 0) | |
agent.performance.average_confidence = perf_data.get('average_confidence', 0.0) | |
agent.performance.user_satisfaction = perf_data.get('user_satisfaction', 0.0) | |
agent.performance.learning_rate = perf_data.get('learning_rate', 0.01) | |
return True | |
except Exception as e: | |
st.error(f"Failed to load agent knowledge: {str(e)}") | |
return False | |
# Initialize data persistence | |
if 'data_persistence' not in st.session_state: | |
st.session_state.data_persistence = DataPersistence() | |
# Load previous learning data when system starts | |
if 'knowledge_loaded' not in st.session_state: | |
st.session_state.data_persistence.load_agent_knowledge(medical_system) | |
st.session_state.knowledge_loaded = True | |
# Auto-save functionality | |
if len(st.session_state.chat_messages) > 0 and len(st.session_state.chat_messages) % 10 == 0: | |
# Save data every 10 messages | |
st.session_state.data_persistence.save_conversation_data(medical_system) | |
st.session_state.data_persistence.save_agent_knowledge(medical_system) | |
# Footer with system information | |
st.markdown("---") | |
st.markdown(""" | |
<div style="text-align: center; padding: 2rem; opacity: 0.8;"> | |
<p><strong>MedAssist v1.0</strong> | AI-Powered Medical Preconsultation System</p> | |
<p>π€ Evolutionary Learning Agents β’ π Real-time Medical Search β’ π¬ Intelligent Chat Interface</p> | |
<p><small>β οΈ This system is for informational purposes only and is not a substitute for professional medical advice</small></p> | |
</div> | |
""", unsafe_allow_html=True) |