Spaces:
Sleeping
Sleeping
File size: 6,278 Bytes
2929135 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# src/models/state.py
from typing import Annotated, List, Dict, Optional
from typing_extensions import TypedDict # Changed this import
from langchain_core.messages import AnyMessage
from datetime import datetime
import operator
from enum import Enum
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
class TaskType(str, Enum):
PATIENT_FLOW = "patient_flow"
RESOURCE_MANAGEMENT = "resource_management"
QUALITY_MONITORING = "quality_monitoring"
STAFF_SCHEDULING = "staff_scheduling"
GENERAL = "general"
class PriorityLevel(int, Enum):
LOW = 1
MEDIUM = 2
HIGH = 3
URGENT = 4
CRITICAL = 5
class Department(TypedDict):
"""Department information"""
id: str
name: str
capacity: int
current_occupancy: int
staff_count: Dict[str, int]
wait_time: int
class PatientFlowMetrics(TypedDict):
"""Metrics related to patient flow"""
total_beds: int
occupied_beds: int
waiting_patients: int
average_wait_time: float
admission_rate: float
discharge_rate: float
department_metrics: Dict[str, "Department"]
class ResourceMetrics(TypedDict):
"""Metrics related to resource management"""
equipment_availability: Dict[str, bool]
supply_levels: Dict[str, float]
resource_utilization: float
pending_requests: int
critical_supplies: List[str]
class QualityMetrics(TypedDict):
"""Metrics related to quality monitoring"""
patient_satisfaction: float
care_outcomes: Dict[str, float]
compliance_rate: float
incident_count: int
quality_scores: Dict[str, float]
last_audit_date: datetime
class StaffingMetrics(TypedDict):
"""Metrics related to staff scheduling"""
total_staff: int
available_staff: Dict[str, int]
shifts_coverage: Dict[str, float]
overtime_hours: float
skill_mix_index: float
staff_satisfaction: float
class HospitalMetrics(TypedDict):
"""Combined hospital metrics"""
patient_flow: PatientFlowMetrics
resources: ResourceMetrics
quality: QualityMetrics
staffing: StaffingMetrics
last_updated: datetime
class AnalysisResult(TypedDict):
"""Analysis results from nodes"""
category: TaskType
priority: PriorityLevel
findings: List[str]
recommendations: List[str]
action_items: List[Dict[str, str]]
metrics_impact: Dict[str, float]
class HospitalState(TypedDict):
"""Main state management for the agent"""
messages: Annotated[List[AnyMessage], operator.add]
current_task: TaskType
priority_level: PriorityLevel
department: Optional[str]
metrics: HospitalMetrics
analysis: Optional[AnalysisResult]
context: Dict[str, any] # Will include routing information
timestamp: datetime
thread_id: str
def create_initial_state(thread_id: str) -> HospitalState:
"""Create initial state with default values"""
return {
"messages": [],
"current_task": TaskType.GENERAL,
"priority_level": PriorityLevel.MEDIUM,
"department": None,
"metrics": {
"patient_flow": {
"total_beds": 300,
"occupied_beds": 240,
"waiting_patients": 15,
"average_wait_time": 35.0,
"admission_rate": 4.2,
"discharge_rate": 3.8,
"department_metrics": {}
},
"resources": {
"equipment_availability": {},
"supply_levels": {},
"resource_utilization": 0.75,
"pending_requests": 5,
"critical_supplies": []
},
"quality": {
"patient_satisfaction": 8.5,
"care_outcomes": {},
"compliance_rate": 0.95,
"incident_count": 2,
"quality_scores": {},
"last_audit_date": datetime.now()
},
"staffing": {
"total_staff": 500,
"available_staff": {
"doctors": 50,
"nurses": 150,
"specialists": 30,
"support": 70
},
"shifts_coverage": {},
"overtime_hours": 120.5,
"skill_mix_index": 0.85,
"staff_satisfaction": 7.8
},
"last_updated": datetime.now()
},
"analysis": None,
"context": {
"next_node": None # Add routing context
},
"timestamp": datetime.now(),
"thread_id": thread_id
}
def validate_state(state: HospitalState) -> bool:
"""Validate state structure and data types"""
try:
# Basic structure validation
required_keys = [
"messages", "current_task", "priority_level",
"metrics", "timestamp", "thread_id"
]
for key in required_keys:
if key not in state:
raise ValueError(f"Missing required key: {key}")
# Validate messages
if not isinstance(state["messages"], list):
raise ValueError("Messages must be a list")
# Validate each message has required attributes
for msg in state["messages"]:
if not hasattr(msg, 'content'):
raise ValueError("Invalid message format - missing content")
# Validate types
if not isinstance(state["current_task"], TaskType):
raise ValueError("Invalid task type")
if not isinstance(state["priority_level"], PriorityLevel):
raise ValueError("Invalid priority level")
if not isinstance(state["timestamp"], datetime):
raise ValueError("Invalid timestamp")
return True
except Exception as e:
raise ValueError(f"State validation failed: {str(e)}")
def update_state_metrics(
state: HospitalState,
new_metrics: Dict,
category: str
) -> HospitalState:
"""Update specific category of metrics in state"""
if category not in state["metrics"]:
raise ValueError(f"Invalid metrics category: {category}")
state["metrics"][category].update(new_metrics)
state["metrics"]["last_updated"] = datetime.now()
return state |