from openai import OpenAI
import json_repair
from transformers import AutoTokenizer
from prompts import *
import re
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from openai import RateLimitError


class ChatbotSimulation:
    def __init__(self, app_name, site_map, page_details, user_state, system_data, user_data, task, solution,
                 log_location, openai_api_key, agent='human',
                 max_steps=50, max_tokens=8192, buffer_tokens=500):
        self.app_name = app_name
        self.sitemap = site_map
        self.page_details = page_details
        self.user_state = user_state
        self.user_state['current_page'] = 'Home'  # Initialize current page
        self.user_state['task_completed'] = 'False'
        self.user_state['back'] = 'False'
        self.system_data = system_data
        self.user_data = user_data
        self.task = task
        self.solution = solution

        self.log_location = log_location
        self.agent = agent.lower()
        if self.agent not in ['human', 'llm']:
            raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.")
        self.max_steps = max_steps
        self.max_tokens = max_tokens
        self.buffer_tokens = buffer_tokens
        self.conversation = []  # Stores recent conversation snippets
        self.prompt_count = 0
        self.client = OpenAI(api_key=openai_api_key)
        self.actions = []
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)

        # back button
        self.page_history = ['Home']

    def _get_page_uid(self, page_name):
        """Retrieve the UID of the given page from the sitemap."""
        return self.sitemap.get(page_name, {}).get('uid')

    def _get_page_details(self, page_name):
        """Retrieve the page details using its UID."""
        uid = self._get_page_uid(page_name)
        return self.page_details.get(uid, {})

    def _generate_system_prompt(self):
        """Create a dynamic system prompt based on the current state."""
        current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
        last_page = self.page_history[-2] if len(self.page_history) > 1 else "Home"
        page_info = self._get_page_details(current_page)
        page_linkage = self.sitemap.get(current_page, {}).get('links_to')

        return get_system_prompt(app_name=self.app_name,
                                 system_data=self.system_data,
                                 task=self.task,
                                 user_data=self.user_data,
                                 current_page=current_page,
                                 last_page=last_page,
                                 actions=self.actions,
                                 user_state=self.user_state,
                                 page_info=page_info,
                                 page_linkage=page_linkage
                                 )

    @retry(
        retry=retry_if_exception_type(RateLimitError),
        wait=wait_fixed(5),  # Waits for 5 seconds between retries
        stop=stop_after_attempt(50000)  # Stops after 5 failed attempts
    )
    def _get_openai_response(self, prompt):
        """Fetch response from OpenAI API using tenacity for handling retries."""
        self._trim_conversation()
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=prompt,
            max_tokens=self.buffer_tokens,  # Adjusted max_tokens if needed
            temperature=0.7,
        )
        return response.choices[0].message.content

    def _calculate_token_count(self, conversation):
        """Accurately calculate the token count in the conversation using a tokenizer."""
        total_tokens = 0
        for entry in conversation:
            # Tokenize each entry content and count tokens
            tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False)
            total_tokens += len(tokens)
        return total_tokens

    def _trim_conversation(self):
        """Trim the conversation to keep it within the token limit."""
        while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2):
            self.conversation.pop(0)

    def one_conversation_round(self, user_input):
        """Conduct one round of conversation between the user and the assistant."""
        # User provides input
        valid_input = self._is_valid_input(user_input)
        if valid_input[0]:
            pass
        else:
            return f"\n{self.app_name}: Invalid input. {valid_input[1]}"

        self.actions.append(user_input + f'on {self.user_state["current_page"]} page')
        self.conversation.append({"role": "user", "content": user_input})
        self.prompt_count += 1

        # Update user state using GPT's response
        current_page = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
        update_prompt = get_user_state_update_prompt(user_input=user_input,
                                                     current_page=current_page,
                                                     task=self.task,
                                                     solution=self.solution,
                                                     user_state=self.user_state,
                                                     sitemap=self.sitemap)

        self.conversation.append({"role": "assistant", "content": update_prompt})
        updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip()
        self.conversation.pop(-1) ## update prompt don't have to stay in conversation history

        # Parse and update the user state
        updated_state = json_repair.loads(updated_state)

        # format forcing of updated state
        required_keys = {'current_page', 'task_completed', 'back'}
        # Ensure `updated_state` is a dictionary
        while not isinstance(updated_state, dict):
            transform_prompt = f"""
            Transform {updated_state} to a properly formatted JSON file.
            Example Output Format:
            {{
               'current_page': 'Home',
               'task_completed': False,
               'back': False
            }}
            """
            updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}])
            updated_state = json_repair.loads(updated_state)
        # Manually add missing required keys
        for key in required_keys:
            if key not in updated_state:
                if key == 'current_page':
                    updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
                else:
                    updated_state[key] = False

        try:
            if str(updated_state['task_completed']).lower() == 'true':
                return f"Task completed! You took {self.prompt_count} steps."
        except:
            updated_state['task_completed'] = 'False'

        self.user_state = updated_state
        if str(updated_state['back']).lower() == 'false':
            self.page_history.append(updated_state['current_page'])
        elif self.page_history:
            self.page_history.pop()

        ## no need to store old system prompt while we get a new one
        self.conversation = [entry for entry in self.conversation if entry["role"] != "system"]
        system_prompt = self._generate_system_prompt()
        # GPT generates the page instructions
        self.conversation.append({"role": "system", "content": system_prompt})
        gpt_instruction = self._get_openai_response(self.conversation)
        self.conversation.append({"role": "assistant", "content": gpt_instruction})
        return gpt_instruction

    def start_conversation(self):
        greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n'
        system_prompt = self._generate_system_prompt()
        # GPT generates the page instructions
        self.conversation.append({"role": "system", "content": system_prompt})
        gpt_instruction = self._get_openai_response(self.conversation)
        self.conversation.append({"role": "assistant", "content": gpt_instruction})
        return greeting + gpt_instruction

    def _extract_buttons(self):
        """Extract buttons and their action types from the latest conversation if role is 'assistant'."""
        # Get the last message
        last_message = self.conversation[-1]

        # Ensure the role of the last message is 'assistant'
        if last_message.get("role") != "assistant":
            return {}

        # Extract the content of the last message
        message_content = last_message.get("content", "")

        # Make the split case-insensitive by searching for the phrase with re.IGNORECASE
        options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE)

        # If the split doesn't produce at least two parts, return an empty dictionary
        if len(options_split) < 2:
            return {}

        # Extract button definitions from the second part of the split content
        button_section = options_split[1]
        pattern = r"\d+\.\s+(.*?):\s+([a-zA-Z_]+)"
        buttons = re.findall(pattern, button_section)

        # Construct the dictionary with button names as keys and action types as values
        return {name.strip().lower(): action_type.strip().lower() for name, action_type in buttons}

    def _is_valid_input(self, user_input):
        """Validate user input format."""
        valid_buttons = self._extract_buttons()

        if valid_buttons == {}:
            return [True, "Enter Anything is empty"]

        # Validate input format
        pattern = r"^(?P<action_type>\w+)\((?P<button_name>[^,]+)(?:,\s*(?P<query>.+))?\)$"
        match = re.match(pattern, user_input)

        if not match:
            return [False, "Your input doesn't match the format: action_type(button name), OR if type, use type(button name, query)"]

        # Extract parsed components
        action_type = match.group("action_type").lower()
        button_name = match.group("button_name").strip().lower()
        query = match.group("query")  # Optional query for `type`

        # Validate button name and action type
        if button_name not in valid_buttons:
            return [False,
                    "Invalid Button name! Recall: Each button is in the format: `number. button name: action_type`"]  # Button name must match exactly (case insensitive)
        if action_type != valid_buttons[button_name]:
            return [False,
                    "Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"]  # Action type must match the button's specified type
        if action_type == "type" and query is None:
            return [False,
                    "Missing Query for action type 'type'! Recall: use the format: `type(button name, query)`"]  # `type` action requires a query
        if action_type != "type" and query is not None:
            return [False,
                    "Non-`type` action_type cannot take query!"]  # Non-`type` actions must not have a query
        return [True, 'Pass']