DocUA's picture
add .env
2929135
raw
history blame
9.05 kB
# src/agent.py
from typing import Dict, Optional, List
import uuid
from datetime import datetime
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
# Remove this line as it's causing the error
# from langgraph.checkpoint import BaseCheckpointSaver
from .config.settings import Settings
from .models.state import (
HospitalState,
create_initial_state,
validate_state
)
from .nodes import (
InputAnalyzerNode,
TaskRouterNode,
PatientFlowNode,
ResourceManagerNode,
QualityMonitorNode,
StaffSchedulerNode,
OutputSynthesizerNode
)
from .tools import (
PatientTools,
ResourceTools,
QualityTools,
SchedulingTools
)
from .utils.logger import setup_logger
from .utils.error_handlers import (
ErrorHandler,
HealthcareError,
ValidationError, # Add this import
ProcessingError # Add this import
)
logger = setup_logger(__name__)
class HealthcareAgent:
def __init__(self, api_key: Optional[str] = None):
try:
# Initialize settings and validate
self.settings = Settings()
if api_key:
self.settings.OPENAI_API_KEY = api_key
self.settings.validate_settings()
# Initialize LLM
self.llm = ChatOpenAI(
model=self.settings.MODEL_NAME,
temperature=self.settings.MODEL_TEMPERATURE,
api_key=self.settings.OPENAI_API_KEY
)
# Initialize tools
self.tools = self._initialize_tools()
# Initialize nodes
self.nodes = self._initialize_nodes()
# Initialize conversation states (replacing checkpointer)
self.conversation_states = {}
# Build graph
self.graph = self._build_graph()
logger.info("Healthcare Agent initialized successfully")
except Exception as e:
logger.error(f"Error initializing Healthcare Agent: {str(e)}")
raise HealthcareError(
message="Failed to initialize Healthcare Agent",
error_code="INIT_ERROR",
details={"error": str(e)}
)
def _initialize_tools(self) -> Dict:
"""Initialize all tools used by the agent"""
return {
"patient": PatientTools(),
"resource": ResourceTools(),
"quality": QualityTools(),
"scheduling": SchedulingTools()
}
def _initialize_nodes(self) -> Dict:
"""Initialize all nodes in the agent workflow"""
return {
"input_analyzer": InputAnalyzerNode(self.llm),
"task_router": TaskRouterNode(),
"patient_flow": PatientFlowNode(self.llm),
"resource_manager": ResourceManagerNode(self.llm),
"quality_monitor": QualityMonitorNode(self.llm),
"staff_scheduler": StaffSchedulerNode(self.llm),
"output_synthesizer": OutputSynthesizerNode(self.llm)
}
def _build_graph(self) -> StateGraph:
"""Build the workflow graph with all nodes and edges"""
try:
# Initialize graph
builder = StateGraph(HospitalState)
# Add all nodes
for name, node in self.nodes.items():
builder.add_node(name, node)
# Set entry point
builder.set_entry_point("input_analyzer")
# Add edge from input analyzer to task router
builder.add_edge("input_analyzer", "task_router")
# Define conditional routing based on task router output
def route_next(state: Dict):
return state["context"]["next_node"]
# Add conditional edges from task router
builder.add_conditional_edges(
"task_router",
route_next,
{
"patient_flow": "patient_flow",
"resource_management": "resource_manager",
"quality_monitoring": "quality_monitor",
"staff_scheduling": "staff_scheduler",
"output_synthesis": "output_synthesizer"
}
)
# Add edges from functional nodes to output synthesizer
functional_nodes = [
"patient_flow",
"resource_manager",
"quality_monitor",
"staff_scheduler"
]
for node in functional_nodes:
builder.add_edge(node, "output_synthesizer")
# Add end condition
builder.add_edge("output_synthesizer", END)
# Compile graph
return builder.compile()
except Exception as e:
logger.error(f"Error building graph: {str(e)}")
raise HealthcareError(
message="Failed to build agent workflow graph",
error_code="GRAPH_BUILD_ERROR",
details={"error": str(e)}
)
@ErrorHandler.error_decorator
def process(
self,
input_text: str,
thread_id: Optional[str] = None,
context: Optional[Dict] = None
) -> Dict:
"""Process input through the healthcare operations workflow"""
try:
# Validate input
ErrorHandler.validate_input(input_text)
# Create or use thread ID
thread_id = thread_id or str(uuid.uuid4())
# Initialize state
initial_state = create_initial_state(thread_id)
# Add input message as HumanMessage object
initial_state["messages"].append(
HumanMessage(content=input_text)
)
# Add context if provided
if context:
initial_state["context"].update(context)
# Validate state
validate_state(initial_state)
# Store state in conversation states
self.conversation_states[thread_id] = initial_state
# Process through graph
result = self.graph.invoke(initial_state)
return self._format_response(result)
except ValidationError as ve:
logger.error(f"Validation error: {str(ve)}")
raise
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
raise HealthcareError(
message="Failed to process input",
error_code="PROCESSING_ERROR",
details={"error": str(e)}
)
def _format_response(self, result: Dict) -> Dict:
"""Format the final response from the graph execution"""
try:
if not result or "messages" not in result:
raise ProcessingError(
message="Invalid result format",
error_code="INVALID_RESULT",
details={"result": str(result)}
)
return {
"response": result["messages"][-1].content if result["messages"] else "",
"analysis": result.get("analysis", {}),
"metrics": result.get("metrics", {}),
"timestamp": datetime.now()
}
except Exception as e:
logger.error(f"Error formatting response: {str(e)}")
raise HealthcareError(
message="Failed to format response",
error_code="FORMAT_ERROR",
details={"error": str(e)}
)
def get_conversation_history(
self,
thread_id: str
) -> List[Dict]:
"""Retrieve conversation history for a specific thread"""
try:
return self.conversation_states.get(thread_id, {}).get("messages", [])
except Exception as e:
logger.error(f"Error retrieving conversation history: {str(e)}")
raise HealthcareError(
message="Failed to retrieve conversation history",
error_code="HISTORY_ERROR",
details={"error": str(e)}
)
def reset_conversation(
self,
thread_id: str
) -> bool:
"""Reset conversation state for a specific thread"""
try:
self.conversation_states[thread_id] = create_initial_state(thread_id)
return True
except Exception as e:
logger.error(f"Error resetting conversation: {str(e)}")
raise HealthcareError(
message="Failed to reset conversation",
error_code="RESET_ERROR",
details={"error": str(e)}
)