Spaces:
Sleeping
Sleeping
# src/agent.py | |
from typing import Dict, Optional, List | |
import uuid | |
from datetime import datetime | |
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage | |
from langgraph.graph import StateGraph, END | |
from langchain_openai import ChatOpenAI | |
# Remove this line as it's causing the error | |
# from langgraph.checkpoint import BaseCheckpointSaver | |
from .config.settings import Settings | |
from .models.state import ( | |
HospitalState, | |
create_initial_state, | |
validate_state | |
) | |
from .nodes import ( | |
InputAnalyzerNode, | |
TaskRouterNode, | |
PatientFlowNode, | |
ResourceManagerNode, | |
QualityMonitorNode, | |
StaffSchedulerNode, | |
OutputSynthesizerNode | |
) | |
from .tools import ( | |
PatientTools, | |
ResourceTools, | |
QualityTools, | |
SchedulingTools | |
) | |
from .utils.logger import setup_logger | |
from .utils.error_handlers import ( | |
ErrorHandler, | |
HealthcareError, | |
ValidationError, # Add this import | |
ProcessingError # Add this import | |
) | |
logger = setup_logger(__name__) | |
class HealthcareAgent: | |
def __init__(self, api_key: Optional[str] = None): | |
try: | |
# Initialize settings and validate | |
self.settings = Settings() | |
if api_key: | |
self.settings.OPENAI_API_KEY = api_key | |
self.settings.validate_settings() | |
# Initialize LLM | |
self.llm = ChatOpenAI( | |
model=self.settings.MODEL_NAME, | |
temperature=self.settings.MODEL_TEMPERATURE, | |
api_key=self.settings.OPENAI_API_KEY | |
) | |
# Initialize tools | |
self.tools = self._initialize_tools() | |
# Initialize nodes | |
self.nodes = self._initialize_nodes() | |
# Initialize conversation states (replacing checkpointer) | |
self.conversation_states = {} | |
# Build graph | |
self.graph = self._build_graph() | |
logger.info("Healthcare Agent initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing Healthcare Agent: {str(e)}") | |
raise HealthcareError( | |
message="Failed to initialize Healthcare Agent", | |
error_code="INIT_ERROR", | |
details={"error": str(e)} | |
) | |
def _initialize_tools(self) -> Dict: | |
"""Initialize all tools used by the agent""" | |
return { | |
"patient": PatientTools(), | |
"resource": ResourceTools(), | |
"quality": QualityTools(), | |
"scheduling": SchedulingTools() | |
} | |
def _initialize_nodes(self) -> Dict: | |
"""Initialize all nodes in the agent workflow""" | |
return { | |
"input_analyzer": InputAnalyzerNode(self.llm), | |
"task_router": TaskRouterNode(), | |
"patient_flow": PatientFlowNode(self.llm), | |
"resource_manager": ResourceManagerNode(self.llm), | |
"quality_monitor": QualityMonitorNode(self.llm), | |
"staff_scheduler": StaffSchedulerNode(self.llm), | |
"output_synthesizer": OutputSynthesizerNode(self.llm) | |
} | |
def _build_graph(self) -> StateGraph: | |
"""Build the workflow graph with all nodes and edges""" | |
try: | |
# Initialize graph | |
builder = StateGraph(HospitalState) | |
# Add all nodes | |
for name, node in self.nodes.items(): | |
builder.add_node(name, node) | |
# Set entry point | |
builder.set_entry_point("input_analyzer") | |
# Add edge from input analyzer to task router | |
builder.add_edge("input_analyzer", "task_router") | |
# Define conditional routing based on task router output | |
def route_next(state: Dict): | |
return state["context"]["next_node"] | |
# Add conditional edges from task router | |
builder.add_conditional_edges( | |
"task_router", | |
route_next, | |
{ | |
"patient_flow": "patient_flow", | |
"resource_management": "resource_manager", | |
"quality_monitoring": "quality_monitor", | |
"staff_scheduling": "staff_scheduler", | |
"output_synthesis": "output_synthesizer" | |
} | |
) | |
# Add edges from functional nodes to output synthesizer | |
functional_nodes = [ | |
"patient_flow", | |
"resource_manager", | |
"quality_monitor", | |
"staff_scheduler" | |
] | |
for node in functional_nodes: | |
builder.add_edge(node, "output_synthesizer") | |
# Add end condition | |
builder.add_edge("output_synthesizer", END) | |
# Compile graph | |
return builder.compile() | |
except Exception as e: | |
logger.error(f"Error building graph: {str(e)}") | |
raise HealthcareError( | |
message="Failed to build agent workflow graph", | |
error_code="GRAPH_BUILD_ERROR", | |
details={"error": str(e)} | |
) | |
def process( | |
self, | |
input_text: str, | |
thread_id: Optional[str] = None, | |
context: Optional[Dict] = None | |
) -> Dict: | |
"""Process input through the healthcare operations workflow""" | |
try: | |
# Validate input | |
ErrorHandler.validate_input(input_text) | |
# Create or use thread ID | |
thread_id = thread_id or str(uuid.uuid4()) | |
# Initialize state | |
initial_state = create_initial_state(thread_id) | |
# Add input message as HumanMessage object | |
initial_state["messages"].append( | |
HumanMessage(content=input_text) | |
) | |
# Add context if provided | |
if context: | |
initial_state["context"].update(context) | |
# Validate state | |
validate_state(initial_state) | |
# Store state in conversation states | |
self.conversation_states[thread_id] = initial_state | |
# Process through graph | |
result = self.graph.invoke(initial_state) | |
return self._format_response(result) | |
except ValidationError as ve: | |
logger.error(f"Validation error: {str(ve)}") | |
raise | |
except Exception as e: | |
logger.error(f"Error processing input: {str(e)}") | |
raise HealthcareError( | |
message="Failed to process input", | |
error_code="PROCESSING_ERROR", | |
details={"error": str(e)} | |
) | |
def _format_response(self, result: Dict) -> Dict: | |
"""Format the final response from the graph execution""" | |
try: | |
if not result or "messages" not in result: | |
raise ProcessingError( | |
message="Invalid result format", | |
error_code="INVALID_RESULT", | |
details={"result": str(result)} | |
) | |
return { | |
"response": result["messages"][-1].content if result["messages"] else "", | |
"analysis": result.get("analysis", {}), | |
"metrics": result.get("metrics", {}), | |
"timestamp": datetime.now() | |
} | |
except Exception as e: | |
logger.error(f"Error formatting response: {str(e)}") | |
raise HealthcareError( | |
message="Failed to format response", | |
error_code="FORMAT_ERROR", | |
details={"error": str(e)} | |
) | |
def get_conversation_history( | |
self, | |
thread_id: str | |
) -> List[Dict]: | |
"""Retrieve conversation history for a specific thread""" | |
try: | |
return self.conversation_states.get(thread_id, {}).get("messages", []) | |
except Exception as e: | |
logger.error(f"Error retrieving conversation history: {str(e)}") | |
raise HealthcareError( | |
message="Failed to retrieve conversation history", | |
error_code="HISTORY_ERROR", | |
details={"error": str(e)} | |
) | |
def reset_conversation( | |
self, | |
thread_id: str | |
) -> bool: | |
"""Reset conversation state for a specific thread""" | |
try: | |
self.conversation_states[thread_id] = create_initial_state(thread_id) | |
return True | |
except Exception as e: | |
logger.error(f"Error resetting conversation: {str(e)}") | |
raise HealthcareError( | |
message="Failed to reset conversation", | |
error_code="RESET_ERROR", | |
details={"error": str(e)} | |
) |