App_Simulator / chatbot_simulator.py
jjz5463's picture
update retries
42b2eba
raw
history blame
11.3 kB
from openai import OpenAI
import json_repair
from transformers import AutoTokenizer
from prompts import *
import re
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from openai import RateLimitError
class ChatbotSimulation:
def __init__(self, app_name, site_map, page_details, user_state, system_data, user_data, task, solution,
log_location, openai_api_key, agent='human',
max_steps=50, max_tokens=8192, buffer_tokens=500):
self.app_name = app_name
self.sitemap = site_map
self.page_details = page_details
self.user_state = user_state
self.user_state['current_page'] = 'Home' # Initialize current page
self.user_state['task_completed'] = 'False'
self.user_state['back'] = 'False'
self.system_data = system_data
self.user_data = user_data
self.task = task
self.solution = solution
self.log_location = log_location
self.agent = agent.lower()
if self.agent not in ['human', 'llm']:
raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.")
self.max_steps = max_steps
self.max_tokens = max_tokens
self.buffer_tokens = buffer_tokens
self.conversation = [] # Stores recent conversation snippets
self.prompt_count = 0
self.client = OpenAI(api_key=openai_api_key)
self.actions = []
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
# back button
self.page_history = ['Home']
def _get_page_uid(self, page_name):
"""Retrieve the UID of the given page from the sitemap."""
return self.sitemap.get(page_name, {}).get('uid')
def _get_page_details(self, page_name):
"""Retrieve the page details using its UID."""
uid = self._get_page_uid(page_name)
return self.page_details.get(uid, {})
def _generate_system_prompt(self):
"""Create a dynamic system prompt based on the current state."""
current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
last_page = self.page_history[-2] if len(self.page_history) > 1 else "Home"
page_info = self._get_page_details(current_page)
page_linkage = self.sitemap.get(current_page, {}).get('links_to')
return get_system_prompt(app_name=self.app_name,
system_data=self.system_data,
task=self.task,
user_data=self.user_data,
current_page=current_page,
last_page=last_page,
actions=self.actions,
user_state=self.user_state,
page_info=page_info,
page_linkage=page_linkage
)
@retry(
retry=retry_if_exception_type(RateLimitError),
wait=wait_fixed(5), # Waits for 5 seconds between retries
stop=stop_after_attempt(50000) # Stops after 5 failed attempts
)
def _get_openai_response(self, prompt):
"""Fetch response from OpenAI API using tenacity for handling retries."""
self._trim_conversation()
response = self.client.chat.completions.create(
model="gpt-4",
messages=prompt,
max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed
temperature=0.7,
)
return response.choices[0].message.content
def _calculate_token_count(self, conversation):
"""Accurately calculate the token count in the conversation using a tokenizer."""
total_tokens = 0
for entry in conversation:
# Tokenize each entry content and count tokens
tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
def _trim_conversation(self):
"""Trim the conversation to keep it within the token limit."""
while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2):
self.conversation.pop(0)
def one_conversation_round(self, user_input):
"""Conduct one round of conversation between the user and the assistant."""
# User provides input
valid_input = self._is_valid_input(user_input)
if valid_input[0]:
pass
else:
return f"\n{self.app_name}: Invalid input. {valid_input[1]}"
self.actions.append(user_input + f'on {self.user_state["current_page"]} page')
self.conversation.append({"role": "user", "content": user_input})
self.prompt_count += 1
# Update user state using GPT's response
current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
update_prompt = get_user_state_update_prompt(user_input=user_input,
current_page=current_page,
task=self.task,
solution=self.solution,
user_state=self.user_state,
sitemap=self.sitemap)
self.conversation.append({"role": "assistant", "content": update_prompt})
updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip()
self.conversation.pop(-1) ## update prompt don't have to stay in conversation history
# Parse and update the user state
updated_state = json_repair.loads(updated_state)
# format forcing of updated state
required_keys = {'current_page', 'task_completed', 'back'}
# Ensure `updated_state` is a dictionary
while not isinstance(updated_state, dict):
transform_prompt = f"""
Transform {updated_state} to a properly formatted JSON file.
Example Output Format:
{{
'current_page': 'Home',
'task_completed': False,
'back': False
}}
"""
updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}])
updated_state = json_repair.loads(updated_state)
# Manually add missing required keys
for key in required_keys:
if key not in updated_state:
if key == 'current_page':
updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
else:
updated_state[key] = False
try:
if str(updated_state['task_completed']).lower() == 'true':
return f"Task completed! You took {self.prompt_count} steps."
except:
updated_state['task_completed'] = 'False'
self.user_state = updated_state
if str(updated_state['back']).lower() == 'false':
self.page_history.append(updated_state['current_page'])
elif self.page_history:
self.page_history.pop()
## no need to store old system prompt while we get a new one
self.conversation = [entry for entry in self.conversation if entry["role"] != "system"]
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
return gpt_instruction
def start_conversation(self):
greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n'
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
return greeting + gpt_instruction
def _extract_buttons(self):
"""Extract buttons and their action types from the latest conversation if role is 'assistant'."""
# Get the last message
last_message = self.conversation[-1]
# Ensure the role of the last message is 'assistant'
if last_message.get("role") != "assistant":
return {}
# Extract the content of the last message
message_content = last_message.get("content", "")
# Make the split case-insensitive by searching for the phrase with re.IGNORECASE
options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE)
# If the split doesn't produce at least two parts, return an empty dictionary
if len(options_split) < 2:
return {}
# Extract button definitions from the second part of the split content
button_section = options_split[1]
pattern = r"\d+\.\s+(.*?):\s+([a-zA-Z_]+)"
buttons = re.findall(pattern, button_section)
# Construct the dictionary with button names as keys and action types as values
return {name.strip().lower(): action_type.strip().lower() for name, action_type in buttons}
def _is_valid_input(self, user_input):
"""Validate user input format."""
valid_buttons = self._extract_buttons()
if valid_buttons == {}:
return [True, "Enter Anything is empty"]
# Validate input format
pattern = r"^(?P<action_type>\w+)\((?P<button_name>[^,]+)(?:,\s*(?P<query>.+))?\)$"
match = re.match(pattern, user_input)
if not match:
return [False, "Your input doesn't match the format: action_type(button name), OR if type, use type(button name, query)"]
# Extract parsed components
action_type = match.group("action_type").lower()
button_name = match.group("button_name").strip().lower()
query = match.group("query") # Optional query for `type`
# Validate button name and action type
if button_name not in valid_buttons:
return [False,
"Invalid Button name! Recall: Each button is in the format: `number. button name: action_type`"] # Button name must match exactly (case insensitive)
if action_type != valid_buttons[button_name]:
return [False,
"Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type
if action_type == "type" and query is None:
return [False,
"Missing Query for action type 'type'! Recall: use the format: `type(button name, query)`"] # `type` action requires a query
if action_type != "type" and query is not None:
return [False,
"Non-`type` action_type cannot take query!"] # Non-`type` actions must not have a query
return [True, 'Pass']