import aiohttp
import json
import urllib

from networks import (
    ChathubRequestPayloadConstructor,
    ConversationRequestHeadersConstructor,
    MessageParser,
    OpenaiStreamOutputer,
)
from conversations import ConversationStyle
from utils.logger import logger
from utils.enver import enver


class ConversationConnector:
    """
    Input params:

    - `sec_access_token`, `client_id`, `conversation_id`
        - Generated by `ConversationCreator`
    - `invocation_id` (int):
      - For 1st request, this value must be `0`.
      - For all requests after, any integer is valid.
      - To make it simple, use `1` for all requests after the 1st one.
    """

    def __init__(
        self,
        conversation_style: str = ConversationStyle.PRECISE.value,
        sec_access_token: str = "",
        client_id: str = "",
        conversation_id: str = "",
        invocation_id: int = 0,
        cookies={},
    ):
        conversation_style_enum_values = [
            style.value for style in ConversationStyle.__members__.values()
        ]

        if conversation_style.lower() not in conversation_style_enum_values:
            self.conversation_style = ConversationStyle.PRECISE.value
        else:
            self.conversation_style = conversation_style.lower()
        print(f"Model: [{self.conversation_style}]")

        self.sec_access_token = sec_access_token
        self.client_id = client_id
        self.conversation_id = conversation_id
        self.invocation_id = invocation_id
        self.cookies = cookies

    async def wss_send(self, message):
        serialized_websocket_message = json.dumps(message, ensure_ascii=False) + "\x1e"
        await self.wss.send_str(serialized_websocket_message)

    async def init_handshake(self):
        await self.wss_send({"protocol": "json", "version": 1})
        await self.wss.receive_str()
        await self.wss_send({"type": 6})

    async def connect(self):
        self.quotelized_sec_access_token = urllib.parse.quote(self.sec_access_token)
        self.ws_url = (
            f"wss://sydney.bing.com/sydney/ChatHub"
            f"?sec_access_token={self.quotelized_sec_access_token}"
        )
        self.aiohttp_session = aiohttp.ClientSession(cookies=self.cookies)
        headers_constructor = ConversationRequestHeadersConstructor()
        enver.set_envs(proxies=True)
        self.wss = await self.aiohttp_session.ws_connect(
            self.ws_url,
            headers=headers_constructor.request_headers,
            proxy=enver.proxy,
        )
        await self.init_handshake()

    async def send_chathub_request(self, prompt: str, system_prompt: str = None):
        payload_constructor = ChathubRequestPayloadConstructor(
            prompt=prompt,
            conversation_style=self.conversation_style,
            client_id=self.client_id,
            conversation_id=self.conversation_id,
            invocation_id=self.invocation_id,
            system_prompt=system_prompt,
        )
        self.connect_request_payload = payload_constructor.request_payload
        await self.wss_send(self.connect_request_payload)

    async def stream_chat(
        self, prompt: str = "", system_prompt: str = None, yield_output=False
    ):
        await self.connect()
        await self.send_chathub_request(prompt=prompt, system_prompt=system_prompt)
        message_parser = MessageParser(outputer=OpenaiStreamOutputer())
        has_output_role_message = False
        if yield_output and not has_output_role_message:
            has_output_role_message = True
            yield message_parser.outputer.output(content="", content_type="Role")
        while not self.wss.closed:
            response_lines_str = await self.wss.receive_str()
            if isinstance(response_lines_str, str):
                response_lines = response_lines_str.split("\x1e")
            else:
                continue
            for line in response_lines:
                if not line:
                    continue
                data = json.loads(line)
                # Stream: Meaningful Messages
                if data.get("type") == 1:
                    if yield_output:
                        output = message_parser.parse(data, return_output=True)
                        if isinstance(output, list):
                            for item in output:
                                yield item
                        else:
                            if output:
                                yield output
                    else:
                        message_parser.parse(data)
                # Stream: List of all messages in the whole conversation
                elif data.get("type") == 2:
                    if data.get("item"):
                        # item = data.get("item")
                        # logger.note("\n[Saving chat messages ...]")
                        pass
                # Stream: End of Conversation
                elif data.get("type") == 3:
                    finished_str = "\n[Finished]"
                    logger.success(finished_str)
                    self.invocation_id += 1
                    await self.wss.close()
                    await self.aiohttp_session.close()
                    if yield_output:
                        yield message_parser.outputer.output(
                            content=finished_str, content_type="Finished"
                        )
                    break
                # Stream: Heartbeat Signal
                elif data.get("type") == 6:
                    continue
                # Stream: Not Implemented
                else:
                    continue