Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import json_repair | |
from transformers import AutoTokenizer | |
from prompts import * | |
import re | |
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type | |
from openai import RateLimitError | |
from difflib import get_close_matches | |
class ChatbotSimulation: | |
def __init__(self, app_name, app_description, site_map, relevant_tables_per_page, | |
database, jinjia_prerender_page, task, solution, | |
log_location, openai_api_key, agent='human', | |
max_steps=30, max_tokens=8192, buffer_tokens=500): | |
self.app_name = app_name | |
self.app_description = app_description | |
self.sitemap = site_map | |
self.relevant_tables_per_page = relevant_tables_per_page | |
self.database = database | |
self.jinjia_prerender_page = jinjia_prerender_page | |
self.task = task | |
self.solution = solution | |
self.user_state = dict() | |
self.user_state['current_page'] = self.sitemap['pages'][0]['id'] # Initialize current page | |
self.user_state['task_completed'] = 'False' | |
self.user_state['logged_in'] = 'False' | |
self.user_state['back'] = 'False' | |
self.log_location = log_location | |
self.agent = agent.lower() | |
if self.agent not in ['human', 'llm']: | |
raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.") | |
self.max_steps = max_steps | |
self.max_tokens = max_tokens | |
self.buffer_tokens = buffer_tokens | |
self.conversation = [] # Stores recent conversation snippets | |
self.trajectory = [{"role": "system", "content": f"Welcome to {app_name} simulator! Your task is: {task}"}] | |
self.prompt_count = 0 | |
self.client = OpenAI(api_key=openai_api_key) | |
self.actions = [] | |
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True) | |
#back button | |
self.page_history = ['Home'] | |
def _get_relevant_data(self, current_page): | |
# Check if the current page exists as a key | |
if current_page in self.relevant_tables_per_page: | |
relevant_tables = self.relevant_tables_per_page[current_page] | |
else: | |
# Find the closest matching key | |
closest_match = get_close_matches(current_page, self.relevant_tables_per_page.keys(), n=1, cutoff=0.5) | |
if closest_match: | |
relevant_tables = self.relevant_tables_per_page[closest_match[0]] | |
else: | |
return self.database | |
return {table: self.database[table] for table in relevant_tables if table in self.database} | |
def _get_prerender_page(self, current_page): | |
if current_page in self.jinjia_prerender_page: | |
return self.jinjia_prerender_page[current_page] | |
else: | |
closest_match = get_close_matches(current_page, self.jinjia_prerender_page.keys(), n=1, cutoff=0) | |
return self.jinjia_prerender_page[closest_match[0]] | |
def _generate_system_prompt(self): | |
"""Create a dynamic system prompt based on the current state.""" | |
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id'] | |
last_page = self.page_history[-2] if len(self.page_history) > 1 else self.sitemap['pages'][0]['id'] | |
relevant_database = self._get_relevant_data(current_page) | |
relevant_sitemap = next((page for page in self.sitemap["pages"] if page["id"] == current_page), self.sitemap["pages"]) | |
prerender_page = self._get_prerender_page(current_page) | |
return get_system_prompt(app_name=self.app_name, | |
app_description=self.app_description, | |
relevant_database=relevant_database, | |
task=self.task, | |
current_page=current_page, | |
last_page=last_page, | |
actions=self.actions, | |
sitemap_page=relevant_sitemap, | |
jinjia_prerender=prerender_page, | |
) | |
def _get_openai_response(self, prompt): | |
"""Fetch response from OpenAI API using tenacity for handling retries.""" | |
self._trim_conversation() | |
response = self.client.chat.completions.create( | |
model="gpt-4", | |
messages=prompt, | |
max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed | |
temperature=0.7, | |
) | |
return response.choices[0].message.content | |
def _calculate_token_count(self, conversation): | |
"""Accurately calculate the token count in the conversation using a tokenizer.""" | |
total_tokens = 0 | |
for entry in conversation: | |
# Tokenize each entry content and count tokens | |
tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False) | |
total_tokens += len(tokens) | |
return total_tokens | |
def _trim_conversation(self): | |
"""Trim the conversation to keep it within the token limit.""" | |
while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2): | |
self.conversation.pop(0) | |
def one_conversation_round(self, user_input): | |
"""Conduct one round of conversation between the user and the assistant.""" | |
# User provides input | |
self.trajectory.append({"role": "user", "content": f'Human: {user_input}'}) | |
valid_input = self._is_valid_input(user_input) | |
if valid_input[0]: | |
pass | |
else: | |
self.prompt_count += 1 | |
invalid_input_message = f"\n{self.app_name}: Invalid input. {valid_input[1]}" | |
self.trajectory.append({"role": "assistant", "content": invalid_input_message}) | |
return invalid_input_message | |
self.actions.append(user_input + f'on {self.user_state["current_page"]} page') | |
self.conversation.append({"role": "user", "content": user_input}) | |
self.prompt_count += 1 | |
# Update user state using GPT's response | |
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id'] | |
update_prompt = get_user_state_update_prompt(user_input=user_input, | |
current_page=current_page, | |
task=self.task, | |
database=self.database, | |
solution=self.solution, | |
user_state=self.user_state, | |
sitemap=self.sitemap) | |
self.conversation.append({"role": "user", "content": update_prompt}) | |
updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip() | |
self.conversation.pop(-1) # update prompt don't have to stay in conversation history | |
# Parse and update the user state | |
updated_state = json_repair.loads(updated_state) | |
# format forcing of updated state | |
required_keys = {'current_page', 'task_completed', 'back'} | |
# Ensure `updated_state` is a dictionary | |
while not isinstance(updated_state, dict): | |
transform_prompt = f""" | |
Transform {updated_state} to a properly formatted JSON file. | |
Example Output Format: | |
{{ | |
'current_page': 'Home', | |
'task_completed': False, | |
'back': False | |
}} | |
""" | |
updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}]) | |
updated_state = json_repair.loads(updated_state) | |
# Manually add missing required keys | |
for key in required_keys: | |
if key not in updated_state: | |
if key == 'current_page': | |
updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home" | |
else: | |
updated_state[key] = False | |
try: | |
if str(updated_state['task_completed']).lower() == 'true': | |
complete_message = f"{self.app_name}: Task completed! You took {self.prompt_count} steps." | |
self.trajectory.append({"role": "assistant", "content": complete_message}) | |
return complete_message | |
except: | |
updated_state['task_completed'] = 'False' | |
self.user_state = updated_state | |
if str(updated_state['back']).lower() == 'false': | |
self.page_history.append(updated_state['current_page']) | |
elif self.page_history: | |
self.page_history.pop() | |
## no need to store old system prompt while we get a new one | |
self.conversation = [entry for entry in self.conversation if entry["role"] != "system"] | |
system_prompt = self._generate_system_prompt() | |
# GPT generates the page instructions | |
self.conversation.append({"role": "system", "content": system_prompt}) | |
gpt_instruction = self._get_openai_response(self.conversation) | |
self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
self.trajectory.append({"role": "assistant", "content": gpt_instruction}) | |
return gpt_instruction | |
def start_conversation(self): | |
greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n' | |
system_prompt = self._generate_system_prompt() | |
# GPT generates the page instructions | |
self.conversation.append({"role": "system", "content": system_prompt}) | |
gpt_instruction = self._get_openai_response(self.conversation) | |
self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
return greeting + gpt_instruction | |
def _extract_buttons(self): | |
"""Extract button numbers and their action types from the latest conversation if role is 'assistant'.""" | |
# Get the last message | |
last_message = self.conversation[-1] | |
# Ensure the role of the last message is 'assistant' | |
if last_message.get("role") != "assistant": | |
return {} | |
# Extract the content of the last message | |
message_content = last_message.get("content", "") | |
# Split the message content to isolate the button section | |
options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE) | |
# If the split doesn't produce at least two parts, return an empty dictionary | |
if len(options_split) < 2: | |
return {} | |
# Extract button definitions from the second part of the split content | |
button_section = options_split[1] | |
pattern = r"(\d+)\.\s+(.*?):\s+([a-zA-Z_]+)" # Capture the number, button name, and action type | |
buttons = re.findall(pattern, button_section) | |
# Construct the dictionary with button numbers as keys and action types as values | |
return {number: action_type.strip().lower() for number, _, action_type in buttons} | |
def _is_valid_input(self, user_input): | |
"""Validate user input format.""" | |
valid_buttons = self._extract_buttons() | |
if valid_buttons == {}: | |
return [True, "Enter Anything is empty"] | |
# Validate input format | |
pattern = r"^(?P<action_type>\w+)\((?P<button_number>[^,]+)(?:,\s*(?P<query>.+))?\)$" | |
match = re.match(pattern, user_input) | |
if not match: | |
return [False, | |
"Your input doesn't match the format: action_type(button number), OR if text_box, use text_box(button number, query), eg. noop(12). No indent before input and No extra input before or after action_type(button number)!"] | |
# Extract parsed components | |
action_type = match.group("action_type").lower() | |
button_name = match.group("button_number").strip().lower() | |
query = match.group("query") # Optional query for `type` | |
# Validate button number and action type | |
if button_name not in valid_buttons: | |
return [False, | |
"Invalid Button number! Recall: Each button is in the format: `number. button name: action_type`. Correct example: link(3), text_box(2, query)"] # Button number must match exactly (case insensitive) | |
if action_type != valid_buttons[button_name]: | |
return [False, | |
"Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type | |
if action_type == "text_box" and query is None: | |
return [False, | |
"Missing Query for action type 'text_box'! Recall: use the format: `text_box(button number, query)`"] # `text_box` action requires a query | |
if action_type != "text_box" and query is not None: | |
return [False, | |
"Non-`text_box` action_type cannot take query!"] # Non-`type` actions must not have a query | |
return [True, 'Pass'] | |