ibrahimkettaneh's picture
Upload folder using huggingface_hub
790bcc2 verified
from dataclasses import dataclass
import json
from typing import List, Dict, Any, Optional
from openai import OpenAI
"""
EXAMPLE OUTPUT:
****************************************
RUNNING QUERY: What's the weather for Paris, TX in fahrenheit?
Agent Issued Step 1
----------------------------------------
Agent Issued Step 2
----------------------------------------
Agent Issued Step 3
----------------------------------------
AGENT MESSAGE: The current weather in Paris, TX is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90s.
Conversation Complete
****************************************
RUNNING QUERY: Who won the most recent PGA?
Agent Issued Step 1
----------------------------------------
Agent Issued Step 2
----------------------------------------
AGENT MESSAGE: I'm sorry, but I don't have the ability to provide sports information. I can help you with weather and location data. Is there anything else I can assist you with?
Conversation Complete
"""
@dataclass
class WeatherConfig:
"""Configuration for OpenAI and API settings"""
api_key: str = "" # The VLLM api_key
api_base: str = "" # The VLLM api_base URL
model: Optional[str] = None
max_steps: int = 5
class WeatherTools:
"""Collection of available tools/functions for the weather agent"""
@staticmethod
def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str:
"""Get weather for given coordinates"""
# We are mocking the weather here, but in the real world, you will submit a request here.
return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's."
@staticmethod
def get_geo_coordinates(city: str, state: str) -> str:
"""Get coordinates for a given city"""
coordinates = {
"Dallas": {"TX": (32.7767, -96.7970)},
"San Francisco": {"CA": (37.7749, -122.4194)},
"Paris": {"TX": (33.6609, 95.5555)}
}
lat, lon = coordinates.get(city, {}).get(state, (0, 0))
# We are mocking the weather here, but in the real world, you will submit a request here.
return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}"
@staticmethod
def no_relevant_function(user_query_span : str) -> str:
return "No relevant function for your request was found. We will stop here."
@staticmethod
def chat(chat_string : str):
print ("AGENT MESSAGE: ", chat_string)
class ToolRegistry:
"""Registry of available tools and their schemas"""
@property
def available_functions(self) -> Dict[str, callable]:
return {
"get_current_weather": WeatherTools.get_current_weather,
"get_geo_coordinates": WeatherTools.get_geo_coordinates,
"no_relevant_function" : WeatherTools.no_relevant_function,
"chat" : WeatherTools.chat
}
@property
def tool_schemas(self) -> List[Dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location. Use exact coordinates.",
"parameters": {
"type": "object",
"properties": {
"latitude": {"type": "array", "description": "The latitude for the city."},
"longitude": {"type": "array", "description": "The longitude for the city."},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["latitude", "longitude", "unit"]
}
}
},
{
"type": "function",
"function": {
"name": "get_geo_coordinates",
"description": "Get the latitude and longitude for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to find coordinates for"},
"state": {"type": "string", "description": "The two-letter state abbreviation"}
},
"required": ["city", "state"]
}
}
},
{
"type": "function",
"function" : {
"name": "no_relevant_function",
"description": "Call this when no other provided function can be called to answer the user query.",
"parameters": {
"type": "object",
"properties": {
"user_query_span": {
"type": "string",
"description": "The part of the user_query that cannot be answered by any other function calls."
}
},
"required": ["user_query_span"]
}
}
},
{
"type": "function",
"function": {
"name": "chat",
"description": "Call this tool when you want to chat with the user. The user won't see anything except for whatever you pass into this function.",
"parameters": {
"type": "object",
"properties": {
"chat_string": {
"type": "string",
"description": "The string to send to the user to chat back to them.",
}
},
"required": ["chat_string"],
},
},
},
]
class WeatherAgent:
"""Main agent class that handles the conversation and tool execution"""
def __init__(self, config: WeatherConfig):
self.config = config
self.client = OpenAI(api_key=config.api_key, base_url=config.api_base)
self.tools = ToolRegistry()
self.messages = []
if not config.model:
models = self.client.models.list()
self.config.model = models.data[0].id
def _serialize_tool_call(self, tool_call) -> Dict[str, Any]:
"""Convert tool call to serializable format"""
return {
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
def process_tool_calls(self, message) -> None:
"""Process and execute tool calls from assistant"""
for tool_call in message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
function_response = self.tools.available_functions[function_name](**function_args)
self.messages.append({
"role": "tool",
"content": json.dumps(function_response),
"tool_call_id": tool_call.id,
"name": function_name
})
def run_conversation(self, initial_query: str) -> None:
"""Run the main conversation loop"""
self.messages = [
{"role" : "system", "content" : "Make sure to use the chat() function to provide the final answer to the user."},
{"role": "user", "content": initial_query}]
print ("\n" * 5)
print ("*" * 40)
print (f"RUNNING QUERY: {initial_query}")
for step in range(self.config.max_steps):
response = self.client.chat.completions.create(
messages=self.messages,
model=self.config.model,
tools=self.tools.tool_schemas,
temperature=0.0,
)
message = response.choices[0].message
if not message.tool_calls:
print("Conversation Complete")
break
print(f"\nAgent Issued Step {step + 1}")
print("-" * 40)
self.messages.append({
"role": "assistant",
"content": json.dumps(message.content),
"tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls]
})
self.process_tool_calls(message)
if step >= self.config.max_steps - 1:
print("Maximum steps reached")
def main():
# Example usage
config = WeatherConfig()
agent = WeatherAgent(config)
agent.run_conversation("What's the weather for Paris, TX in fahrenheit?")
# Example OOD usage
agent.run_conversation("Who won the most recent PGA?")
if __name__ == "__main__":
main()