import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import os


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()
    MPT = auto()
    PLAIN = auto()
    LLAMA_2 = auto()
    MISTRAL = auto()

# video_helper_map = {
#     # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'},
#     'Keynote 2023' : {'path' : '/data/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'},
#     'Intel Behind the Bell' : {'path' : '/data/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'},
#     'CEOs Talk' : {'path' : '/data/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'},
#     'Chips Act Funding Announcement' : {'path' : '/data/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'},
#     '22nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained  Intel_split'},
#     '14nm-Chip Technology' : {'path' : '/data/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'},    
# }

video_helper_map = {
    # 'Chips Making Deal Video' : {'path' : '/data/videos/ChipmakingDeal/sub-videos/', 'prefix' : 'ChipmakingDeal_split'},
    'Innovation-2023' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/PatsKeynote23/sub-videos/', 'prefix' : 'keynotes23_split'},
    'Behind-the-Bell-Intel' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/BehindTheBell/sub-videos/', 'prefix' : 'Behind the Bell Intel_split'},
    'Foundry-Connect' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/SamPatTalkAI/sub-videos/', 'prefix' : 'Sam Altman and Pat Gelsinger Talk Artificial Intelligence_split'},
    'Chips Act Funding Announcement' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/IntelChipsFundingAnnounce/sub-videos/', 'prefix' : 'Intel Celebrates CHIPS and Science Act Direct Funding Announcement (Replay)_split'},
    '22nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains22nm/sub-videos/', 'prefix' : 'Video Animation Mark Bohr Gets Small 22nm Explained  Intel_split'},
    '14nm-transistor-animation' : {'path' : '/data1/tile_gh/Multimodal-RAG/videos/MarkBohrExplains14nm/sub-videos/', 'prefix' : 'Explanation of Intels 14nm Process_split'},    
}

@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "\n"
    sep2: str = None
    version: str = "Unknown"
    path_to_img: str = None
    video_title: str = None
    caption: str = None

    skip_next: bool = False

    def _template_caption(self):
        out = ""
        if self.caption is not None:
            out = f"The caption associated with the image is '{self.caption}'. "
        return out
    
    def get_prompt(self):
        messages = self.messages
        if len(messages) > 0 and messages[1][1] is not None and "<image>" not in messages[0][1]:
            # if there is a history message and <image> is not yet in the first message of user
            # then add <image>\n to the beginning
            messages = self.messages.copy()
            init_role, init_msg = messages[0].copy()
            messages[0] = (init_role, "<image>\n" + self._template_caption() + init_msg)

        if len(messages) > 1 and messages[1][1] is None:
            #Need to do RAG. prompt is the query only
            ret = messages[0][1]
        else:
            if self.sep_style == SeparatorStyle.SINGLE:
                ret = ""
                for role, message in messages:
                    if message:
                        ret += role + ": " + message + self.sep
                    else:
                        ret += role + ":"
            elif self.sep_style == SeparatorStyle.LLAMA_2:
                wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
                wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
                ret = ""

                for i, (role, message) in enumerate(messages):
                    if i == 0:
                        assert message, "first message should not be none"
                        assert role == self.roles[0], "first message should come from user"
                    if message:
                        if type(message) is tuple:
                            message, _, _ = message
                        if i == 0: message = wrap_sys(self.system) + message
                        if i % 2 == 0:
                            message = wrap_inst(message)
                            ret += self.sep + message
                        else:
                            ret += " " + message + " " + self.sep2
                    else:
                        ret += ""
                ret = ret.lstrip(self.sep)
            else:
                raise ValueError(f"Invalid style: {self.sep_style}")

        return ret

    def append_message(self, role, message):
        self.messages.append([role, message])

    def get_images(self, return_pil=False):
        images = []
        if self.path_to_img is not None:
            path_to_image = self.path_to_img
            images.append(path_to_image)
            # import base64
            # from io import BytesIO
            # from PIL import Image
            # image = Image.open(path_to_image)
            # max_hw, min_hw = max(image.size), min(image.size)
            # aspect_ratio = max_hw / min_hw
            # max_len, min_len = 800, 400
            # shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
            # longest_edge = int(shortest_edge * aspect_ratio)
            # W, H = image.size
            # if longest_edge != max(image.size):
            #     if H > W:
            #         H, W = longest_edge, shortest_edge
            #     else:
            #         H, W = shortest_edge, longest_edge
            #     image = image.resize((W, H))
            # if return_pil:
            #     images.append(image)
            # else:
            #     # buffered = BytesIO()
            #     # # image.save(buffered, format="PNG")
            #     # img_b64_str = base64.b64encode(buffered.getvalue()).decode()
            #     images.append(path_to_image)
        return images

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                if type(msg) is tuple:
                    import base64
                    from io import BytesIO
                    msg, image, image_process_mode = msg
                    max_hw, min_hw = max(image.size), min(image.size)
                    aspect_ratio = max_hw / min_hw
                    max_len, min_len = 800, 400
                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
                    longest_edge = int(shortest_edge * aspect_ratio)
                    W, H = image.size
                    if H > W:
                        H, W = longest_edge, shortest_edge
                    else:
                        H, W = shortest_edge, longest_edge
                    image = image.resize((W, H))
                    buffered = BytesIO()
                    image.save(buffered, format="JPEG")
                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()
                    img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
                    msg = img_str + msg.replace('<image>', '').strip()
                    ret.append([msg, None])
                else:
                    ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            version=self.version,)

    def dict(self):
        return {
            "system": self.system,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "path_to_img": self.path_to_img,
            "video_title" : self.video_title,
            "caption" : self.caption,
        }
    def get_path_to_subvideos(self):
        print(f"self.video_title {self.video_title}")
        print(f"self.path_to_image {self.path_to_img}")
        return None
        if self.video_title is not None and self.path_to_img is not None:
            info = video_helper_map[self.video_title]
            path = info['path']
            prefix = info['prefix']
            vid_index = self.path_to_img.split('/')[-1]
            vid_index = vid_index.split('_')[-1]
            vid_index = vid_index.replace('.jpg', '')
            ret = f"{prefix}{vid_index}.mp4"
            ret = os.path.join(path, ret)
            return ret
        elif self.path_to_img is not None:
            return self.path_to_img
        return None

multimodal_rag = Conversation(
    system="",
    roles=("USER", "ASSISTANT"),
    messages=(),
    offset=0,
    sep_style=SeparatorStyle.SINGLE,
    sep="\n",
    path_to_img=None,
    video_title=None,
    caption=None,
)

conv_mistral_instruct = Conversation(
    system="",
    roles=("USER", "ASSISTANT"),
    version="llama_v2",
    messages=(),
    offset=0,
    sep_style=SeparatorStyle.LLAMA_2,
    sep="",
    sep2="</s>",
    path_to_img=None,
    video_title=None,
    caption=None,
)



default_conversation = multimodal_rag
conv_templates = {
    "default": multimodal_rag,
    "multimodal_rag" : multimodal_rag,
    "llavamed_rag" : conv_mistral_instruct,
}


if __name__ == "__main__":
    print(default_conversation.get_prompt())