diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..7e812a85972b72865741908073cd8d2631b9bd93 --- /dev/null +++ b/.env.example @@ -0,0 +1,16 @@ +# OpenAI Configuration +OPENAI_API_KEY= + +# Application Settings +LOG_LEVEL=INFO +MEMORY_TYPE=sqlite +MEMORY_URI=:memory: + +# Hospital Configuration +HOSPITAL_NAME=Example Hospital +TOTAL_BEDS=300 +DEPARTMENTS=ER,ICU,General,Surgery,Pediatrics + +# Development Settings +DEBUG=False +ENVIRONMENT=production diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ce62710fd22fc80873218698c6601398e7af2dbe --- /dev/null +++ b/.gitignore @@ -0,0 +1,64 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Environment +.env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Logs +*.log +logs/ + +# Testing +.coverage +htmlcov/ +.pytest_cache/ + +# macOS +.DS_Store +.AppleDouble +.LSOverride +Icon +._* +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..2533dae4678558ec841bc3e1d1171bef98190001 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +name: langgraph +channels: + - conda-forge + - defaults +dependencies: + - python=3.11 + - pip + - pip: + - langgraph>=0.0.15 + - langchain>=0.1.0 + - openai>=1.3.0 + - python-dotenv>=0.19.0 + - pydantic>=2.0.0 + - typing-extensions>=4.5.0 + - python-json-logger>=2.0.7 + - structlog>=24.1.0 + - pytest>=7.0.0 + - pytest-asyncio>=0.23.0 + - pytest-cov>=4.1.0 + - black>=23.0.0 + - isort>=5.12.0 + - flake8>=6.1.0 + - mkdocs>=1.5.0 + - mkdocs-material>=9.5.0 diff --git a/examples/usage_examples.py b/examples/usage_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..ff688054ba3cc11f542de756692b1b76f15acf5b --- /dev/null +++ b/examples/usage_examples.py @@ -0,0 +1,87 @@ +# examples/usage_examples.py +import os +from dotenv import load_dotenv +from src.agent import HealthcareAgent + +# Load environment variables +load_dotenv() + +def basic_usage_example(): + """Basic usage example of the Healthcare Agent""" + agent = HealthcareAgent(os.getenv("OPENAI_API_KEY")) + + # Single query example + response = agent.process( + "What is the current ER wait time and bed availability?" + ) + print("Basic Query Response:", response) + +def conversation_example(): + """Example of maintaining conversation context""" + agent = HealthcareAgent() + thread_id = "example-conversation" + + # Series of related queries + queries = [ + "How many beds are currently available in the ER?", + "What is the current staffing level for that department?", + "Based on these metrics, what are your recommendations for optimization?" + ] + + for query in queries: + print(f"\nUser: {query}") + response = agent.process(query, thread_id=thread_id) + print(f"Assistant: {response['response']}") + +def department_analysis_example(): + """Example of department-specific analysis""" + agent = HealthcareAgent() + + # Context with department-specific metrics + context = { + "department": "ICU", + "metrics": { + "bed_capacity": 20, + "occupied_beds": 18, + "staff_count": {"doctors": 5, "nurses": 15}, + "average_stay": 4.5 # days + } + } + + response = agent.process( + "Analyze current ICU operations and suggest improvements", + context=context + ) + print("Department Analysis:", response) + +def async_streaming_example(): + """Example of using async streaming responses""" + import asyncio + + async def stream_response(): + agent = HealthcareAgent() + query = "Provide a complete analysis of current hospital operations" + + async for event in agent.graph.astream_events( + {"messages": [query]}, + {"configurable": {"thread_id": "streaming-example"}} + ): + if event["event"] == "on_chat_model_stream": + content = event["data"]["chunk"].content + if content: + print(content, end="", flush=True) + + asyncio.run(stream_response()) + +if __name__ == "__main__": + print("=== Basic Usage Example ===") + basic_usage_example() + + print("\n=== Conversation Example ===") + conversation_example() + + print("\n=== Department Analysis Example ===") + department_analysis_example() + + print("\n=== Streaming Example ===") + async_streaming_example()# Usage examples implementation diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..fe54d8c8e51086daf64eb9ffdd387df4e8cc7dcb --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test +python_functions = test_* +addopts = -v --cov=src --cov-report=html diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa16123d19852e9b44ca7e7e86fe713d32806765 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +# Core Dependencies +langgraph>=0.0.15 +langchain>=0.1.0 +langchain-openai>=0.0.5 +openai>=1.3.0 +python-dotenv>=0.19.0 +typing-extensions>=4.5.0 + +# State Management +pydantic>=2.0.0 + +# Logging and Monitoring +python-json-logger>=2.0.7 +structlog>=24.1.0 + +# Development Dependencies +pytest>=7.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 +black>=23.0.0 +isort>=5.12.0 +flake8>=6.1.0 + +# Documentation +mkdocs>=1.5.0 +mkdocs-material>=9.5.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c77dce8c46bbcff22ba0fb4f81ab9f7e22f223be --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +# setup.py +from setuptools import setup, find_packages + +# Read requirements +with open('requirements.txt') as f: + requirements = f.read().splitlines() + +# Read README for long description +with open('README.md', encoding='utf-8') as f: + long_description = f.read() + +setup( + name='healthcare-ops-agent', + version='0.1.0', + description='Healthcare Operations Management Agent using LangGraph', + long_description=long_description, + long_description_content_type='text/markdown', + author='Your Name', + author_email='your.email@example.com', + url='https://github.com/yourusername/healthcare-ops-agent', + packages=find_packages(exclude=['tests*']), + install_requires=requirements, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Healthcare Industry', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + ], + python_requires='>=3.9', + include_package_data=True, + extras_require={ + 'dev': [ + 'pytest>=7.0.0', + 'pytest-asyncio>=0.23.0', + 'pytest-cov>=4.1.0', + 'black>=23.0.0', + 'isort>=5.12.0', + 'flake8>=6.1.0', + ], + 'docs': [ + 'mkdocs>=1.5.0', + 'mkdocs-material>=9.5.0', + ], + } +) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeafe29a665a11658e7d1c71abe0a59a38ea97b --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,4 @@ +# src/__init__.py +from .agent import HealthcareAgent + +__version__ = "0.1.0" \ No newline at end of file diff --git a/src/agent.py b/src/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9686b2e7e03583950b6e0893b9f637f5a4cfff9c --- /dev/null +++ b/src/agent.py @@ -0,0 +1,264 @@ +# src/agent.py +from typing import Dict, Optional, List +import uuid +from datetime import datetime +from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage +from langgraph.graph import StateGraph, END +from langchain_openai import ChatOpenAI + +# Remove this line as it's causing the error +# from langgraph.checkpoint import BaseCheckpointSaver + +from .config.settings import Settings +from .models.state import ( + HospitalState, + create_initial_state, + validate_state +) +from .nodes import ( + InputAnalyzerNode, + TaskRouterNode, + PatientFlowNode, + ResourceManagerNode, + QualityMonitorNode, + StaffSchedulerNode, + OutputSynthesizerNode +) +from .tools import ( + PatientTools, + ResourceTools, + QualityTools, + SchedulingTools +) + +from .utils.logger import setup_logger +from .utils.error_handlers import ( + ErrorHandler, + HealthcareError, + ValidationError, # Add this import + ProcessingError # Add this import +) + + +logger = setup_logger(__name__) + +class HealthcareAgent: + def __init__(self, api_key: Optional[str] = None): + try: + # Initialize settings and validate + self.settings = Settings() + if api_key: + self.settings.OPENAI_API_KEY = api_key + self.settings.validate_settings() + + # Initialize LLM + self.llm = ChatOpenAI( + model=self.settings.MODEL_NAME, + temperature=self.settings.MODEL_TEMPERATURE, + api_key=self.settings.OPENAI_API_KEY + ) + + # Initialize tools + self.tools = self._initialize_tools() + + # Initialize nodes + self.nodes = self._initialize_nodes() + + # Initialize conversation states (replacing checkpointer) + self.conversation_states = {} + + # Build graph + self.graph = self._build_graph() + + logger.info("Healthcare Agent initialized successfully") + + except Exception as e: + logger.error(f"Error initializing Healthcare Agent: {str(e)}") + raise HealthcareError( + message="Failed to initialize Healthcare Agent", + error_code="INIT_ERROR", + details={"error": str(e)} + ) + + def _initialize_tools(self) -> Dict: + """Initialize all tools used by the agent""" + return { + "patient": PatientTools(), + "resource": ResourceTools(), + "quality": QualityTools(), + "scheduling": SchedulingTools() + } + + def _initialize_nodes(self) -> Dict: + """Initialize all nodes in the agent workflow""" + return { + "input_analyzer": InputAnalyzerNode(self.llm), + "task_router": TaskRouterNode(), + "patient_flow": PatientFlowNode(self.llm), + "resource_manager": ResourceManagerNode(self.llm), + "quality_monitor": QualityMonitorNode(self.llm), + "staff_scheduler": StaffSchedulerNode(self.llm), + "output_synthesizer": OutputSynthesizerNode(self.llm) + } + + def _build_graph(self) -> StateGraph: + """Build the workflow graph with all nodes and edges""" + try: + # Initialize graph + builder = StateGraph(HospitalState) + + # Add all nodes + for name, node in self.nodes.items(): + builder.add_node(name, node) + + # Set entry point + builder.set_entry_point("input_analyzer") + + # Add edge from input analyzer to task router + builder.add_edge("input_analyzer", "task_router") + + # Define conditional routing based on task router output + def route_next(state: Dict): + return state["context"]["next_node"] + + # Add conditional edges from task router + builder.add_conditional_edges( + "task_router", + route_next, + { + "patient_flow": "patient_flow", + "resource_management": "resource_manager", + "quality_monitoring": "quality_monitor", + "staff_scheduling": "staff_scheduler", + "output_synthesis": "output_synthesizer" + } + ) + + # Add edges from functional nodes to output synthesizer + functional_nodes = [ + "patient_flow", + "resource_manager", + "quality_monitor", + "staff_scheduler" + ] + + for node in functional_nodes: + builder.add_edge(node, "output_synthesizer") + + # Add end condition + builder.add_edge("output_synthesizer", END) + + # Compile graph + return builder.compile() + + except Exception as e: + logger.error(f"Error building graph: {str(e)}") + raise HealthcareError( + message="Failed to build agent workflow graph", + error_code="GRAPH_BUILD_ERROR", + details={"error": str(e)} + ) + + @ErrorHandler.error_decorator + def process( + self, + input_text: str, + thread_id: Optional[str] = None, + context: Optional[Dict] = None + ) -> Dict: + """Process input through the healthcare operations workflow""" + try: + # Validate input + ErrorHandler.validate_input(input_text) + + # Create or use thread ID + thread_id = thread_id or str(uuid.uuid4()) + + # Initialize state + initial_state = create_initial_state(thread_id) + + # Add input message as HumanMessage object + initial_state["messages"].append( + HumanMessage(content=input_text) + ) + + # Add context if provided + if context: + initial_state["context"].update(context) + + # Validate state + validate_state(initial_state) + + # Store state in conversation states + self.conversation_states[thread_id] = initial_state + + # Process through graph + result = self.graph.invoke(initial_state) + + return self._format_response(result) + + except ValidationError as ve: + logger.error(f"Validation error: {str(ve)}") + raise + except Exception as e: + logger.error(f"Error processing input: {str(e)}") + raise HealthcareError( + message="Failed to process input", + error_code="PROCESSING_ERROR", + details={"error": str(e)} + ) + + def _format_response(self, result: Dict) -> Dict: + """Format the final response from the graph execution""" + try: + if not result or "messages" not in result: + raise ProcessingError( + message="Invalid result format", + error_code="INVALID_RESULT", + details={"result": str(result)} + ) + + return { + "response": result["messages"][-1].content if result["messages"] else "", + "analysis": result.get("analysis", {}), + "metrics": result.get("metrics", {}), + "timestamp": datetime.now() + } + except Exception as e: + logger.error(f"Error formatting response: {str(e)}") + raise HealthcareError( + message="Failed to format response", + error_code="FORMAT_ERROR", + details={"error": str(e)} + ) + + def get_conversation_history( + self, + thread_id: str + ) -> List[Dict]: + """Retrieve conversation history for a specific thread""" + try: + return self.conversation_states.get(thread_id, {}).get("messages", []) + except Exception as e: + logger.error(f"Error retrieving conversation history: {str(e)}") + raise HealthcareError( + message="Failed to retrieve conversation history", + error_code="HISTORY_ERROR", + details={"error": str(e)} + ) + + def reset_conversation( + self, + thread_id: str + ) -> bool: + """Reset conversation state for a specific thread""" + try: + self.conversation_states[thread_id] = create_initial_state(thread_id) + return True + except Exception as e: + logger.error(f"Error resetting conversation: {str(e)}") + raise HealthcareError( + message="Failed to reset conversation", + error_code="RESET_ERROR", + details={"error": str(e)} + ) \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42c0ae84a04b18568da045c755cb6a3c0aea8092 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,5 @@ +# src/config/__init__.py +from .settings import Settings +from .prompts import PROMPTS + +__all__ = ['Settings', 'PROMPTS'] \ No newline at end of file diff --git a/src/config/prompts.py b/src/config/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..b450e6a99145695f53fd3160efae17bf6a7e6578 --- /dev/null +++ b/src/config/prompts.py @@ -0,0 +1,75 @@ +# src/config/prompts.py +PROMPTS = { + "system": """You are an expert Healthcare Operations Management Assistant. +Your role is to optimize hospital operations through: +- Patient flow management +- Resource allocation +- Quality monitoring +- Staff scheduling + +Always maintain HIPAA compliance and healthcare standards in your responses. +Base your analysis on the provided metrics and department data.""", + + "input_analyzer": """Analyze the following input and determine: +1. Primary task category (patient_flow, resource_management, quality_monitoring, staff_scheduling) +2. Required context information +3. Priority level (1-5, where 5 is highest) +4. Relevant department(s) + +Current input: {input}""", + + "patient_flow": """Analyze patient flow based on: +- Current occupancy: {occupancy}% +- Waiting times: {wait_times} minutes +- Department capacity: {department_capacity} +- Admission rate: {admission_rate} per hour + +Provide specific recommendations for optimization.""", + + "resource_manager": """Evaluate resource utilization: +- Equipment availability: {equipment_status} +- Supply levels: {supply_levels} +- Resource allocation: {resource_allocation} +- Budget constraints: {budget_info} + +Recommend optimal resource distribution.""", + + "quality_monitor": """Review quality metrics: +- Patient satisfaction: {satisfaction_score}/10 +- Care outcomes: {care_outcomes} +- Compliance rates: {compliance_rates}% +- Incident reports: {incident_count} + +Identify areas for improvement.""", + + "staff_scheduler": """Optimize staff scheduling considering: +- Staff availability: {staff_available} +- Department needs: {department_needs} +- Skill mix requirements: {skill_requirements} +- Work hour regulations: {work_hours} + +Provide scheduling recommendations.""", + + "output_synthesis": """Synthesize findings and provide: +1. Key insights +2. Actionable recommendations +3. Priority actions +4. Implementation timeline + +Context: {context}""" +} + +# Error message templates +ERROR_MESSAGES = { + "invalid_input": "Invalid input provided. Please ensure all required information is included.", + "processing_error": "An error occurred while processing your request. Please try again.", + "data_validation": "Data validation failed. Please check the input format.", + "system_error": "System error encountered. Please contact support." +} + +# Response templates +RESPONSE_TEMPLATES = { + "confirmation": "Request received and being processed. Priority level: {priority}", + "completion": "Analysis complete. {summary}", + "error": "Error: {error_message}. Error code: {error_code}" +}# System prompts implementation diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..2a634b6f6570faa098ecee7df5eef3842009d5a4 --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,70 @@ +# src/config/settings.py +import os +from typing import Dict, Any +from dotenv import load_dotenv + +load_dotenv() + +class Settings: + """Configuration settings for the Healthcare Operations Management Agent""" + + # OpenAI Configuration + OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + MODEL_NAME = "gpt-4o-mini-2024-07-18" + MODEL_TEMPERATURE = 0 + + # LangGraph Configuration + MEMORY_TYPE = os.getenv("MEMORY_TYPE", "sqlite") + MEMORY_URI = os.getenv("MEMORY_URI", ":memory:") + + # Hospital Configuration + HOSPITAL_SETTINGS = { + "total_beds": 300, + "departments": ["ER", "ICU", "General", "Surgery", "Pediatrics"], + "staff_roles": ["Doctor", "Nurse", "Specialist", "Support Staff"] + } + + # Application Settings + MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3")) + REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "30")) + BATCH_SIZE = int(os.getenv("BATCH_SIZE", "10")) + + # Logging Configuration + LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") + LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + LOG_FILE = "logs/healthcare_ops_agent.log" + + # Quality Metrics Thresholds + QUALITY_THRESHOLDS = { + "min_satisfaction_score": 7.0, + "max_wait_time_minutes": 45, + "optimal_bed_utilization": 0.85, + "min_staff_ratio": { + "ICU": 0.5, # 1 nurse per 2 patients + "General": 0.25 # 1 nurse per 4 patients + } + } + + @classmethod + def get_model_config(cls) -> Dict[str, Any]: + """Get model configuration""" + return { + "model": cls.MODEL_NAME, + "temperature": cls.MODEL_TEMPERATURE, + "api_key": cls.OPENAI_API_KEY + } + + @classmethod + def validate_settings(cls) -> bool: + """Validate required settings""" + required_settings = [ + "OPENAI_API_KEY", + "MODEL_NAME", + "MEMORY_TYPE" + ] + + for setting in required_settings: + if not getattr(cls, setting): + raise ValueError(f"Missing required setting: {setting}") + + return True# Configuration settings implementation diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03372bc1f10bfd43369230818fb8330b1d802a60 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,32 @@ +# src/models/__init__.py +from .state import ( + TaskType, + PriorityLevel, + Department, + HospitalState, + PatientFlowMetrics, + ResourceMetrics, + QualityMetrics, + StaffingMetrics, + HospitalMetrics, + AnalysisResult, + create_initial_state, + validate_state, + update_state_metrics +) + +__all__ = [ + 'TaskType', + 'PriorityLevel', + 'Department', + 'HospitalState', + 'PatientFlowMetrics', + 'ResourceMetrics', + 'QualityMetrics', + 'StaffingMetrics', + 'HospitalMetrics', + 'AnalysisResult', + 'create_initial_state', + 'validate_state', + 'update_state_metrics' +] \ No newline at end of file diff --git a/src/models/state.py b/src/models/state.py new file mode 100644 index 0000000000000000000000000000000000000000..70751073d56e79158a552cba61b8d90b9bf6fc79 --- /dev/null +++ b/src/models/state.py @@ -0,0 +1,202 @@ +# src/models/state.py +from typing import Annotated, List, Dict, Optional +from typing_extensions import TypedDict # Changed this import +from langchain_core.messages import AnyMessage +from datetime import datetime +import operator +from enum import Enum +from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage + + +class TaskType(str, Enum): + PATIENT_FLOW = "patient_flow" + RESOURCE_MANAGEMENT = "resource_management" + QUALITY_MONITORING = "quality_monitoring" + STAFF_SCHEDULING = "staff_scheduling" + GENERAL = "general" + +class PriorityLevel(int, Enum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + URGENT = 4 + CRITICAL = 5 + +class Department(TypedDict): + """Department information""" + id: str + name: str + capacity: int + current_occupancy: int + staff_count: Dict[str, int] + wait_time: int + +class PatientFlowMetrics(TypedDict): + """Metrics related to patient flow""" + total_beds: int + occupied_beds: int + waiting_patients: int + average_wait_time: float + admission_rate: float + discharge_rate: float + department_metrics: Dict[str, "Department"] + +class ResourceMetrics(TypedDict): + """Metrics related to resource management""" + equipment_availability: Dict[str, bool] + supply_levels: Dict[str, float] + resource_utilization: float + pending_requests: int + critical_supplies: List[str] + +class QualityMetrics(TypedDict): + """Metrics related to quality monitoring""" + patient_satisfaction: float + care_outcomes: Dict[str, float] + compliance_rate: float + incident_count: int + quality_scores: Dict[str, float] + last_audit_date: datetime + +class StaffingMetrics(TypedDict): + """Metrics related to staff scheduling""" + total_staff: int + available_staff: Dict[str, int] + shifts_coverage: Dict[str, float] + overtime_hours: float + skill_mix_index: float + staff_satisfaction: float + +class HospitalMetrics(TypedDict): + """Combined hospital metrics""" + patient_flow: PatientFlowMetrics + resources: ResourceMetrics + quality: QualityMetrics + staffing: StaffingMetrics + last_updated: datetime + +class AnalysisResult(TypedDict): + """Analysis results from nodes""" + category: TaskType + priority: PriorityLevel + findings: List[str] + recommendations: List[str] + action_items: List[Dict[str, str]] + metrics_impact: Dict[str, float] + + +class HospitalState(TypedDict): + """Main state management for the agent""" + messages: Annotated[List[AnyMessage], operator.add] + current_task: TaskType + priority_level: PriorityLevel + department: Optional[str] + metrics: HospitalMetrics + analysis: Optional[AnalysisResult] + context: Dict[str, any] # Will include routing information + timestamp: datetime + thread_id: str + + +def create_initial_state(thread_id: str) -> HospitalState: + """Create initial state with default values""" + return { + "messages": [], + "current_task": TaskType.GENERAL, + "priority_level": PriorityLevel.MEDIUM, + "department": None, + "metrics": { + "patient_flow": { + "total_beds": 300, + "occupied_beds": 240, + "waiting_patients": 15, + "average_wait_time": 35.0, + "admission_rate": 4.2, + "discharge_rate": 3.8, + "department_metrics": {} + }, + "resources": { + "equipment_availability": {}, + "supply_levels": {}, + "resource_utilization": 0.75, + "pending_requests": 5, + "critical_supplies": [] + }, + "quality": { + "patient_satisfaction": 8.5, + "care_outcomes": {}, + "compliance_rate": 0.95, + "incident_count": 2, + "quality_scores": {}, + "last_audit_date": datetime.now() + }, + "staffing": { + "total_staff": 500, + "available_staff": { + "doctors": 50, + "nurses": 150, + "specialists": 30, + "support": 70 + }, + "shifts_coverage": {}, + "overtime_hours": 120.5, + "skill_mix_index": 0.85, + "staff_satisfaction": 7.8 + }, + "last_updated": datetime.now() + }, + "analysis": None, + "context": { + "next_node": None # Add routing context + }, + "timestamp": datetime.now(), + "thread_id": thread_id + } + +def validate_state(state: HospitalState) -> bool: + """Validate state structure and data types""" + try: + # Basic structure validation + required_keys = [ + "messages", "current_task", "priority_level", + "metrics", "timestamp", "thread_id" + ] + for key in required_keys: + if key not in state: + raise ValueError(f"Missing required key: {key}") + + # Validate messages + if not isinstance(state["messages"], list): + raise ValueError("Messages must be a list") + + # Validate each message has required attributes + for msg in state["messages"]: + if not hasattr(msg, 'content'): + raise ValueError("Invalid message format - missing content") + + # Validate types + if not isinstance(state["current_task"], TaskType): + raise ValueError("Invalid task type") + if not isinstance(state["priority_level"], PriorityLevel): + raise ValueError("Invalid priority level") + if not isinstance(state["timestamp"], datetime): + raise ValueError("Invalid timestamp") + + return True + + except Exception as e: + raise ValueError(f"State validation failed: {str(e)}") + +def update_state_metrics( + state: HospitalState, + new_metrics: Dict, + category: str +) -> HospitalState: + """Update specific category of metrics in state""" + if category not in state["metrics"]: + raise ValueError(f"Invalid metrics category: {category}") + + state["metrics"][category].update(new_metrics) + state["metrics"]["last_updated"] = datetime.now() + + return state \ No newline at end of file diff --git a/src/nodes/__init__.py b/src/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab24861dd700cda61bad08ec93ed1e29dc456a42 --- /dev/null +++ b/src/nodes/__init__.py @@ -0,0 +1,18 @@ +# src/nodes/__init__.py +from .input_analyzer import InputAnalyzerNode +from .task_router import TaskRouterNode +from .patient_flow import PatientFlowNode +from .resource_manager import ResourceManagerNode +from .quality_monitor import QualityMonitorNode +from .staff_scheduler import StaffSchedulerNode +from .output_synthesizer import OutputSynthesizerNode + +__all__ = [ + 'InputAnalyzerNode', + 'TaskRouterNode', + 'PatientFlowNode', + 'ResourceManagerNode', + 'QualityMonitorNode', + 'StaffSchedulerNode', + 'OutputSynthesizerNode' +] \ No newline at end of file diff --git a/src/nodes/input_analyzer.py b/src/nodes/input_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..c84c894c28780766e01f7828361c21c1c5f94d4e --- /dev/null +++ b/src/nodes/input_analyzer.py @@ -0,0 +1,85 @@ +# src/nodes/input_analyzer.py +#from typing import Dict +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +#from langchain_core.messages import SystemMessage, HumanMessage +from ..models.state import HospitalState, TaskType, PriorityLevel +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger +from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage + +logger = setup_logger(__name__) + +class InputAnalyzerNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["input_analyzer"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get the latest message + if not state["messages"]: + raise ValueError("No messages in state") + + latest_message = state["messages"][-1] + + # Ensure message is a LangChain message object + if not hasattr(latest_message, 'content'): + raise ValueError("Invalid message format") + + # Prepare messages for LLM + messages = [ + SystemMessage(content=self.system_prompt), + latest_message if isinstance(latest_message, HumanMessage) + else HumanMessage(content=str(latest_message)) + ] + + # Get LLM response + response = self.llm.invoke(messages) + + # Parse response to determine task type and priority + parsed_result = self._parse_llm_response(response.content) + + return { + "current_task": parsed_result["task_type"], + "priority_level": parsed_result["priority"], + "department": parsed_result["department"], + "context": parsed_result["context"] + } + + except Exception as e: + logger.error(f"Error in input analysis: {str(e)}") + raise + + def _parse_llm_response(self, response: str) -> Dict: + """Parse LLM response to extract task type and other metadata""" + try: + # Default values + result = { + "task_type": TaskType.GENERAL, + "priority": PriorityLevel.MEDIUM, + "department": None, + "context": {} + } + + # Simple parsing logic (can be made more robust) + if "patient flow" in response.lower(): + result["task_type"] = TaskType.PATIENT_FLOW + elif "resource" in response.lower(): + result["task_type"] = TaskType.RESOURCE_MANAGEMENT + elif "quality" in response.lower(): + result["task_type"] = TaskType.QUALITY_MONITORING + elif "staff" in response.lower() or "schedule" in response.lower(): + result["task_type"] = TaskType.STAFF_SCHEDULING + + # Extract priority from response + if "urgent" in response.lower() or "critical" in response.lower(): + result["priority"] = PriorityLevel.CRITICAL + elif "high" in response.lower(): + result["priority"] = PriorityLevel.HIGH + + return result + + except Exception as e: + logger.error(f"Error parsing LLM response: {str(e)}") + return result diff --git a/src/nodes/output_synthesizer.py b/src/nodes/output_synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8ea8ca5274c7fb88a58625a0b5e062eef5fdb8 --- /dev/null +++ b/src/nodes/output_synthesizer.py @@ -0,0 +1,109 @@ +# src/nodes/output_synthesizer.py +#from typing import Dict, List +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.messages import SystemMessage +from ..models.state import HospitalState +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class OutputSynthesizerNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["output_synthesis"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get analysis results from previous nodes + analysis = state.get("analysis", {}) + + # Format prompt with context + formatted_prompt = self.system_prompt.format( + context=self._format_context(state) + ) + + # Get LLM synthesis + response = self.llm.invoke([ + SystemMessage(content=formatted_prompt) + ]) + + # Structure the final output + final_output = self._structure_final_output( + response.content, + state["current_task"], + state["priority_level"] + ) + + return { + "messages": [response], + "analysis": final_output + } + + except Exception as e: + logger.error(f"Error in output synthesis: {str(e)}") + raise + + def _format_context(self, state: HospitalState) -> str: + """Format all relevant context for synthesis""" + return f""" +Task Type: {state['current_task']} +Priority Level: {state['priority_level']} +Department: {state['department'] or 'All Departments'} +Key Metrics Summary: +- Patient Flow: {self._summarize_patient_flow(state)} +- Resources: {self._summarize_resources(state)} +- Quality: {self._summarize_quality(state)} +- Staffing: {self._summarize_staffing(state)} + """ + + def _structure_final_output(self, response: str, task_type: str, priority: int) -> Dict: + """Structure the final output in a standardized format""" + return { + "summary": self._extract_summary(response), + "key_findings": self._extract_key_findings(response), + "recommendations": self._extract_recommendations(response), + "action_items": self._extract_action_items(response), + "priority_level": priority, + "task_type": task_type + } + + def _summarize_patient_flow(self, state: HospitalState) -> str: + metrics = state["metrics"]["patient_flow"] + return f"Occupancy {(metrics['occupied_beds']/metrics['total_beds'])*100:.1f}%" + + def _summarize_resources(self, state: HospitalState) -> str: + metrics = state["metrics"]["resources"] + return f"Utilization {metrics['resource_utilization']*100:.1f}%" + + def _summarize_quality(self, state: HospitalState) -> str: + metrics = state["metrics"]["quality"] + return f"Satisfaction {metrics['patient_satisfaction']:.1f}/10" + + def _summarize_staffing(self, state: HospitalState) -> str: + metrics = state["metrics"]["staffing"] + return f"Staff Available: {sum(metrics['available_staff'].values())}" + + def _extract_summary(self, response: str) -> str: + """Extract high-level summary from response""" + # Implementation depends on response structure + return response.split('\n')[0] + + def _extract_key_findings(self, response: str) -> List[str]: + """Extract key findings from response""" + findings = [] + # Implementation for parsing findings + return findings + + def _extract_recommendations(self, response: str) -> List[str]: + """Extract recommendations from response""" + recommendations = [] + # Implementation for parsing recommendations + return recommendations + + def _extract_action_items(self, response: str) -> List[Dict]: + """Extract actionable items from response""" + action_items = [] + # Implementation for parsing action items + return action_items# output_synthesizer node implementation diff --git a/src/nodes/patient_flow.py b/src/nodes/patient_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..05be055f366fa93294d8f950e747fac87d881e29 --- /dev/null +++ b/src/nodes/patient_flow.py @@ -0,0 +1,62 @@ +# src/nodes/patient_flow.py +#from typing import Dict +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.messages import SystemMessage +from ..models.state import HospitalState +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class PatientFlowNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["patient_flow"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get current metrics + metrics = state["metrics"]["patient_flow"] + + # Format prompt with current metrics + formatted_prompt = self.system_prompt.format( + occupancy=self._calculate_occupancy(metrics), + wait_times=metrics["average_wait_time"], + department_capacity=self._get_department_capacity(metrics), + admission_rate=metrics["admission_rate"] + ) + + # Get LLM analysis + response = self.llm.invoke([ + SystemMessage(content=formatted_prompt) + ]) + + # Parse and structure the response + analysis = self._structure_analysis(response.content) + + return { + "analysis": analysis, + "messages": [response] + } + + except Exception as e: + logger.error(f"Error in patient flow analysis: {str(e)}") + raise + + def _calculate_occupancy(self, metrics: Dict) -> float: + """Calculate current occupancy percentage""" + return (metrics["occupied_beds"] / metrics["total_beds"]) * 100 + + def _get_department_capacity(self, metrics: Dict) -> Dict: + """Get capacity details by department""" + return metrics.get("department_metrics", {}) + + def _structure_analysis(self, response: str) -> Dict: + """Structure the LLM response into a standardized format""" + return { + "findings": [], # Extract key findings + "recommendations": [], # Extract recommendations + "action_items": [], # Extract action items + "metrics_impact": {} # Expected impact on metrics + }# patient_flow node implementation diff --git a/src/nodes/quality_monitor.py b/src/nodes/quality_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb1dd79a954f04d845e9f4a2eca63c9a1e50bdb --- /dev/null +++ b/src/nodes/quality_monitor.py @@ -0,0 +1,97 @@ +# src/nodes/quality_monitor.py + +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.messages import SystemMessage +from ..models.state import HospitalState +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class QualityMonitorNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["quality_monitor"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get current quality metrics + metrics = state["metrics"]["quality"] + + # Format prompt with current metrics + formatted_prompt = self.system_prompt.format( + satisfaction_score=metrics["patient_satisfaction"], + care_outcomes=self._format_care_outcomes(metrics), + compliance_rates=metrics["compliance_rate"] * 100, + incident_count=metrics["incident_count"] + ) + + # Get LLM analysis + response = self.llm.invoke([ + SystemMessage(content=formatted_prompt) + ]) + + # Process quality assessment + analysis = self._analyze_quality_metrics(response.content, metrics) + + return { + "analysis": analysis, + "messages": [response], + "context": { + "quality_scores": metrics["quality_scores"], + "last_audit": metrics["last_audit_date"] + } + } + + except Exception as e: + logger.error(f"Error in quality monitoring analysis: {str(e)}") + raise + + def _format_care_outcomes(self, metrics: Dict) -> str: + """Format care outcomes into readable text""" + outcomes = [] + for metric, value in metrics["care_outcomes"].items(): + outcomes.append(f"{metric}: {value:.1f}") + return ", ".join(outcomes) + + def _analyze_quality_metrics(self, response: str, metrics: Dict) -> Dict: + """Analyze quality metrics and identify areas for improvement""" + return { + "satisfaction_analysis": self._analyze_satisfaction(metrics), + "compliance_analysis": self._analyze_compliance(metrics), + "incident_analysis": self._analyze_incidents(metrics), + "recommendations": self._extract_recommendations(response), + "priority_improvements": [] + } + + def _analyze_satisfaction(self, metrics: Dict) -> Dict: + """Analyze patient satisfaction trends""" + satisfaction = metrics["patient_satisfaction"] + return { + "current_score": satisfaction, + "status": "Good" if satisfaction >= 8.0 else "Needs Improvement", + "trend": "Unknown" # Would need historical data + } + + def _analyze_compliance(self, metrics: Dict) -> Dict: + """Analyze compliance rates""" + return { + "rate": metrics["compliance_rate"], + "status": "Compliant" if metrics["compliance_rate"] >= 0.95 else "Review Required" + } + + def _analyze_incidents(self, metrics: Dict) -> Dict: + """Analyze incident reports""" + return { + "count": metrics["incident_count"], + "severity": "High" if metrics["incident_count"] > 5 else "Low" + } + + def _extract_recommendations(self, response: str) -> List[str]: + """Extract recommendations from LLM response""" + recommendations = [] + for line in response.split('\n'): + if 'recommend' in line.lower() or 'suggest' in line.lower(): + recommendations.append(line.strip()) + return recommendations \ No newline at end of file diff --git a/src/nodes/resource_manager.py b/src/nodes/resource_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a45afa8a9fd3b23455bb87a806ea5f405777ccb3 --- /dev/null +++ b/src/nodes/resource_manager.py @@ -0,0 +1,78 @@ +# src/nodes/resource_manager.py +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +#from typing import Dict +from langchain_core.messages import SystemMessage +from ..models.state import HospitalState +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class ResourceManagerNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["resource_manager"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get current resource metrics + metrics = state["metrics"]["resources"] + + # Format prompt with current metrics + formatted_prompt = self.system_prompt.format( + equipment_status=self._format_equipment_status(metrics), + supply_levels=self._format_supply_levels(metrics), + resource_allocation=metrics["resource_utilization"], + budget_info=self._get_budget_info(state) + ) + + # Get LLM analysis + response = self.llm.invoke([ + SystemMessage(content=formatted_prompt) + ]) + + # Update state with recommendations + analysis = self._parse_recommendations(response.content) + + return { + "analysis": analysis, + "messages": [response], + "context": { + "critical_supplies": metrics["critical_supplies"], + "pending_requests": metrics["pending_requests"] + } + } + + except Exception as e: + logger.error(f"Error in resource management analysis: {str(e)}") + raise + + def _format_equipment_status(self, metrics: Dict) -> str: + """Format equipment availability into readable text""" + status = [] + for equip, available in metrics["equipment_availability"].items(): + status.append(f"{equip}: {'Available' if available else 'In Use'}") + return ", ".join(status) + + def _format_supply_levels(self, metrics: Dict) -> str: + """Format supply levels into readable text""" + levels = [] + for item, level in metrics["supply_levels"].items(): + status = "Critical" if level < 0.2 else "Low" if level < 0.4 else "Adequate" + levels.append(f"{item}: {status} ({level*100:.0f}%)") + return ", ".join(levels) + + def _get_budget_info(self, state: HospitalState) -> str: + """Get budget information from context""" + return state.get("context", {}).get("budget_info", "Budget information not available") + + def _parse_recommendations(self, response: str) -> Dict: + """Parse LLM recommendations into structured format""" + return { + "resource_optimization": [], + "supply_management": [], + "equipment_maintenance": [], + "budget_allocation": [], + "priority_actions": [] + }# resource_manager node implementation diff --git a/src/nodes/staff_scheduler.py b/src/nodes/staff_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c90d4d79983df3987022e05e47828c0f55c6b6a --- /dev/null +++ b/src/nodes/staff_scheduler.py @@ -0,0 +1,76 @@ +# src/nodes/staff_scheduler.py +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.messages import SystemMessage +from ..models.state import HospitalState +from ..config.prompts import PROMPTS +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class StaffSchedulerNode: + def __init__(self, llm): + self.llm = llm + self.system_prompt = PROMPTS["staff_scheduler"] + + def __call__(self, state: HospitalState) -> Dict: + try: + # Get current staffing metrics + metrics = state["metrics"]["staffing"] + + # Format prompt with current metrics + formatted_prompt = self.system_prompt.format( + staff_available=self._format_staff_availability(metrics), + department_needs=self._get_department_needs(state), + skill_requirements=self._format_skill_requirements(metrics), + work_hours=metrics["overtime_hours"] + ) + + # Get LLM analysis + response = self.llm.invoke([ + SystemMessage(content=formatted_prompt) + ]) + + # Generate scheduling recommendations + analysis = self._generate_schedule_recommendations(response.content, metrics) + + return { + "analysis": analysis, + "messages": [response], + "context": { + "staff_satisfaction": metrics["staff_satisfaction"], + "skill_mix_index": metrics["skill_mix_index"] + } + } + + except Exception as e: + logger.error(f"Error in staff scheduling analysis: {str(e)}") + raise + + def _format_staff_availability(self, metrics: Dict) -> str: + """Format staff availability into readable text""" + return ", ".join([ + f"{role}: {count} available" + for role, count in metrics["available_staff"].items() + ]) + + def _get_department_needs(self, state: HospitalState) -> Dict: + """Get staffing needs by department""" + return { + dept: metrics + for dept, metrics in state["metrics"]["patient_flow"]["department_metrics"].items() + } + + def _format_skill_requirements(self, metrics: Dict) -> str: + """Format skill requirements into readable text""" + return f"Skill Mix Index: {metrics['skill_mix_index']:.2f}" + + def _generate_schedule_recommendations(self, response: str, metrics: Dict) -> Dict: + """Generate scheduling recommendations based on LLM response""" + return { + "shift_adjustments": [], + "staff_assignments": {}, + "overtime_recommendations": [], + "training_needs": [], + "efficiency_improvements": [] + }# staff_scheduler node implementation diff --git a/src/nodes/task_router.py b/src/nodes/task_router.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7e894076269135f0ed92d8574799e6df83de55 --- /dev/null +++ b/src/nodes/task_router.py @@ -0,0 +1,45 @@ +# src/nodes/task_router.py +from typing import Literal +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from ..models.state import HospitalState, TaskType +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class TaskRouterNode: + def __call__(self, state: HospitalState) -> Dict: + """Route to appropriate node based on task type and return state update""" + try: + task_type = state["current_task"] + + # Create base state update + state_update = { + "messages": state.get("messages", []), + "current_task": task_type, + "priority_level": state.get("priority_level"), + "context": state.get("context", {}) + } + + # Add routing information to context + if task_type == TaskType.PATIENT_FLOW: + state_update["context"]["next_node"] = "patient_flow" + elif task_type == TaskType.RESOURCE_MANAGEMENT: + state_update["context"]["next_node"] = "resource_management" + elif task_type == TaskType.QUALITY_MONITORING: + state_update["context"]["next_node"] = "quality_monitoring" + elif task_type == TaskType.STAFF_SCHEDULING: + state_update["context"]["next_node"] = "staff_scheduling" + else: + state_update["context"]["next_node"] = "output_synthesis" + + return state_update + + except Exception as e: + logger.error(f"Error in task routing: {str(e)}") + # Return default routing to output synthesis on error + return { + "messages": state.get("messages", []), + "context": {"next_node": "output_synthesis"}, + "current_task": state.get("current_task") + } \ No newline at end of file diff --git a/src/tools/__init__.py b/src/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a18a402c2c901a242fcbd716fbb1784d5e85d65 --- /dev/null +++ b/src/tools/__init__.py @@ -0,0 +1,12 @@ +# src/tools/__init__.py +from .patient_tools import PatientTools +from .resource_tools import ResourceTools +from .quality_tools import QualityTools +from .scheduling_tools import SchedulingTools + +__all__ = [ + 'PatientTools', + 'ResourceTools', + 'QualityTools', + 'SchedulingTools' +] \ No newline at end of file diff --git a/src/tools/patient_tools.py b/src/tools/patient_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..043b0e0a65ca49a91274e913b52419726bffe469 --- /dev/null +++ b/src/tools/patient_tools.py @@ -0,0 +1,171 @@ +# src/tools/patient_tools.py +#from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from typing import Dict, List, Optional +from langchain_core.tools import tool +from datetime import datetime +from ..utils.logger import setup_logger +from ..models.state import Department + +logger = setup_logger(__name__) + +class PatientTools: + @tool + def calculate_wait_time( + self, + department: str, + current_queue: int, + staff_available: int + ) -> float: + """Calculate estimated wait time for a department based on queue and staff""" + try: + # Average time per patient (in minutes) + AVG_TIME_PER_PATIENT = 15 + + # Factor in staff availability + wait_time = (current_queue * AVG_TIME_PER_PATIENT) / max(staff_available, 1) + + return round(wait_time, 1) + + except Exception as e: + logger.error(f"Error calculating wait time: {str(e)}") + raise + + @tool + def analyze_bed_capacity( + self, + total_beds: int, + occupied_beds: int, + pending_admissions: int + ) -> Dict: + """Analyze bed capacity and provide utilization metrics""" + try: + capacity = { + "total_beds": total_beds, + "occupied_beds": occupied_beds, + "available_beds": total_beds - occupied_beds, + "utilization_rate": (occupied_beds / total_beds) * 100, + "pending_admissions": pending_admissions, + "status": "Normal" + } + + # Determine status based on utilization + if capacity["utilization_rate"] > 90: + capacity["status"] = "Critical" + elif capacity["utilization_rate"] > 80: + capacity["status"] = "High" + + return capacity + + except Exception as e: + logger.error(f"Error analyzing bed capacity: {str(e)}") + raise + + @tool + def predict_discharge_time( + self, + admission_date: datetime, + condition_type: str, + department: str + ) -> datetime: + """Predict expected discharge time based on condition and department""" + try: + # Average length of stay (in days) by condition + LOS_BY_CONDITION = { + "routine": 3, + "acute": 5, + "critical": 7, + "emergency": 2 + } + + # Get base length of stay + base_los = LOS_BY_CONDITION.get(condition_type.lower(), 4) + + # Adjust based on department + if department.lower() == "icu": + base_los *= 1.5 + + # Calculate expected discharge date + discharge_date = admission_date + timedelta(days=base_los) + + return discharge_date + + except Exception as e: + logger.error(f"Error predicting discharge time: {str(e)}") + raise + + @tool + def optimize_patient_flow( + self, + departments: List[Department], + waiting_patients: List[Dict] + ) -> Dict: + """Optimize patient flow across departments""" + try: + optimization_result = { + "department_recommendations": {}, + "patient_transfers": [], + "capacity_alerts": [] + } + + for dept in departments: + # Calculate department capacity + utilization = dept["current_occupancy"] / dept["capacity"] + + if utilization > 0.9: + optimization_result["capacity_alerts"].append({ + "department": dept["name"], + "alert": "Critical capacity", + "utilization": utilization + }) + + # Recommend transfers if needed + if utilization > 0.85: + optimization_result["patient_transfers"].append({ + "from_dept": dept["name"], + "recommended_transfers": max(1, int((utilization - 0.8) * dept["capacity"])) + }) + + return optimization_result + + except Exception as e: + logger.error(f"Error optimizing patient flow: {str(e)}") + raise + + @tool + def assess_admission_priority( + self, + patient_condition: str, + wait_time: float, + department_load: float + ) -> Dict: + """Assess admission priority based on multiple factors""" + try: + # Base priority scores + CONDITION_SCORES = { + "critical": 10, + "urgent": 8, + "moderate": 5, + "routine": 3 + } + + # Calculate priority score + base_score = CONDITION_SCORES.get(patient_condition.lower(), 3) + wait_factor = min(wait_time / 30, 2) # Cap wait time factor at 2 + load_penalty = department_load if department_load > 0.8 else 0 + + final_score = base_score + wait_factor - load_penalty + + return { + "priority_score": round(final_score, 2), + "priority_level": "High" if final_score > 7 else "Medium" if final_score > 4 else "Low", + "factors": { + "condition_score": base_score, + "wait_factor": round(wait_factor, 2), + "load_penalty": round(load_penalty, 2) + } + } + + except Exception as e: + logger.error(f"Error assessing admission priority: {str(e)}") + raise# patient_tools implementation diff --git a/src/tools/quality_tools.py b/src/tools/quality_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d9871914ff8eaf7be3c942e320828c3b6f338464 --- /dev/null +++ b/src/tools/quality_tools.py @@ -0,0 +1,176 @@ +# src/tools/quality_tools.py +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.tools import tool +from datetime import datetime, timedelta +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class QualityTools: + @tool + def analyze_patient_satisfaction( + self, + satisfaction_scores: List[float], + feedback_comments: List[str], + department: Optional[str] = None + ) -> Dict: + """Analyze patient satisfaction scores and feedback""" + try: + analysis = { + "metrics": { + "average_score": sum(satisfaction_scores) / len(satisfaction_scores), + "total_responses": len(satisfaction_scores), + "score_distribution": {}, + "trend": "stable" + }, + "feedback_analysis": { + "positive_themes": [], + "negative_themes": [], + "improvement_areas": [] + }, + "recommendations": [] + } + + # Analyze score distribution + for score in satisfaction_scores: + category = int(score) + analysis["metrics"]["score_distribution"][category] = \ + analysis["metrics"]["score_distribution"].get(category, 0) + 1 + + # Basic sentiment analysis of feedback + positive_keywords = ["great", "excellent", "good", "satisfied", "helpful"] + negative_keywords = ["poor", "bad", "slow", "unhappy", "dissatisfied"] + + for comment in feedback_comments: + comment_lower = comment.lower() + + # Analyze positive feedback + for keyword in positive_keywords: + if keyword in comment_lower: + analysis["feedback_analysis"]["positive_themes"].append(keyword) + + # Analyze negative feedback + for keyword in negative_keywords: + if keyword in comment_lower: + analysis["feedback_analysis"]["negative_themes"].append(keyword) + + # Generate recommendations + if analysis["metrics"]["average_score"] < 7.0: + analysis["recommendations"].append("Implement immediate satisfaction improvement plan") + + return analysis + + except Exception as e: + logger.error(f"Error analyzing patient satisfaction: {str(e)}") + raise + + @tool + def monitor_clinical_outcomes( + self, + outcomes_data: List[Dict], + benchmark_metrics: Dict[str, float] + ) -> Dict: + """Monitor and analyze clinical outcomes against benchmarks""" + try: + analysis = { + "outcome_metrics": {}, + "benchmark_comparison": {}, + "critical_deviations": [], + "success_areas": [] + } + + # Analyze outcomes by category + for outcome in outcomes_data: + category = outcome["category"] + if category not in analysis["outcome_metrics"]: + analysis["outcome_metrics"][category] = { + "success_rate": 0, + "complication_rate": 0, + "readmission_rate": 0, + "total_cases": 0 + } + + # Update metrics + metrics = analysis["outcome_metrics"][category] + metrics["total_cases"] += 1 + metrics["success_rate"] = (metrics["success_rate"] * (metrics["total_cases"] - 1) + + outcome["success"]) / metrics["total_cases"] + + # Compare with benchmarks + if category in benchmark_metrics: + benchmark = benchmark_metrics[category] + deviation = metrics["success_rate"] - benchmark + + if deviation < -0.1: # More than 10% below benchmark + analysis["critical_deviations"].append({ + "category": category, + "deviation": deviation, + "current_rate": metrics["success_rate"], + "benchmark": benchmark + }) + elif deviation > 0.05: # More than 5% above benchmark + analysis["success_areas"].append({ + "category": category, + "improvement": deviation, + "current_rate": metrics["success_rate"] + }) + + return analysis + + except Exception as e: + logger.error(f"Error monitoring clinical outcomes: {str(e)}") + raise + + @tool + def track_compliance_metrics( + self, + compliance_data: List[Dict], + audit_period: str + ) -> Dict: + """Track and analyze compliance with medical standards and regulations""" + try: + analysis = { + "compliance_rate": 0, + "violations": [], + "risk_areas": [], + "audit_summary": { + "period": audit_period, + "total_checks": len(compliance_data), + "passed_checks": 0, + "failed_checks": 0 + } + } + + # Analyze compliance checks + for check in compliance_data: + if check["compliant"]: + analysis["audit_summary"]["passed_checks"] += 1 + else: + analysis["audit_summary"]["failed_checks"] += 1 + analysis["violations"].append({ + "standard": check["standard"], + "severity": check["severity"], + "date": check["date"] + }) + + # Identify risk areas + if check["severity"] == "high" or check.get("repeat_violation", False): + analysis["risk_areas"].append({ + "area": check["standard"], + "risk_level": "high", + "recommendations": ["Immediate action required", + "Staff training needed"] + }) + + # Calculate overall compliance rate + total_checks = analysis["audit_summary"]["total_checks"] + if total_checks > 0: + analysis["compliance_rate"] = (analysis["audit_summary"]["passed_checks"] / + total_checks * 100) + + return analysis + + except Exception as e: + logger.error(f"Error tracking compliance metrics: {str(e)}") + raise# quality_tools implementation diff --git a/src/tools/resource_tools.py b/src/tools/resource_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..72f0d7e601a6c6c55b8b6e870eae87df038e9102 --- /dev/null +++ b/src/tools/resource_tools.py @@ -0,0 +1,130 @@ +# src/tools/resource_tools.py +#from typing import Dict, List +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.tools import tool +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class ResourceTools: + @tool + def analyze_supply_levels( + self, + current_inventory: Dict[str, float], + consumption_rate: Dict[str, float], + reorder_thresholds: Dict[str, float] + ) -> Dict: + """Analyze supply levels and generate reorder recommendations""" + try: + analysis = { + "critical_items": [], + "reorder_needed": [], + "adequate_supplies": [], + "recommendations": [] + } + + for item, level in current_inventory.items(): + threshold = reorder_thresholds.get(item, 0.2) + consumption = consumption_rate.get(item, 0) + + # Days of supply remaining + days_remaining = level / consumption if consumption > 0 else float('inf') + + if level <= threshold: + if days_remaining < 2: + analysis["critical_items"].append({ + "item": item, + "current_level": level, + "days_remaining": days_remaining + }) + else: + analysis["reorder_needed"].append({ + "item": item, + "current_level": level, + "days_remaining": days_remaining + }) + else: + analysis["adequate_supplies"].append(item) + + return analysis + + except Exception as e: + logger.error(f"Error analyzing supply levels: {str(e)}") + raise + + @tool + def track_equipment_utilization( + self, + equipment_logs: List[Dict], + equipment_capacity: Dict[str, int] + ) -> Dict: + """Track and analyze equipment utilization rates""" + try: + utilization = { + "equipment_stats": {}, + "underutilized": [], + "optimal": [], + "overutilized": [] + } + + for equip, capacity in equipment_capacity.items(): + usage = len([log for log in equipment_logs if log["equipment"] == equip]) + utilization_rate = usage / capacity + + utilization["equipment_stats"][equip] = { + "usage": usage, + "capacity": capacity, + "utilization_rate": utilization_rate + } + + if utilization_rate < 0.3: + utilization["underutilized"].append(equip) + elif utilization_rate > 0.8: + utilization["overutilized"].append(equip) + else: + utilization["optimal"].append(equip) + + return utilization + + except Exception as e: + logger.error(f"Error tracking equipment utilization: {str(e)}") + raise + + @tool + def optimize_resource_allocation( + self, + department_demands: Dict[str, Dict], + available_resources: Dict[str, int] + ) -> Dict: + """Optimize resource allocation across departments""" + try: + allocation = { + "recommended_distribution": {}, + "unmet_demands": [], + "resource_sharing": [] + } + + total_demand = sum(dept["demand"] for dept in department_demands.values()) + + for dept, demand in department_demands.items(): + # Calculate fair share based on demand + for resource, available in available_resources.items(): + dept_share = int((demand["demand"] / total_demand) * available) + + allocation["recommended_distribution"][dept] = { + resource: dept_share + } + + if dept_share < demand.get("minimum", 0): + allocation["unmet_demands"].append({ + "department": dept, + "resource": resource, + "shortfall": demand["minimum"] - dept_share + }) + + return allocation + + except Exception as e: + logger.error(f"Error optimizing resource allocation: {str(e)}") + raise# resource_tools implementation diff --git a/src/tools/scheduling_tools.py b/src/tools/scheduling_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..2c080687b17562d4dbd9759e604fa40f5995b365 --- /dev/null +++ b/src/tools/scheduling_tools.py @@ -0,0 +1,160 @@ +# src/tools/scheduling_tools.py +#from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any +from typing_extensions import TypedDict # If using TypedDict +from langchain_core.tools import tool +from datetime import datetime, timedelta +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class SchedulingTools: + @tool + def optimize_staff_schedule( + self, + staff_availability: List[Dict], + department_needs: Dict[str, Dict], + shift_preferences: Optional[List[Dict]] = None + ) -> Dict: + """Generate optimized staff schedules based on availability and department needs""" + try: + schedule = { + "shifts": {}, + "coverage_gaps": [], + "recommendations": [], + "staff_assignments": {} + } + + # Process each department's needs + for dept, needs in department_needs.items(): + schedule["shifts"][dept] = { + "morning": [], + "afternoon": [], + "night": [] + } + + required_staff = needs.get("required_staff", {}) + + # Match available staff to shifts + for staff in staff_availability: + if staff["department"] == dept and staff["available"]: + preferred_shift = "morning" # Default + if shift_preferences: + for pref in shift_preferences: + if pref["staff_id"] == staff["id"]: + preferred_shift = pref["preferred_shift"] + + schedule["shifts"][dept][preferred_shift].append(staff["id"]) + + # Identify coverage gaps + for shift in ["morning", "afternoon", "night"]: + required = required_staff.get(shift, 0) + assigned = len(schedule["shifts"][dept][shift]) + + if assigned < required: + schedule["coverage_gaps"].append({ + "department": dept, + "shift": shift, + "shortage": required - assigned + }) + + return schedule + + except Exception as e: + logger.error(f"Error optimizing staff schedule: {str(e)}") + raise + + @tool + def analyze_workforce_metrics( + self, + staff_data: List[Dict], + time_period: str + ) -> Dict: + """Analyze workforce metrics including overtime, satisfaction, and skill mix""" + try: + analysis = { + "workforce_metrics": { + "total_staff": len(staff_data), + "overtime_hours": 0, + "skill_distribution": {}, + "satisfaction_score": 0, + "turnover_rate": 0 + }, + "recommendations": [] + } + + total_satisfaction = 0 + total_overtime = 0 + + for staff in staff_data: + # Analyze overtime + total_overtime += staff.get("overtime_hours", 0) + + # Track skill distribution + role = staff.get("role", "unknown") + analysis["workforce_metrics"]["skill_distribution"][role] = \ + analysis["workforce_metrics"]["skill_distribution"].get(role, 0) + 1 + + # Track satisfaction + total_satisfaction += staff.get("satisfaction_score", 0) + + # Calculate averages + if staff_data: + analysis["workforce_metrics"]["overtime_hours"] = total_overtime / len(staff_data) + analysis["workforce_metrics"]["satisfaction_score"] = \ + total_satisfaction / len(staff_data) + + # Generate recommendations + if analysis["workforce_metrics"]["overtime_hours"] > 10: + analysis["recommendations"].append("Reduce overtime hours through better scheduling") + + if analysis["workforce_metrics"]["satisfaction_score"] < 7: + analysis["recommendations"].append("Implement staff satisfaction improvement measures") + + return analysis + + except Exception as e: + logger.error(f"Error analyzing workforce metrics: {str(e)}") + raise + + @tool + def calculate_staffing_needs( + self, + patient_census: Dict[str, int], + acuity_levels: Dict[str, float], + staff_ratios: Dict[str, float] + ) -> Dict: + """Calculate staffing needs based on patient census and acuity""" + try: + staffing_needs = { + "required_staff": {}, + "current_gaps": {}, + "recommendations": [] + } + + for department, census in patient_census.items(): + # Calculate base staffing need + acuity = acuity_levels.get(department, 1.0) + ratio = staff_ratios.get(department, 4) # default 1:4 ratio + + required_staff = ceil(census * acuity / ratio) + + staffing_needs["required_staff"][department] = { + "total_needed": required_staff, + "acuity_factor": acuity, + "patient_ratio": ratio + } + + # Generate staffing recommendations + if required_staff > current_staff.get(department, 0): + staffing_needs["recommendations"].append({ + "department": department, + "action": "increase_staff", + "amount": required_staff - current_staff.get(department, 0) + }) + + return staffing_needs + + except Exception as e: + logger.error(f"Error calculating staffing needs: {str(e)}") + raise# scheduling_tools implementation diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d131cb1e1e48291cac392ff2052ef7f2a03f5e --- /dev/null +++ b/src/ui/__init__.py @@ -0,0 +1,3 @@ +from .app import HealthcareUI + +__all__ = ['HealthcareUI'] \ No newline at end of file diff --git a/src/ui/app.py b/src/ui/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5da1ce0c4e8f550a2914a152021e5edf4d4f8463 --- /dev/null +++ b/src/ui/app.py @@ -0,0 +1,522 @@ +import streamlit as st +from datetime import datetime +from typing import Optional, Dict, Any +import os + +from ..agent import HealthcareAgent +from ..models.state import TaskType, PriorityLevel +from ..utils.logger import setup_logger + +logger = setup_logger(__name__) + +class HealthcareUI: + def __init__(self): + """Initialize the Healthcare Operations Management UI""" + try: + # Set up Streamlit page configuration + st.set_page_config( + page_title="Healthcare Operations Assistant", + page_icon="๐Ÿฅ", + layout="wide", + initial_sidebar_state="expanded", + menu_items={ + 'About': "Healthcare Operations Management AI Assistant", + 'Report a bug': "https://github.com/yourusername/repo/issues", + 'Get Help': "https://your-docs-url" + } + ) + + # Apply custom theme + self.setup_theme() + + # Initialize the agent + self.agent = HealthcareAgent(os.getenv("OPENAI_API_KEY")) + + # Initialize session state variables only if not already set + if 'initialized' not in st.session_state: + st.session_state.initialized = True + st.session_state.messages = [] + st.session_state.thread_id = datetime.now().strftime("%Y%m%d-%H%M%S") + st.session_state.current_department = "All Departments" + st.session_state.metrics_history = [] + st.session_state.system_status = True + + except Exception as e: + logger.error(f"Error initializing UI: {str(e)}") + st.error("Failed to initialize the application. Please refresh the page.") + + def setup_theme(self): + """Configure the UI theme and styling""" + st.markdown(""" + + """, unsafe_allow_html=True) + + def render_header(self): + """Render the application header""" + try: + header_container = st.container() + with header_container: + col1, col2, col3 = st.columns([1, 4, 1]) + + with col1: + st.markdown("# ๐Ÿฅ") + + with col2: + st.title("Healthcare Operations Assistant") + st.markdown("*Your AI-powered healthcare operations management solution* ๐Ÿค–") + + with col3: + # System status indicator + status = "๐ŸŸข Online" if st.session_state.system_status else "๐Ÿ”ด Offline" + st.markdown(f"### {status}") + + except Exception as e: + logger.error(f"Error rendering header: {str(e)}") + st.error("Error loading header section") + + def render_metrics(self, metrics: Optional[Dict[str, Any]] = None): + """Render the metrics dashboard""" + try: + if not metrics: + metrics = { + "patient_flow": {"occupied_beds": 75, "total_beds": 100}, + "quality": {"patient_satisfaction": 8.5}, + "staffing": {"available_staff": {"doctors": 20, "nurses": 50}}, + "resources": {"resource_utilization": 0.75} + } + + st.markdown("### ๐Ÿ“Š Key Metrics Dashboard") + metrics_container = st.container() + + with metrics_container: + # First row - Key metrics + col1, col2, col3, col4 = st.columns(4) + + with col1: + occupancy = (metrics['patient_flow']['occupied_beds'] / + metrics['patient_flow']['total_beds'] * 100) + st.metric( + "Bed Occupancy ๐Ÿ›๏ธ", + f"{occupancy:.1f}%", + "Normal ๐ŸŸข" if occupancy < 85 else "High ๐ŸŸก" + ) + + with col2: + satisfaction = metrics['quality']['patient_satisfaction'] + st.metric( + "Patient Satisfaction ๐Ÿ˜Š", + f"{satisfaction}/10", + "โ†— +0.5" if satisfaction > 8 else "โ†˜ -0.3" + ) + + with col3: + total_staff = sum(metrics['staffing']['available_staff'].values()) + st.metric( + "Available Staff ๐Ÿ‘ฅ", + total_staff, + "Optimal ๐ŸŸข" if total_staff > 80 else "Low ๐Ÿ”ด" + ) + + with col4: + utilization = metrics['resources']['resource_utilization'] * 100 + st.metric( + "Resource Utilization ๐Ÿ“ฆ", + f"{utilization:.1f}%", + "โ†˜ -2%" + ) + + # Add metrics to history + st.session_state.metrics_history.append({ + 'timestamp': datetime.now(), + 'metrics': metrics + }) + + except Exception as e: + logger.error(f"Error rendering metrics: {str(e)}") + st.error("Error loading metrics dashboard") + + def render_chat(self): + """Render the chat interface""" + try: + st.markdown("### ๐Ÿ’ฌ Chat Interface") + chat_container = st.container() + + with chat_container: + # Display chat messages + for message in st.session_state.messages: + role = message["role"] + content = message["content"] + timestamp = message.get("timestamp", datetime.now()) + + with st.chat_message(role, avatar="๐Ÿค–" if role == "assistant" else "๐Ÿ‘ค"): + st.markdown(content) + st.caption(f":clock2: {timestamp.strftime('%H:%M')}") + + # Chat input + if prompt := st.chat_input("How can I assist you with healthcare operations today?"): + # Add user message + current_time = datetime.now() + st.session_state.messages.append({ + "role": "user", + "content": prompt, + "timestamp": current_time + }) + + # Display user message + with st.chat_message("user", avatar="๐Ÿ‘ค"): + st.markdown(prompt) + st.caption(f":clock2: {current_time.strftime('%H:%M')}") + + # Display assistant response + with st.chat_message("assistant", avatar="๐Ÿค–"): + with st.spinner("Processing your request... ๐Ÿ”„"): + try: + # Generate response based on query type + response = self._get_department_response(prompt) + + # Display structured response + st.markdown("### ๐Ÿ” Key Insights") + st.markdown(response["insights"]) + + st.markdown("### ๐Ÿ“‹ Actionable Recommendations") + st.markdown(response["recommendations"]) + + st.markdown("### โšก Priority Actions") + st.markdown(response["priority_actions"]) + + st.markdown("### โฐ Implementation Timeline") + st.markdown(response["timeline"]) + + # Update metrics if available + if "metrics" in response: + self.render_metrics(response["metrics"]) + + # Add to chat history + st.session_state.messages.append({ + "role": "assistant", + "content": response["full_response"], + "timestamp": datetime.now() + }) + + except Exception as e: + st.error(f"Error processing request: {str(e)} โŒ") + logger.error(f"Error in chat processing: {str(e)}") + + except Exception as e: + logger.error(f"Error rendering chat interface: {str(e)}") + st.error("Error loading chat interface") + + def _get_department_response(self, query: str) -> Dict[str, Any]: + """Generate response based on query type""" + query = query.lower() + + # Waiting times response + if "waiting" in query or "wait time" in query: + return { + "insights": """ + ๐Ÿ“Š Current Department Wait Times: + - ER: 45 minutes (โš ๏ธ Above target) + - ICU: 5 minutes (โœ… Within target) + - General Ward: 25 minutes (โœ… Within target) + - Surgery: 30 minutes (โšก Approaching target) + - Pediatrics: 20 minutes (โœ… Within target) + """, + "recommendations": """ + 1. ๐Ÿ‘ฅ Deploy additional triage nurses to ER + 2. ๐Ÿ”„ Optimize patient handoff procedures + 3. ๐Ÿ“ฑ Implement real-time wait time updates + 4. ๐Ÿฅ Activate overflow protocols where needed + """, + "priority_actions": """ + Immediate Actions Required: + - ๐Ÿšจ Redirect non-emergency cases from ER + - ๐Ÿ‘จโ€โš•๏ธ Increase ER staffing for next 2 hours + - ๐Ÿ“ข Update waiting patients every 15 minutes + """, + "timeline": """ + Implementation Schedule: + - ๐Ÿ• 0-1 hour: Staff reallocation + - ๐Ÿ•’ 1-2 hours: Process optimization + - ๐Ÿ•“ 2-4 hours: Situation reassessment + - ๐Ÿ•” 4+ hours: Long-term monitoring + """, + "metrics": { + "patient_flow": { + "occupied_beds": 85, + "total_beds": 100, + "waiting_patients": 18, + "average_wait_time": 35.0 + }, + "quality": {"patient_satisfaction": 7.8}, + "staffing": {"available_staff": {"doctors": 22, "nurses": 55}}, + "resources": {"resource_utilization": 0.82} + }, + "full_response": "Based on current data, we're seeing elevated wait times in the ER department. Immediate actions have been recommended to address this situation." + } + + # Bed occupancy response + elif "bed" in query or "occupancy" in query: + return { + "insights": """ + ๐Ÿ›๏ธ Current Bed Occupancy Status: + - Overall Occupancy: 85% + - Critical Care: 90% (โš ๏ธ Near capacity) + - General Wards: 82% (โœ… Optimal) + - Available Emergency Beds: 5 + """, + "recommendations": """ + 1. ๐Ÿ”„ Review discharge plans + 2. ๐Ÿฅ Prepare overflow areas + 3. ๐Ÿ“‹ Optimize bed turnover + 4. ๐Ÿ‘ฅ Adjust staff allocation + """, + "priority_actions": """ + Critical Actions: + - ๐Ÿšจ Expedite planned discharges + - ๐Ÿฅ Activate surge capacity plan + - ๐Ÿ“Š Hourly capacity monitoring + """, + "timeline": """ + Action Timeline: + - ๐Ÿ• Immediate: Discharge reviews + - ๐Ÿ•‘ 2 hours: Capacity reassessment + - ๐Ÿ•’ 4 hours: Staff reallocation + - ๐Ÿ•“ 8 hours: Full situation review + """, + "metrics": { + "patient_flow": { + "occupied_beds": 90, + "total_beds": 100, + "waiting_patients": 12, + "average_wait_time": 30.0 + }, + "quality": {"patient_satisfaction": 8.0}, + "staffing": {"available_staff": {"doctors": 25, "nurses": 58}}, + "resources": {"resource_utilization": 0.88} + }, + "full_response": "Current bed occupancy is at 85% with critical care areas approaching capacity. Immediate actions are being taken to optimize bed utilization." + } + + # Default response for other queries + else: + return { + "insights": """ + Please specify your request: + - ๐Ÿฅ Department specific information + - โฐ Wait time inquiries + - ๐Ÿ›๏ธ Bed capacity status + - ๐Ÿ‘ฅ Staffing information + - ๐Ÿ“Š Resource utilization + """, + "recommendations": "To better assist you, please provide more specific details about what you'd like to know.", + "priority_actions": "No immediate actions required. Awaiting specific inquiry.", + "timeline": "Timeline will be generated based on specific requests.", + "full_response": "I'm here to help! Please specify what information you need about healthcare operations." + } + + def render_sidebar(self): + """Render the sidebar with controls and filters""" + try: + with st.sidebar: + # Add custom CSS for consistent button styling + st.markdown(""" + + """, unsafe_allow_html=True) + + st.markdown("### โš™๏ธ Settings") + + # Department filter + if "department_filter" not in st.session_state: + st.session_state.department_filter = "All Departments" + + st.selectbox( + "Select Department", + ["All Departments", "ER", "ICU", "General Ward", "Surgery", "Pediatrics"], + key="department_filter" + ) + + # Priority filter + if "priority_filter" not in st.session_state: + st.session_state.priority_filter = "Medium" + + st.select_slider( + "Priority Level", + options=["Low", "Medium", "High", "Urgent", "Critical"], + key="priority_filter" + ) + + # Time range + if "time_range_filter" not in st.session_state: + st.session_state.time_range_filter = 8 + + st.slider( + "Time Range (hours)", + min_value=1, + max_value=24, + key="time_range_filter" + ) + + # Quick actions with consistent styling + st.markdown("### โšก Quick Actions") + + # Create two columns for buttons + col1, col2 = st.columns(2) + + with col1: + if st.button("๐Ÿ“Š Report"): + st.info("Generating comprehensive report...") + + with col2: + if st.button("๐Ÿ”„ Refresh"): + st.success("Data refreshed successfully!") + + # Emergency Mode + st.markdown("### ๐Ÿšจ Emergency Mode") + + if "emergency_mode" not in st.session_state: + st.session_state.emergency_mode = False + + st.toggle( + "Activate Emergency Protocol", + key="emergency_mode", + help="Enable emergency mode for critical situations" + ) + + if st.session_state.emergency_mode: + st.warning("Emergency Mode Active!") + + # Help section + st.markdown("### โ“ Help") + with st.expander("Usage Guide"): + st.markdown(""" + - ๐Ÿ’ฌ Use the chat to ask questions + - ๐Ÿ“Š Monitor real-time metrics + - โš™๏ธ Adjust filters as needed + - ๐Ÿ“‹ Generate reports for analysis + - ๐Ÿšจ Toggle emergency mode for critical situations + """) + + # Footer + st.markdown("---") + st.caption( + f"*Last updated: {datetime.now().strftime('%H:%M:%S')}*" + ) + + except Exception as e: + logger.error(f"Error rendering sidebar: {str(e)}") + st.error("Error loading sidebar") + + def run(self): + """Run the Streamlit application""" + try: + # Main application container + main_container = st.container() + + with main_container: + # Render components + self.render_header() + self.render_sidebar() + + # Main content area + content_container = st.container() + with content_container: + self.render_metrics() + st.markdown("
", unsafe_allow_html=True) # Spacing + self.render_chat() + + except Exception as e: + logger.error(f"Error running application: {str(e)}") + st.error(f"Application error: {str(e)} โŒ") \ No newline at end of file diff --git a/src/ui/assets/icons/.gitkeep b/src/ui/assets/icons/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ui/assets/images/.gitkeep b/src/ui/assets/images/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ui/components/__init__.py b/src/ui/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d348218583c23ac33986cc4830cac23c53e867b --- /dev/null +++ b/src/ui/components/__init__.py @@ -0,0 +1,11 @@ +from .chat import ChatComponent +from .metrics import MetricsComponent +from .sidebar import SidebarComponent +from .header import HeaderComponent + +__all__ = [ + 'ChatComponent', + 'MetricsComponent', + 'SidebarComponent', + 'HeaderComponent' +] \ No newline at end of file diff --git a/src/ui/components/chat.py b/src/ui/components/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..899c6ee5de7c6a1b00d32268a148f0a660853e4a --- /dev/null +++ b/src/ui/components/chat.py @@ -0,0 +1,78 @@ +import streamlit as st +from typing import Optional, Dict, Callable +from datetime import datetime + +class ChatComponent: + def __init__(self, process_message_callback: Callable): + """ + Initialize the chat component + + Args: + process_message_callback: Callback function to process messages + """ + self.process_message = process_message_callback + + # Initialize session state for messages if not exists + if 'messages' not in st.session_state: + st.session_state.messages = [] + + def _display_message(self, role: str, content: str, timestamp: Optional[datetime] = None): + """Display a single chat message""" + avatar = "๐Ÿค–" if role == "assistant" else "๐Ÿ‘ค" + with st.chat_message(role, avatar=avatar): + st.markdown(content) + if timestamp: + st.caption(f":clock2: {timestamp.strftime('%H:%M')}") + + def render(self): + """Render the chat interface""" + st.markdown("### ๐Ÿ’ฌ Healthcare Operations Chat") + + # Display chat messages + for message in st.session_state.messages: + self._display_message( + role=message["role"], + content=message["content"], + timestamp=message.get("timestamp") + ) + + # Chat input + if prompt := st.chat_input( + "Ask about patient flow, resources, quality metrics, or staff scheduling..." + ): + # Add user message + current_time = datetime.now() + st.session_state.messages.append({ + "role": "user", + "content": prompt, + "timestamp": current_time + }) + + # Display user message + self._display_message("user", prompt, current_time) + + # Process message and get response + with st.spinner("Processing your request... ๐Ÿ”„"): + try: + response = self.process_message(prompt) + + # Add and display assistant response + st.session_state.messages.append({ + "role": "assistant", + "content": response["response"], + "timestamp": datetime.now() + }) + + self._display_message( + "assistant", + response["response"], + datetime.now() + ) + + except Exception as e: + st.error(f"Error processing your request: {str(e)} โŒ") + + def clear_chat(self): + """Clear the chat history""" + st.session_state.messages = [] + st.success("Chat history cleared! ๐Ÿงน") \ No newline at end of file diff --git a/src/ui/components/header.py b/src/ui/components/header.py new file mode 100644 index 0000000000000000000000000000000000000000..d4af6531cc95f4c23b08f6f69df34e5d3b7a3622 --- /dev/null +++ b/src/ui/components/header.py @@ -0,0 +1,73 @@ +import streamlit as st +from datetime import datetime + +class HeaderComponent: + def __init__(self): + """Initialize the header component""" + # Initialize session state for notifications if not exists + if 'notifications' not in st.session_state: + st.session_state.notifications = [] + + def _add_notification(self, message: str, type: str = "info"): + """Add a notification to the session state""" + st.session_state.notifications.append({ + "message": message, + "type": type, + "timestamp": datetime.now() + }) + + def render(self): + """Render the header""" + # Main header container + header_container = st.container() + + with header_container: + # Top row with logo and title + col1, col2, col3 = st.columns([1, 4, 1]) + + with col1: + st.markdown("# ๐Ÿฅ") + + with col2: + st.title("Healthcare Operations Assistant") + st.markdown(""" +
+ *AI-Powered Healthcare Management System* ๐Ÿค– +
+ """, unsafe_allow_html=True) + + with col3: + # Status indicator + status = "๐ŸŸข Online" if st.session_state.get('system_status', True) else "๐Ÿ”ด Offline" + st.markdown(f"### {status}") + + # Notification area + if st.session_state.notifications: + with st.expander("๐Ÿ“ฌ Notifications", expanded=True): + for notif in st.session_state.notifications[-3:]: # Show last 3 + if notif["type"] == "info": + st.info(notif["message"]) + elif notif["type"] == "warning": + st.warning(notif["message"]) + elif notif["type"] == "error": + st.error(notif["message"]) + elif notif["type"] == "success": + st.success(notif["message"]) + + # System status bar + status_cols = st.columns(4) + with status_cols[0]: + st.markdown("**System Status:** Operational โœ…") + with status_cols[1]: + st.markdown("**API Status:** Connected ๐Ÿ”—") + with status_cols[2]: + st.markdown("**Load:** Normal ๐Ÿ“Š") + with status_cols[3]: + st.markdown(f"**Last Update:** {datetime.now().strftime('%H:%M')} ๐Ÿ•’") + + # Divider + st.markdown("---") + + def add_notification(self, message: str, type: str = "info"): + """Public method to add notifications""" + self._add_notification(message, type) \ No newline at end of file diff --git a/src/ui/components/metrics.py b/src/ui/components/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b17578d0942cd1132c5af9b21ea7ff307e31afca --- /dev/null +++ b/src/ui/components/metrics.py @@ -0,0 +1,139 @@ +import streamlit as st +from typing import Dict, Any, Optional + +class MetricsComponent: + def __init__(self): + """Initialize the metrics component""" + self.default_metrics = { + "patient_flow": { + "occupied_beds": 75, + "total_beds": 100, + "waiting_time": 15, + "discharge_rate": 8 + }, + "quality": { + "patient_satisfaction": 8.5, + "compliance_rate": 0.95, + "incident_count": 2 + }, + "staffing": { + "available_staff": { + "doctors": 20, + "nurses": 50, + "specialists": 15 + }, + "shift_coverage": 0.92 + }, + "resources": { + "resource_utilization": 0.75, + "critical_supplies": 3, + "equipment_availability": 0.88 + } + } + + def _render_metric_card( + self, + title: str, + value: Any, + delta: Optional[str] = None, + help_text: Optional[str] = None + ): + """Render a single metric card""" + st.metric( + label=title, + value=value, + delta=delta, + help=help_text + ) + + def render(self, metrics: Optional[Dict[str, Any]] = None): + """ + Render the metrics dashboard + + Args: + metrics: Optional metrics data to display + """ + metrics = metrics or self.default_metrics + + st.markdown("### ๐Ÿ“Š Operational Metrics Dashboard") + + # Create two rows of metrics + row1_cols = st.columns(4) + row2_cols = st.columns(4) + + # First row - Key metrics + with row1_cols[0]: + occupancy = (metrics["patient_flow"]["occupied_beds"] / + metrics["patient_flow"]["total_beds"] * 100) + self._render_metric_card( + "Bed Occupancy ๐Ÿ›๏ธ", + f"{occupancy:.1f}%", + "Normal" if occupancy < 85 else "High", + "Current bed occupancy rate across all departments" + ) + + with row1_cols[1]: + satisfaction = metrics["quality"]["patient_satisfaction"] + self._render_metric_card( + "Patient Satisfaction ๐Ÿ˜Š", + f"{satisfaction}/10", + "โ†— +0.5" if satisfaction > 8 else "โ†˜ -0.3", + "Average patient satisfaction score" + ) + + with row1_cols[2]: + total_staff = sum(metrics["staffing"]["available_staff"].values()) + self._render_metric_card( + "Available Staff ๐Ÿ‘ฅ", + total_staff, + "Optimal" if total_staff > 80 else "Low", + "Total number of available staff across all roles" + ) + + with row1_cols[3]: + utilization = metrics["resources"]["resource_utilization"] * 100 + self._render_metric_card( + "Resource Utilization ๐Ÿ“ฆ", + f"{utilization:.1f}%", + "Efficient" if utilization < 80 else "High", + "Current resource utilization rate" + ) + + # Second row - Additional metrics + with row2_cols[0]: + self._render_metric_card( + "Waiting Time โฐ", + f"{metrics['patient_flow']['waiting_time']} min", + help_text="Average patient waiting time" + ) + + with row2_cols[1]: + self._render_metric_card( + "Compliance Rate โœ…", + f"{metrics['quality']['compliance_rate']*100:.1f}%", + help_text="Current compliance rate with protocols" + ) + + with row2_cols[2]: + self._render_metric_card( + "Critical Supplies โš ๏ธ", + metrics['resources']['critical_supplies'], + "Action needed" if metrics['resources']['critical_supplies'] > 0 else "All stocked", + "Number of supplies needing immediate attention" + ) + + with row2_cols[3]: + self._render_metric_card( + "Shift Coverage ๐Ÿ“…", + f"{metrics['staffing']['shift_coverage']*100:.1f}%", + help_text="Current shift coverage rate" + ) + + # Additional visualization if needed + with st.expander("๐Ÿ“ˆ Detailed Metrics Analysis"): + st.markdown(""" + ### Trend Analysis + - ๐Ÿ“ˆ Patient flow is within normal range + - ๐Ÿ“‰ Resource utilization shows optimization opportunities + - ๐Ÿ“Š Staff distribution is balanced across departments + """) \ No newline at end of file diff --git a/src/ui/components/sidebar.py b/src/ui/components/sidebar.py new file mode 100644 index 0000000000000000000000000000000000000000..d92fac2efc971463396624f2b0ab87f9704a4b13 --- /dev/null +++ b/src/ui/components/sidebar.py @@ -0,0 +1,121 @@ +import streamlit as st +from typing import Dict, Any, Callable +from datetime import datetime, timedelta + +class SidebarComponent: + def __init__(self, on_filter_change: Optional[Callable] = None): + """ + Initialize the sidebar component + + Args: + on_filter_change: Optional callback for filter changes + """ + self.on_filter_change = on_filter_change + + # Initialize session state for filters if not exists + if 'filters' not in st.session_state: + st.session_state.filters = { + 'department': 'All Departments', + 'priority': 'Medium', + 'time_range': 8, + 'view_mode': 'Standard' + } + + def render(self): + """Render the sidebar""" + with st.sidebar: + st.markdown("# โš™๏ธ Operations Control") + + # Department Selection + st.markdown("### ๐Ÿฅ Department") + department = st.selectbox( + "Select Department", + [ + "All Departments", + "Emergency Room", + "ICU", + "General Ward", + "Surgery", + "Pediatrics", + "Cardiology" + ], + index=0, + help="Filter data by department" + ) + + # Priority Filter + st.markdown("### ๐ŸŽฏ Priority Level") + priority = st.select_slider( + "Set Priority", + options=["Low", "Medium", "High", "Urgent", "Critical"], + value=st.session_state.filters['priority'], + help="Filter by priority level" + ) + + # Time Range + st.markdown("### ๐Ÿ•’ Time Range") + time_range = st.slider( + "Select Time Range", + min_value=1, + max_value=24, + value=st.session_state.filters['time_range'], + help="Time range for data analysis (hours)" + ) + + # View Mode + st.markdown("### ๐Ÿ‘๏ธ View Mode") + view_mode = st.radio( + "Select View Mode", + ["Standard", "Detailed", "Compact"], + help="Change the display density" + ) + + # Quick Actions + st.markdown("### โšก Quick Actions") + col1, col2 = st.columns(2) + with col1: + if st.button("๐Ÿ“Š Report", use_container_width=True): + st.info("Generating report...") + with col2: + if st.button("๐Ÿ”„ Refresh", use_container_width=True): + st.success("Data refreshed!") + + # Emergency Mode Toggle + st.markdown("### ๐Ÿšจ Emergency Mode") + emergency_mode = st.toggle( + "Activate Emergency Protocol", + help="Enable emergency mode for critical situations" + ) + if emergency_mode: + st.warning("Emergency Mode Active!") + + # Help & Documentation + with st.expander("โ“ Help & Tips"): + st.markdown(""" + ### Quick Guide + - ๐Ÿ” Use filters to focus on specific areas + - ๐Ÿ“ˆ Monitor real-time metrics + - ๐Ÿšจ Toggle emergency mode for critical situations + - ๐Ÿ“Š Generate reports for analysis + - ๐Ÿ’ก Access quick actions for common tasks + """) + + # Update filters in session state + st.session_state.filters.update({ + 'department': department, + 'priority': priority, + 'time_range': time_range, + 'view_mode': view_mode, + 'emergency_mode': emergency_mode + }) + + # Call filter change callback if provided + if self.on_filter_change: + self.on_filter_change(st.session_state.filters) + + # Footer + st.markdown("---") + st.markdown( + f"*Last updated: {datetime.now().strftime('%H:%M:%S')}*", + help="Last data refresh timestamp" + ) \ No newline at end of file diff --git a/src/ui/styles/__init__.py b/src/ui/styles/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e21f90669d36a4052f6d4ead072b423681ac4350 --- /dev/null +++ b/src/ui/styles/__init__.py @@ -0,0 +1,3 @@ +from .theme import HealthcareTheme + +__all__ = ['HealthcareTheme'] \ No newline at end of file diff --git a/src/ui/styles/custom.css b/src/ui/styles/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..688e1dccc230cd41ba925f8fc139092465c97524 --- /dev/null +++ b/src/ui/styles/custom.css @@ -0,0 +1,224 @@ +/* Healthcare Operations Assistant Custom Styles */ + +/* Layout and Structure */ +.container { + max-width: 1200px; + margin: 0 auto; + padding: 1rem; +} + +/* Chat Interface */ +.chat-container { + background-color: #ffffff; + border-radius: 12px; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); + padding: 1rem; + margin: 1rem 0; +} + +.user-message { + background-color: #e3f2fd; + padding: 1rem; + border-radius: 10px; + margin: 1rem 0; + border-left: 5px solid #1976d2; +} + +.assistant-message { + background-color: #f5f5f5; + padding: 1rem; + border-radius: 10px; + margin: 1rem 0; + border-left: 5px solid #4caf50; +} + +.message-timestamp { + font-size: 0.8rem; + color: #707070; + margin-top: 0.25rem; +} + +/* Metrics Dashboard */ +.metric-card { + background-color: white; + border-radius: 12px; + padding: 1.5rem; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); + transition: transform 0.2s; +} + +.metric-card:hover { + transform: translateY(-2px); + box-shadow: 0 4px 6px rgba(0,0,0,0.1); +} + +.metric-title { + font-size: 1rem; + font-weight: 600; + color: #2c3e50; + margin-bottom: 0.5rem; +} + +.metric-value { + font-size: 1.5rem; + font-weight: 700; + color: #1976d2; +} + +.metric-trend { + font-size: 0.9rem; + color: #4caf50; +} + +.metric-trend.negative { + color: #f44336; +} + +/* Header Styling */ +.header { + background-color: white; + padding: 1rem; + border-bottom: 1px solid #e0e0e0; + margin-bottom: 2rem; +} + +.header-title { + font-size: 1.8rem; + font-weight: 700; + color: #2c3e50; +} + +.header-subtitle { + font-size: 1rem; + color: #707070; +} + +/* Sidebar Styling */ +.sidebar { + background-color: white; + padding: 1.5rem; + border-right: 1px solid #e0e0e0; +} + +.sidebar-section { + margin-bottom: 2rem; +} + +.sidebar-title { + font-size: 1.1rem; + font-weight: 600; + color: #2c3e50; + margin-bottom: 1rem; +} + +/* Status Indicators */ +.status-indicator { + display: inline-flex; + align-items: center; + padding: 0.25rem 0.75rem; + border-radius: 9999px; + font-size: 0.875rem; + font-weight: 500; +} + +.status-normal { + background-color: #4caf50; + color: white; +} + +.status-warning { + background-color: #ff9800; + color: white; +} + +.status-critical { + background-color: #f44336; + color: white; +} + +/* Buttons and Interactive Elements */ +.action-button { + background-color: #2196f3; + color: white; + border: none; + border-radius: 8px; + padding: 0.5rem 1rem; + font-weight: 500; + cursor: pointer; + transition: background-color 0.2s; +} + +.action-button:hover { + background-color: #1976d2; +} + +.action-button.secondary { + background-color: #f5f5f5; + color: #2c3e50; + border: 1px solid #e0e0e0; +} + +.action-button.secondary:hover { + background-color: #e0e0e0; +} + +/* Notifications */ +.notification { + padding: 0.75rem 1rem; + border-radius: 8px; + margin-bottom: 1rem; +} + +.notification.info { + background-color: #e3f2fd; + border-left: 4px solid #2196f3; +} + +.notification.success { + background-color: #e8f5e9; + border-left: 4px solid #4caf50; +} + +.notification.warning { + background-color: #fff3e0; + border-left: 4px solid #ff9800; +} + +.notification.error { + background-color: #ffebee; + border-left: 4px solid #f44336; +} + +/* Responsive Design */ +@media (max-width: 768px) { + .metric-card { + margin-bottom: 1rem; + } + + .header-title { + font-size: 1.5rem; + } + + .sidebar { + padding: 1rem; + } +} + +/* Animations */ +@keyframes fadeIn { + from { opacity: 0; } + to { opacity: 1; } +} + +.fade-in { + animation: fadeIn 0.3s ease-in; +} + +@keyframes slideIn { + from { transform: translateY(20px); opacity: 0; } + to { transform: translateY(0); opacity: 1; } +} + +.slide-in { + animation: slideIn 0.3s ease-out; +} \ No newline at end of file diff --git a/src/ui/styles/theme.py b/src/ui/styles/theme.py new file mode 100644 index 0000000000000000000000000000000000000000..f6402e9094f3c04037532f27e9e174fddf21cee5 --- /dev/null +++ b/src/ui/styles/theme.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass +from typing import Dict, Any + +@dataclass +class HealthcareTheme: + """Healthcare UI theme configuration""" + + # Color palette + colors = { + 'primary': '#2196f3', # Main blue + 'primary_light': '#e3f2fd', # Light blue + 'primary_dark': '#1976d2', # Dark blue + 'success': '#4caf50', # Green + 'warning': '#ff9800', # Orange + 'error': '#f44336', # Red + 'info': '#2196f3', # Blue + 'background': '#f0f8ff', # Light blue background + 'surface': '#ffffff', # White + 'text': '#2c3e50', # Dark gray + 'text_secondary': '#707070' # Medium gray + } + + # Typography + fonts = { + 'primary': '"Source Sans Pro", -apple-system, BlinkMacSystemFont, sans-serif', + 'monospace': '"Roboto Mono", monospace' + } + + # Spacing + spacing = { + 'xs': '0.25rem', + 'sm': '0.5rem', + 'md': '1rem', + 'lg': '1.5rem', + 'xl': '2rem' + } + + # Border radius + radius = { + 'sm': '4px', + 'md': '8px', + 'lg': '12px', + 'xl': '16px', + 'pill': '9999px' + } + + # Shadows + shadows = { + 'sm': '0 1px 3px rgba(0,0,0,0.12)', + 'md': '0 2px 4px rgba(0,0,0,0.1)', + 'lg': '0 4px 6px rgba(0,0,0,0.1)', + 'xl': '0 8px 12px rgba(0,0,0,0.1)' + } + + @classmethod + def get_streamlit_config(cls) -> Dict[str, Any]: + """Get Streamlit theme configuration""" + return { + "theme": { + "primaryColor": cls.colors['primary'], + "backgroundColor": cls.colors['background'], + "secondaryBackgroundColor": cls.colors['surface'], + "textColor": cls.colors['text'], + "font": cls.fonts['primary'] + } + } + + @classmethod + def apply_theme(cls): + """Apply theme to Streamlit application""" + import streamlit as st + + # Apply theme configuration + st.set_page_config(**cls.get_streamlit_config()) + + # Apply custom CSS + st.markdown(""" + + """, unsafe_allow_html=True) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d82b3f07e73072379f02f1132b7f28699c9b6040 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,6 @@ +# src/utils/__init__.py +from .logger import setup_logger +from .error_handlers import ErrorHandler +from .validators import Validator + +__all__ = ['setup_logger', 'ErrorHandler', 'Validator'] \ No newline at end of file diff --git a/src/utils/error_handlers.py b/src/utils/error_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..648c3331165cfe943393654f76fe60fcf42a5060 --- /dev/null +++ b/src/utils/error_handlers.py @@ -0,0 +1,122 @@ +# src/utils/error_handlers.py +from typing import Dict, Any, Optional, Callable +from functools import wraps +import traceback +from .logger import setup_logger + +logger = setup_logger(__name__) + +class HealthcareError(Exception): + """Base exception class for healthcare operations""" + def __init__(self, message: str, error_code: str, details: Optional[Dict] = None): + self.message = message + self.error_code = error_code + self.details = details or {} + super().__init__(self.message) + +class ValidationError(HealthcareError): + """Raised when input validation fails""" + def __init__(self, message: str, details: Optional[Dict] = None): + super().__init__( + message=message, + error_code="INPUT_VALIDATION_ERROR", + details=details + ) + +class ProcessingError(HealthcareError): + """Raised when processing operations fail""" + pass + +class ResourceError(HealthcareError): + """Raised when resource-related operations fail""" + pass + +class ErrorHandler: + @staticmethod + def validate_input(input_text: str) -> None: + """Validate input text before processing""" + if not input_text or not input_text.strip(): + raise ValidationError( + message="Input text cannot be empty", + details={"provided_input": input_text} + ) + + @staticmethod + def handle_error(error: Exception) -> Dict[str, Any]: + """Handle different types of errors and return appropriate response""" + if isinstance(error, ValidationError): + logger.error(f"Validation Error: {error.message}", + extra={"error_code": error.error_code, "details": error.details}) + raise error # Re-raise ValidationError + elif isinstance(error, HealthcareError): + logger.error(f"Healthcare Error: {error.message}", + extra={"error_code": error.error_code, "details": error.details}) + return { + "error": True, + "error_code": error.error_code, + "message": error.message, + "details": error.details + } + else: + logger.error(f"Unexpected Error: {str(error)}\n{traceback.format_exc()}") + return { + "error": True, + "error_code": "UNEXPECTED_ERROR", + "message": "An unexpected error occurred", + "details": {"error_type": type(error).__name__} + } + + @staticmethod + def error_decorator(func: Callable) -> Callable: + """Decorator for handling errors in functions""" + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ValidationError: + # Let ValidationError propagate up + raise + except Exception as e: + return ErrorHandler.handle_error(e) + return wrapper + + @staticmethod + def retry_operation( + operation: Callable, + max_retries: int = 3, + retry_delay: float = 1.0 + ) -> Any: + """ + Retry an operation with exponential backoff + """ + from time import sleep + + for attempt in range(max_retries): + try: + return operation() + except Exception as e: + if attempt == max_retries - 1: + raise + + logger.warning( + f"Operation failed (attempt {attempt + 1}/{max_retries}): {str(e)}" + ) + sleep(retry_delay * (2 ** attempt)) + + @staticmethod + def safe_execute( + operation: Callable, + error_code: str, + default_value: Any = None + ) -> Any: + """ + Safely execute an operation with error handling + """ + try: + return operation() + except Exception as e: + logger.error(f"Operation failed: {str(e)}") + raise HealthcareError( + message=f"Operation failed: {str(e)}", + error_code=error_code + )# Error handling utilities implementation diff --git a/src/utils/logger.py b/src/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..8a503a628e1df8e7c2468c9edb6bcfc3d7d1f943 --- /dev/null +++ b/src/utils/logger.py @@ -0,0 +1,102 @@ +# src/utils/logger.py +import logging +import sys +from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler +from pathlib import Path +from datetime import datetime +from typing import Optional +from ..config.settings import Settings + +class CustomFormatter(logging.Formatter): + """Custom formatter with color coding for different log levels""" + + COLORS = { + 'DEBUG': '\033[0;36m', # Cyan + 'INFO': '\033[0;32m', # Green + 'WARNING': '\033[0;33m', # Yellow + 'ERROR': '\033[0;31m', # Red + 'CRITICAL': '\033[0;37;41m' # White on Red + } + RESET = '\033[0m' + + def format(self, record): + # Add color to log level if on console + if hasattr(self, 'use_color') and self.use_color: + record.levelname = f"{self.COLORS.get(record.levelname, '')}{record.levelname}{self.RESET}" + return super().format(record) + +def setup_logger( + name: str, + log_level: Optional[str] = None, + log_file: Optional[str] = None +) -> logging.Logger: + """ + Set up logger with both file and console handlers + + Args: + name: Logger name + log_level: Optional override for log level + log_file: Optional override for log file path + + Returns: + Configured logger instance + """ + try: + # Create logger + logger = logging.getLogger(name) + logger.setLevel(log_level or Settings.LOG_LEVEL) + + # Avoid adding handlers if they already exist + if logger.handlers: + return logger + + # Create formatters + file_formatter = logging.Formatter( + '%(asctime)s - %(name)s - [%(levelname)s] - %(message)s' + ) + + console_formatter = CustomFormatter( + '%(asctime)s - %(name)s - [%(levelname)s] - %(message)s' + ) + console_formatter.use_color = True + + # Create and configure file handler + log_file = log_file or Settings.LOG_FILE + log_dir = Path(log_file).parent + log_dir.mkdir(parents=True, exist_ok=True) + + # Rotating file handler (size-based) + file_handler = RotatingFileHandler( + log_file, + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5 + ) + file_handler.setFormatter(file_formatter) + + # Time-based rotating handler for daily logs + daily_handler = TimedRotatingFileHandler( + str(log_dir / f"daily_{datetime.now():%Y-%m-%d}.log"), + when="midnight", + interval=1, + backupCount=30 + ) + daily_handler.setFormatter(file_formatter) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + + # Add handlers + logger.addHandler(file_handler) + logger.addHandler(daily_handler) + logger.addHandler(console_handler) + + return logger + + except Exception as e: + # Fallback to basic logging if setup fails + basic_logger = logging.getLogger(name) + basic_logger.setLevel(logging.INFO) + basic_logger.addHandler(logging.StreamHandler(sys.stdout)) + basic_logger.error(f"Error setting up logger: {str(e)}") + return basic_logger# Logging configuration implementation diff --git a/src/utils/validators.py b/src/utils/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..4f670a1eb20e20ac82a3cdb1c1f94487fb81a8f0 --- /dev/null +++ b/src/utils/validators.py @@ -0,0 +1,132 @@ +# src/utils/validators.py +from typing import Dict, Any, List, Optional +from datetime import datetime +from .logger import setup_logger +from .error_handlers import ValidationError + +logger = setup_logger(__name__) + +class Validator: + @staticmethod + def validate_state(state: Dict[str, Any]) -> bool: + """Validate the state structure and data types""" + required_keys = ["messages", "current_task", "metrics", "timestamp"] + + try: + # Check required keys + for key in required_keys: + if key not in state: + raise ValidationError( + message=f"Missing required key: {key}", + error_code="INVALID_STATE_STRUCTURE" + ) + + # Validate timestamp + if not isinstance(state["timestamp"], datetime): + raise ValidationError( + message="Invalid timestamp format", + error_code="INVALID_TIMESTAMP" + ) + + return True + + except Exception as e: + logger.error(f"State validation failed: {str(e)}") + raise + + @staticmethod + def validate_metrics(metrics: Dict[str, Any]) -> bool: + """Validate metrics data structure and values""" + required_categories = [ + "patient_flow", + "resources", + "quality", + "staffing" + ] + + try: + # Check required categories + for category in required_categories: + if category not in metrics: + raise ValidationError( + message=f"Missing required metrics category: {category}", + error_code="INVALID_METRICS_STRUCTURE" + ) + + # Validate numeric values + Validator._validate_numeric_values(metrics) + + return True + + except Exception as e: + logger.error(f"Metrics validation failed: {str(e)}") + raise + + @staticmethod + def validate_tool_input( + tool_name: str, + params: Dict[str, Any], + required_params: List[str] + ) -> bool: + """Validate input parameters for tools""" + try: + # Check required parameters + for param in required_params: + if param not in params: + raise ValidationError( + message=f"Missing required parameter: {param}", + error_code="MISSING_PARAMETER", + details={"tool": tool_name, "parameter": param} + ) + + return True + + except Exception as e: + logger.error(f"Tool input validation failed: {str(e)}") + raise + + @staticmethod + def validate_department_data(department_data: Dict[str, Any]) -> bool: + """Validate department-specific data""" + required_fields = [ + "capacity", + "current_occupancy", + "staff_count" + ] + + try: + # Check required fields + for field in required_fields: + if field not in department_data: + raise ValidationError( + message=f"Missing required field: {field}", + error_code="INVALID_DEPARTMENT_DATA" + ) + + # Validate capacity constraints + if department_data["current_occupancy"] > department_data["capacity"]: + raise ValidationError( + message="Current occupancy exceeds capacity", + error_code="INVALID_OCCUPANCY" + ) + + return True + + except Exception as e: + logger.error(f"Department data validation failed: {str(e)}") + raise + + @staticmethod + def _validate_numeric_values(data: Dict[str, Any], path: str = "") -> None: + """Recursively validate numeric values in nested dictionary""" + for key, value in data.items(): + current_path = f"{path}.{key}" if path else key + + if isinstance(value, (int, float)): + if value < 0: + raise ValidationError( + message=f"Negative value not allowed: {current_path}", + error_code="INVALID_NUMERIC_VALUE" + ) + elif isinstance(value, dict): + Validator._validate_numeric_values(value, current_path)# Input validation utilities implementation diff --git a/streamlit_app.py b/streamlit_app.py new file mode 100644 index 0000000000000000000000000000000000000000..699861855b59064098f5f794fe5aef52e40bda01 --- /dev/null +++ b/streamlit_app.py @@ -0,0 +1,10 @@ +from src.ui import HealthcareUI +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +if __name__ == "__main__": + app = HealthcareUI() + app.run() \ No newline at end of file diff --git a/test_healthcare_agent_basic.py b/test_healthcare_agent_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..3db88c47d3681d97c3f8b3eb3534c7d75f88daaa --- /dev/null +++ b/test_healthcare_agent_basic.py @@ -0,0 +1,66 @@ +# test_healthcare_agent_basic.py +import os +from datetime import datetime +from src.agent import HealthcareAgent +from src.models.state import TaskType, PriorityLevel +from src.utils.error_handlers import ValidationError, HealthcareError + +def main(): + """Basic test of the Healthcare Operations Management Agent""" + try: + # 1. Test Agent Initialization + print("\n=== Testing Agent Initialization ===") + agent = HealthcareAgent(os.getenv("OPENAI_API_KEY")) + print("โœ“ Agent initialized successfully") + + # 2. Test Basic Query - Patient Flow + print("\n=== Testing Patient Flow Query ===") + patient_query = "What is the current ER occupancy and wait time?" + response = agent.process( + input_text=patient_query, + thread_id="test-thread-1" + ) + print(f"Query: {patient_query}") + print(f"Response: {response.get('response', 'No response')}") + print(f"Analysis: {response.get('analysis', {})}") + + # 3. Test Resource Management Query + print("\n=== Testing Resource Management Query ===") + resource_query = "Check the current availability of ventilators and ICU beds" + response = agent.process( + input_text=resource_query, + thread_id="test-thread-1" + ) + print(f"Query: {resource_query}") + print(f"Response: {response.get('response', 'No response')}") + print(f"Analysis: {response.get('analysis', {})}") + + # 4. Test Conversation History + print("\n=== Testing Conversation History ===") + history = agent.get_conversation_history("test-thread-1") + print(f"Conversation history length: {len(history)}") + + # 5. Test Reset Conversation + print("\n=== Testing Conversation Reset ===") + reset_success = agent.reset_conversation("test-thread-1") + print(f"Reset successful: {reset_success}") + + # 6. Test Error Handling + print("\n=== Testing Error Handling ===") + try: + agent.process("") + print("โŒ Error handling test failed - empty input accepted") + except ValidationError as ve: + print(f"โœ“ Error handling working correctly: Empty input rejected with validation error") + except HealthcareError as he: + print(f"โœ“ Error handling working correctly: {str(he)}") + except Exception as e: + print(f"โŒ Unexpected error type: {type(e).__name__}: {str(e)}") + + except Exception as e: + print(f"\nโŒ Test failed with error: {str(e)}") + +if __name__ == "__main__": + print("Starting Healthcare Agent Basic Tests...") + print(f"Test Time: {datetime.now()}") + main() \ No newline at end of file diff --git a/test_healthcare_scenarios.py b/test_healthcare_scenarios.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2eba4ab3e36b34bdd6e479f6496d1308d91fa4 --- /dev/null +++ b/test_healthcare_scenarios.py @@ -0,0 +1,163 @@ +import streamlit as st +from time import sleep +import pytest +from datetime import datetime + +class HealthcareAssistantTester: + def __init__(self): + self.test_results = [] + + def run_test_suite(self): + """Run all test scenarios""" + print("\n=== Starting Healthcare Assistant Test Suite ===\n") + + # Run all test categories + self.test_patient_flow() + self.test_resource_management() + self.test_staff_scheduling() + self.test_quality_metrics() + self.test_emergency_scenarios() + self.test_department_specific() + + # Print test summary + self.print_test_summary() + + def test_patient_flow(self): + """Test Patient Flow Related Queries""" + print("\n1. Testing Patient Flow Queries:") + queries = [ + "Show me waiting times across all departments", + "What is the current bed occupancy in the ER?", + "How many patients are currently waiting for admission?", + "What's the average wait time in the ICU?", + "Show patient flow trends for the last 8 hours", + "Which department has the longest waiting time right now?" + ] + self._run_test_batch("Patient Flow", queries) + + def test_resource_management(self): + """Test Resource Management Queries""" + print("\n2. Testing Resource Management Queries:") + queries = [ + "Check medical supplies inventory status", + "What is the current ventilator availability?", + "Are there any critical supply shortages?", + "Show resource utilization across departments", + "Which supplies need immediate reordering?", + "What's the equipment maintenance status?" + ] + self._run_test_batch("Resource Management", queries) + + def test_staff_scheduling(self): + """Test Staff Scheduling Queries""" + print("\n3. Testing Staff Scheduling Queries:") + queries = [ + "Show current staff distribution", + "How many nurses are available in ICU?", + "What is the current shift coverage?", + "Show staff overtime hours this week", + "Is there adequate staff coverage for next shift?", + "Which departments need additional staff right now?" + ] + self._run_test_batch("Staff Scheduling", queries) + + def test_quality_metrics(self): + """Test Quality Metrics Queries""" + print("\n4. Testing Quality Metrics Queries:") + queries = [ + "What's our current patient satisfaction score?", + "Show me compliance rates for the last 24 hours", + "Are there any quality metrics below target?", + "What's the current incident report status?", + "Show quality trends across departments", + "Which department has the highest patient satisfaction?" + ] + self._run_test_batch("Quality Metrics", queries) + + def test_emergency_scenarios(self): + """Test Emergency Scenario Queries""" + print("\n5. Testing Emergency Scenarios:") + queries = [ + "Activate emergency protocol for mass casualty incident", + "Need immediate bed availability status for emergency", + "Require rapid staff mobilization plan", + "Emergency resource allocation needed", + "Critical capacity alert in ER", + "Emergency department overflow protocol status" + ] + self._run_test_batch("Emergency Scenarios", queries) + + def test_department_specific(self): + """Test Department-Specific Queries""" + print("\n6. Testing Department-Specific Queries:") + queries = [ + "Show complete metrics for ER department", + "What's the ICU capacity and staff status?", + "General ward patient distribution", + "Surgery department resource utilization", + "Pediatrics department waiting times", + "Cardiology unit staff coverage" + ] + self._run_test_batch("Department-Specific", queries) + + def _run_test_batch(self, category: str, queries: list): + """Run a batch of test queries""" + for query in queries: + try: + print(f"\nTesting: {query}") + print("-" * 50) + + # Simulate processing time + print("Processing query...") + sleep(1) + + # Record test execution + self.test_results.append({ + 'category': category, + 'query': query, + 'timestamp': datetime.now(), + 'status': 'Success' + }) + + print("โœ“ Test completed successfully") + + except Exception as e: + print(f"โœ— Test failed: {str(e)}") + self.test_results.append({ + 'category': category, + 'query': query, + 'timestamp': datetime.now(), + 'status': 'Failed', + 'error': str(e) + }) + + def print_test_summary(self): + """Print summary of all test results""" + print("\n=== Test Execution Summary ===") + print(f"Total Tests Run: {len(self.test_results)}") + + # Calculate statistics + successful_tests = len([t for t in self.test_results if t['status'] == 'Success']) + failed_tests = len([t for t in self.test_results if t['status'] == 'Failed']) + + print(f"Successful Tests: {successful_tests}") + print(f"Failed Tests: {failed_tests}") + + # Print results by category + print("\nResults by Category:") + categories = set([t['category'] for t in self.test_results]) + for category in categories: + category_tests = [t for t in self.test_results if t['category'] == category] + category_success = len([t for t in category_tests if t['status'] == 'Success']) + print(f"{category}: {category_success}/{len(category_tests)} passed") + + print("\n=== Test Suite Completed ===") + +def main(): + """Main test execution function""" + # Initialize and run tests + tester = HealthcareAssistantTester() + tester.run_test_suite() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e046e5b439ac906c661c4922bdee24ccd22036da --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +# tests/__init__.py +import os +import sys + +# Add project root to Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..23e4557273ee97ad6e4400f30060fa9ed99ef423 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,108 @@ +# tests/conftest.py +import pytest +from datetime import datetime +from typing import Dict + +from src.config.settings import Settings +from src.models.state import HospitalState, TaskType, PriorityLevel + +@pytest.fixture +def mock_settings(): + """Fixture for test settings""" + return { + "OPENAI_API_KEY": "test-api-key", + "MODEL_NAME": "gpt-4o-mini-2024-07-18", + "MODEL_TEMPERATURE": 0, + "MEMORY_TYPE": "sqlite", + "MEMORY_URI": ":memory:", + "LOG_LEVEL": "DEBUG" + } + +@pytest.fixture +def mock_llm_response(): + """Fixture for mock LLM responses""" + return { + "input_analysis": { + "task_type": TaskType.PATIENT_FLOW, + "priority": PriorityLevel.HIGH, + "department": "ER", + "context": {"urgent": True} + }, + "patient_flow": { + "recommendations": ["Optimize bed allocation", "Increase staff in ER"], + "metrics": {"waiting_time": 25, "bed_utilization": 0.85} + }, + "quality_monitoring": { + "satisfaction_score": 8.5, + "compliance_rate": 0.95, + "recommendations": ["Maintain current standards"] + } + } + +@pytest.fixture +def mock_hospital_state() -> HospitalState: + """Fixture for mock hospital state""" + return { + "messages": [], + "current_task": TaskType.GENERAL, + "priority_level": PriorityLevel.MEDIUM, + "department": None, + "metrics": { + "patient_flow": { + "total_beds": 100, + "occupied_beds": 75, + "waiting_patients": 10, + "average_wait_time": 30.0 + }, + "resources": { + "equipment_availability": {"ventilators": True}, + "supply_levels": {"masks": 0.8}, + "resource_utilization": 0.75 + }, + "quality": { + "patient_satisfaction": 8.5, + "compliance_rate": 0.95, + "incident_count": 2 + }, + "staffing": { + "total_staff": 200, + "available_staff": {"doctors": 20, "nurses": 50}, + "overtime_hours": 45.5 + } + }, + "analysis": None, + "context": {}, + "timestamp": datetime.now(), + "thread_id": "test-thread-id" + } + +@pytest.fixture +def mock_tools_response(): + """Fixture for mock tool responses""" + return { + "patient_tools": { + "wait_time": 30.5, + "bed_capacity": {"available": 25, "total": 100}, + "discharge_time": datetime.now() + }, + "resource_tools": { + "supply_levels": {"critical": [], "reorder": ["masks"]}, + "equipment_status": {"available": ["xray"], "in_use": ["mri"]} + } + } + +@pytest.fixture +def mock_error_response(): + """Fixture for mock error responses""" + return { + "validation_error": { + "code": "INVALID_INPUT", + "message": "Invalid input parameters", + "details": {"field": "department", "issue": "required"} + }, + "processing_error": { + "code": "PROCESSING_FAILED", + "message": "Failed to process request", + "details": {"step": "analysis", "reason": "timeout"} + } + }# Test configuration implementation diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa8e0669db53617f99329accaaed9fd5435afc0 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,75 @@ +# tests/test_agent.py +import pytest +from src.agent import HealthcareAgent +from src.utils.error_handlers import HealthcareError + +class TestHealthcareAgent: + def test_agent_initialization(self, mock_settings): + """Test agent initialization""" + agent = HealthcareAgent(api_key=mock_settings["OPENAI_API_KEY"]) + assert agent is not None + assert agent.llm is not None + assert agent.tools is not None + assert agent.nodes is not None + + def test_process_input(self, mock_hospital_state): + """Test processing of input through agent""" + agent = HealthcareAgent() + result = agent.process( + "What is the current ER waiting time?", + thread_id="test-thread" + ) + + assert "response" in result + assert "analysis" in result + assert "metrics" in result + assert "timestamp" in result + + def test_conversation_history(self): + """Test conversation history retrieval""" + agent = HealthcareAgent() + thread_id = "test-thread" + + # Add some messages + agent.process("Test message 1", thread_id=thread_id) + agent.process("Test message 2", thread_id=thread_id) + + history = agent.get_conversation_history(thread_id) + assert len(history) >= 2 + + def test_error_handling(self): + """Test error handling in agent""" + agent = HealthcareAgent() + + with pytest.raises(HealthcareError): + agent.process("", thread_id="test-thread") + + def test_state_management(self, mock_hospital_state): + """Test state management""" + agent = HealthcareAgent() + thread_id = "test-thread" + + # Process message + result = agent.process("Test message", thread_id=thread_id) + assert result is not None + + # Reset conversation + reset_success = agent.reset_conversation(thread_id) + assert reset_success is True + + # Verify reset + history = agent.get_conversation_history(thread_id) + assert len(history) == 0 + + @pytest.mark.asyncio + async def test_async_processing(self): + """Test async processing capabilities""" + agent = HealthcareAgent() + thread_id = "test-thread" + + # Test streaming response + async for event in agent.graph.astream_events( + {"messages": ["Test message"]}, + {"configurable": {"thread_id": thread_id}} + ): + assert event is not None# Integration tests implementation diff --git a/tests/test_nodes/test_input_analyzer.py b/tests/test_nodes/test_input_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..01b55036f28907fd5315666624ffbf8468d6724c --- /dev/null +++ b/tests/test_nodes/test_input_analyzer.py @@ -0,0 +1,27 @@ +# tests/test_nodes/test_input_analyzer.py +import pytest +from src.nodes.input_analyzer import InputAnalyzerNode +from src.models.state import TaskType, PriorityLevel + +def test_input_analyzer_initialization(mock_llm_response): + """Test InputAnalyzer node initialization""" + analyzer = InputAnalyzerNode(mock_llm_response) + assert analyzer is not None + +def test_input_analysis(mock_hospital_state, mock_llm_response): + """Test input analysis functionality""" + analyzer = InputAnalyzerNode(mock_llm_response) + result = analyzer(mock_hospital_state) + + assert "current_task" in result + assert "priority_level" in result + assert isinstance(result["current_task"], TaskType) + assert isinstance(result["priority_level"], PriorityLevel) + +def test_invalid_input_handling(mock_hospital_state): + """Test handling of invalid input""" + analyzer = InputAnalyzerNode(None) + mock_hospital_state["messages"] = [] + + with pytest.raises(ValueError): + analyzer(mock_hospital_state) \ No newline at end of file diff --git a/tests/test_nodes/test_patient_flow.py b/tests/test_nodes/test_patient_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..7223ad615b070e4d6eafbf64f270096bc7516012 --- /dev/null +++ b/tests/test_nodes/test_patient_flow.py @@ -0,0 +1,22 @@ +# tests/test_nodes/test_patient_flow.py +import pytest +from src.nodes.patient_flow import PatientFlowNode + +def test_patient_flow_analysis(mock_hospital_state, mock_llm_response): + """Test patient flow analysis""" + node = PatientFlowNode(mock_llm_response) + result = node(mock_hospital_state) + + assert "analysis" in result + assert "messages" in result + assert "recommendations" in result["analysis"] + +def test_occupancy_calculation(mock_hospital_state): + """Test occupancy calculation logic""" + node = PatientFlowNode(None) + metrics = mock_hospital_state["metrics"]["patient_flow"] + + occupancy = node._calculate_occupancy(metrics) + expected = (metrics["occupied_beds"] / metrics["total_beds"]) * 100 + + assert occupancy == expected \ No newline at end of file diff --git a/tests/test_tools/test_patient_tools.py b/tests/test_tools/test_patient_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a11784876316227a2ac41bb5ff786e9b885b7b69 --- /dev/null +++ b/tests/test_tools/test_patient_tools.py @@ -0,0 +1,27 @@ +# tests/test_tools/test_patient_tools.py +import pytest +from src.tools.patient_tools import PatientTools + +def test_wait_time_calculation(): + """Test wait time calculation""" + tools = PatientTools() + wait_time = tools.calculate_wait_time("ER", 10, 2) + + assert isinstance(wait_time, float) + assert wait_time > 0 + +def test_bed_capacity_analysis(): + """Test bed capacity analysis""" + tools = PatientTools() + result = tools.analyze_bed_capacity(100, 75, 5) + + assert "utilization_rate" in result + assert "status" in result + assert result["utilization_rate"] == 75.0 + +def test_invalid_capacity_input(): + """Test handling of invalid capacity input""" + tools = PatientTools() + + with pytest.raises(ValueError): + tools.analyze_bed_capacity(0, 10, 5) \ No newline at end of file