Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import json_repair | |
from transformers import AutoTokenizer | |
from prompts import * | |
import re | |
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type | |
from openai import RateLimitError | |
class ChatbotSimulation: | |
def __init__(self, app_name, site_map, page_details, user_state, system_data, user_data, task, solution, | |
log_location, openai_api_key, agent='human', | |
max_steps=50, max_tokens=8192, buffer_tokens=500): | |
self.app_name = app_name | |
self.sitemap = site_map | |
self.page_details = page_details | |
self.user_state = user_state | |
self.user_state['current_page'] = 'Home' # Initialize current page | |
self.user_state['task_completed'] = 'False' | |
self.user_state['back'] = 'False' | |
self.system_data = system_data | |
self.user_data = user_data | |
self.task = task | |
self.solution = solution | |
self.log_location = log_location | |
self.agent = agent.lower() | |
if self.agent not in ['human', 'llm']: | |
raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.") | |
self.max_steps = max_steps | |
self.max_tokens = max_tokens | |
self.buffer_tokens = buffer_tokens | |
self.conversation = [] # Stores recent conversation snippets | |
self.prompt_count = 0 | |
self.client = OpenAI(api_key=openai_api_key) | |
self.actions = [] | |
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True) | |
# back button | |
self.page_history = ['Home'] | |
def _get_page_uid(self, page_name): | |
"""Retrieve the UID of the given page from the sitemap.""" | |
return self.sitemap.get(page_name, {}).get('uid') | |
def _get_page_details(self, page_name): | |
"""Retrieve the page details using its UID.""" | |
uid = self._get_page_uid(page_name) | |
return self.page_details.get(uid, {}) | |
def _generate_system_prompt(self): | |
"""Create a dynamic system prompt based on the current state.""" | |
current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home" | |
last_page = self.page_history[-2] if len(self.page_history) > 1 else "Home" | |
page_info = self._get_page_details(current_page) | |
page_linkage = self.sitemap.get(current_page, {}).get('links_to') | |
return get_system_prompt(app_name=self.app_name, | |
system_data=self.system_data, | |
task=self.task, | |
user_data=self.user_data, | |
current_page=current_page, | |
last_page=last_page, | |
actions=self.actions, | |
user_state=self.user_state, | |
page_info=page_info, | |
page_linkage=page_linkage | |
) | |
def _get_openai_response(self, prompt): | |
"""Fetch response from OpenAI API using tenacity for handling retries.""" | |
self._trim_conversation() | |
response = self.client.chat.completions.create( | |
model="gpt-4", | |
messages=prompt, | |
max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed | |
temperature=0.7, | |
) | |
return response.choices[0].message.content | |
def _calculate_token_count(self, conversation): | |
"""Accurately calculate the token count in the conversation using a tokenizer.""" | |
total_tokens = 0 | |
for entry in conversation: | |
# Tokenize each entry content and count tokens | |
tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False) | |
total_tokens += len(tokens) | |
return total_tokens | |
def _trim_conversation(self): | |
"""Trim the conversation to keep it within the token limit.""" | |
while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2): | |
self.conversation.pop(0) | |
def one_conversation_round(self, user_input): | |
"""Conduct one round of conversation between the user and the assistant.""" | |
# User provides input | |
valid_input = self._is_valid_input(user_input) | |
if valid_input[0]: | |
pass | |
else: | |
return f"\n{self.app_name}: Invalid input. {valid_input[1]}" | |
self.actions.append(user_input + f'on {self.user_state["current_page"]} page') | |
self.conversation.append({"role": "user", "content": user_input}) | |
self.prompt_count += 1 | |
# Update user state using GPT's response | |
current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home" | |
update_prompt = get_user_state_update_prompt(user_input=user_input, | |
current_page=current_page, | |
task=self.task, | |
solution=self.solution, | |
user_state=self.user_state, | |
sitemap=self.sitemap) | |
self.conversation.append({"role": "assistant", "content": update_prompt}) | |
updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip() | |
self.conversation.pop(-1) ## update prompt don't have to stay in conversation history | |
# Parse and update the user state | |
updated_state = json_repair.loads(updated_state) | |
# format forcing of updated state | |
required_keys = {'current_page', 'task_completed', 'back'} | |
# Ensure `updated_state` is a dictionary | |
while not isinstance(updated_state, dict): | |
transform_prompt = f""" | |
Transform {updated_state} to a properly formatted JSON file. | |
Example Output Format: | |
{{ | |
'current_page': 'Home', | |
'task_completed': False, | |
'back': False | |
}} | |
""" | |
updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}]) | |
updated_state = json_repair.loads(updated_state) | |
# Manually add missing required keys | |
for key in required_keys: | |
if key not in updated_state: | |
if key == 'current_page': | |
updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home" | |
else: | |
updated_state[key] = False | |
try: | |
if str(updated_state['task_completed']).lower() == 'true': | |
return f"Task completed! You took {self.prompt_count} steps." | |
except: | |
updated_state['task_completed'] = 'False' | |
self.user_state = updated_state | |
if str(updated_state['back']).lower() == 'false': | |
self.page_history.append(updated_state['current_page']) | |
elif self.page_history: | |
self.page_history.pop() | |
## no need to store old system prompt while we get a new one | |
self.conversation = [entry for entry in self.conversation if entry["role"] != "system"] | |
system_prompt = self._generate_system_prompt() | |
# GPT generates the page instructions | |
self.conversation.append({"role": "system", "content": system_prompt}) | |
gpt_instruction = self._get_openai_response(self.conversation) | |
self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
return gpt_instruction | |
def start_conversation(self): | |
greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n' | |
system_prompt = self._generate_system_prompt() | |
# GPT generates the page instructions | |
self.conversation.append({"role": "system", "content": system_prompt}) | |
gpt_instruction = self._get_openai_response(self.conversation) | |
self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
return greeting + gpt_instruction | |
def _extract_buttons(self): | |
"""Extract buttons and their action types from the latest conversation if role is 'assistant'.""" | |
# Get the last message | |
last_message = self.conversation[-1] | |
# Ensure the role of the last message is 'assistant' | |
if last_message.get("role") != "assistant": | |
return {} | |
# Extract the content of the last message | |
message_content = last_message.get("content", "") | |
# Make the split case-insensitive by searching for the phrase with re.IGNORECASE | |
options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE) | |
# If the split doesn't produce at least two parts, return an empty dictionary | |
if len(options_split) < 2: | |
return {} | |
# Extract button definitions from the second part of the split content | |
button_section = options_split[1] | |
pattern = r"\d+\.\s+(.*?):\s+([a-zA-Z_]+)" | |
buttons = re.findall(pattern, button_section) | |
# Construct the dictionary with button names as keys and action types as values | |
return {name.strip().lower(): action_type.strip().lower() for name, action_type in buttons} | |
def _is_valid_input(self, user_input): | |
"""Validate user input format.""" | |
valid_buttons = self._extract_buttons() | |
if valid_buttons == {}: | |
return [True, "Enter Anything is empty"] | |
# Validate input format | |
pattern = r"^(?P<action_type>\w+)\((?P<button_name>[^,]+)(?:,\s*(?P<query>.+))?\)$" | |
match = re.match(pattern, user_input) | |
if not match: | |
return [False, "Your input doesn't match the format: action_type(button name), OR if type, use type(button name, query)"] | |
# Extract parsed components | |
action_type = match.group("action_type").lower() | |
button_name = match.group("button_name").strip().lower() | |
query = match.group("query") # Optional query for `type` | |
# Validate button name and action type | |
if button_name not in valid_buttons: | |
return [False, | |
"Invalid Button name! Recall: Each button is in the format: `number. button name: action_type`"] # Button name must match exactly (case insensitive) | |
if action_type != valid_buttons[button_name]: | |
return [False, | |
"Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type | |
if action_type == "type" and query is None: | |
return [False, | |
"Missing Query for action type 'type'! Recall: use the format: `type(button name, query)`"] # `type` action requires a query | |
if action_type != "type" and query is not None: | |
return [False, | |
"Non-`type` action_type cannot take query!"] # Non-`type` actions must not have a query | |
return [True, 'Pass'] | |