oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
raw
history blame
6.49 kB
import json
import operator
from pathlib import Path
from dotenv import load_dotenv
from datetime import datetime
from typing import List, Dict, Any, TypedDict, Annotated, Optional
from langgraph.graph import StateGraph, END
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
_ = load_dotenv()
class ToolCallLog(TypedDict):
"""
A TypedDict representing a log entry for a tool call.
Attributes:
timestamp (str): The timestamp of when the tool call was made.
tool_call_id (str): The unique identifier for the tool call.
name (str): The name of the tool that was called.
args (Any): The arguments passed to the tool.
content (str): The content or result of the tool call.
"""
timestamp: str
tool_call_id: str
name: str
args: Any
content: str
class AgentState(TypedDict):
"""
A TypedDict representing the state of an agent.
Attributes:
messages (Annotated[List[AnyMessage], operator.add]): A list of messages
representing the conversation history. The operator.add annotation
indicates that new messages should be appended to this list.
"""
messages: Annotated[List[AnyMessage], operator.add]
class Agent:
"""
A class representing an agent that processes requests and executes tools based on
language model responses.
Attributes:
model (BaseLanguageModel): The language model used for processing.
tools (Dict[str, BaseTool]): A dictionary of available tools.
checkpointer (Any): Manages and persists the agent's state.
system_prompt (str): The system instructions for the agent.
workflow (StateGraph): The compiled workflow for the agent's processing.
log_tools (bool): Whether to log tool calls.
log_path (Path): Path to save tool call logs.
"""
def __init__(
self,
model: BaseLanguageModel,
tools: List[BaseTool],
checkpointer: Any = None,
system_prompt: str = "",
log_tools: bool = True,
log_dir: Optional[str] = "logs",
):
"""
Initialize the Agent.
Args:
model (BaseLanguageModel): The language model to use.
tools (List[BaseTool]): A list of available tools.
checkpointer (Any, optional): State persistence manager. Defaults to None.
system_prompt (str, optional): System instructions. Defaults to "".
log_tools (bool, optional): Whether to log tool calls. Defaults to True.
log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
"""
self.system_prompt = system_prompt
self.log_tools = log_tools
if self.log_tools:
self.log_path = Path(log_dir or "logs")
self.log_path.mkdir(exist_ok=True)
# Define the agent workflow
workflow = StateGraph(AgentState)
workflow.add_node("process", self.process_request)
workflow.add_node("execute", self.execute_tools)
workflow.add_conditional_edges(
"process", self.has_tool_calls, {True: "execute", False: END}
)
workflow.add_edge("execute", "process")
workflow.set_entry_point("process")
self.workflow = workflow.compile(checkpointer=checkpointer)
self.tools = {t.name: t for t in tools}
self.model = model.bind_tools(tools)
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
"""
Process the request using the language model.
Args:
state (AgentState): The current state of the agent.
Returns:
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
"""
messages = state["messages"]
if self.system_prompt:
messages = [SystemMessage(content=self.system_prompt)] + messages
response = self.model.invoke(messages)
return {"messages": [response]}
def has_tool_calls(self, state: AgentState) -> bool:
"""
Check if the response contains any tool calls.
Args:
state (AgentState): The current state of the agent.
Returns:
bool: True if tool calls exist, False otherwise.
"""
response = state["messages"][-1]
return len(response.tool_calls) > 0
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
"""
Execute tool calls from the model's response.
Args:
state (AgentState): The current state of the agent.
Returns:
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
"""
tool_calls = state["messages"][-1].tool_calls
results = []
for call in tool_calls:
print(f"Executing tool: {call}")
if call["name"] not in self.tools:
print("\n....invalid tool....")
result = "invalid tool, please retry"
else:
result = self.tools[call["name"]].invoke(call["args"])
results.append(
ToolMessage(
tool_call_id=call["id"],
name=call["name"],
args=call["args"],
content=str(result),
)
)
self._save_tool_calls(results)
print("Returning to model processing!")
return {"messages": results}
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
"""
Save tool calls to a JSON file with timestamp-based naming.
Args:
tool_calls (List[ToolMessage]): List of tool calls to save.
"""
if not self.log_tools:
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = self.log_path / f"tool_calls_{timestamp}.json"
logs: List[ToolCallLog] = []
for call in tool_calls:
log_entry = {
"tool_call_id": call.tool_call_id,
"name": call.name,
"args": call.args,
"content": call.content,
"timestamp": datetime.now().isoformat(),
}
logs.append(log_entry)
with open(filename, "w") as f:
json.dump(logs, f, indent=4)