from transformers import AutoProcessor
from PIL import Image
import os
import torch
import pickle

## ACTUAL INPUT CONSTRUCTION
BASE_SPEAKER_LEN = 787

def joint_listener_input(processor, context_images, description, device):
    # Preliminaries
    img_dir = "tangram_pngs"
    raw_images = process_images(img_dir, context_images)
    target_anno = description.lower()
    prompt = construct_listener_full_prompt(
        processor, target_anno, 0, "verbose_instruction"
    )

    # Listener processing
    outputs = processor(
        text=[prompt],
        images=[raw_images],
        return_tensors="pt"
    ).to(device)
    l_input_tokens = outputs['input_ids'][:, :-2]
    l_attn_mask = outputs['attention_mask'][:, :-2]
    l_attn_mask[(l_input_tokens == 0).bool()] = 0
    images = outputs['pixel_values']
    l_image_attn_mask = outputs['pixel_attention_mask']

    # Speaker processing
    prompts = []
    for i in range(10):
        prompt = construct_speaker_full_prompt(processor, description, i, "information_after")
        prompts.append(prompt)
    outputs = processor(
        text=prompts,
        images=[raw_images]*10,
        padding='longest',
        return_tensors="pt"
    ).to(device)

    s_input_tokens = outputs['input_ids'][:, :-1]
    s_attn_mask = outputs['attention_mask'][:, :-1]
    s_attn_mask[(s_input_tokens == 0).bool()] = 0
    s_image_attn_mask = outputs['pixel_attention_mask']
    s_target_tokens = outputs['input_ids'][:, 1:]
    s_target_mask = []
    for i in range(10):
        curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i])
        s_target_mask.append(curr_mask)
    s_target_mask = torch.stack(s_target_mask, dim=0)

    return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \
        s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \
        s_target_tokens.unsqueeze(0)

def joint_speaker_input(processor, image_paths, target_path, device):
    # Get the prompt
    img_dir = "tangram_pngs"
    raw_images = process_images(img_dir, image_paths)
    target_idx = image_paths.index(target_path)
    base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True)

    # Create the basic input
    outputs = processor(
        text=[base_prompt],
        images=[raw_images],
        return_tensors="pt"
    ).to(device)

    input_tokens = outputs['input_ids']
    attn_mask = outputs['attention_mask']
    attn_mask[(input_tokens == 0).bool()] = 0
    images = outputs['pixel_values']
    image_attn_mask = outputs['pixel_attention_mask']

    return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device)


## UTILITIES

def get_processor():
    checkpoint = "HuggingFaceM4/idefics2-8b"
    processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False,
                                               size={"longest_edge": 448, "shortest_edge": 224})
    return processor

def get_index_to_token():
    index_to_token_path = "index_to_token.pkl"
    with open(index_to_token_path, 'rb') as f:
        index_to_token = pickle.load(f)
    return index_to_token

def process_images(img_dir, context_images):
    raw_images = []
    for img in context_images:
        image_path = os.path.join(img_dir, img)
        raw_image = Image.open(image_path).convert('RGB')
        raw_images.append(raw_image)
    return raw_images

def create_speaker_caption_mask(all_token_ids, text_mask):
    # Overall token comp: pad + base + caption
    padding_tokens = torch.sum(all_token_ids == 0).item()
    caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN)

    # Construct a mask where the last caption tokens are 1
    target_mask = torch.zeros_like(text_mask)
    target_mask[-caption_tokens:] = 1

    return target_mask.bool()

def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"):
    target_anno = target_anno.lower().strip()
    messages = []

    if comprehension_prompt_type == "verbose_instruction":
        # User side: Intro
        messages.append(
            {
                "role" : "user",
                "content" : [
                    {"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "},
                    {"type" : "text", "text" : "Your task is to guess which image the caption describes. "},
                ]
            }
        )

        # User side: Images
        for i in range(10):
            if i == 0:
                messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
            else:
                messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
            messages[0]["content"].append({"type" : "image"})

        # User side: Caption
        messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"})
        messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"})

        # Model side: Guess
        messages.append(
            {
                "role" : "assistant",
                "content" : [
                    {"type" : "text", "text" : f"The caption describes Image {target_idx}"}
                ]
            }
        )
    else:
        assert(False)

    return processor.apply_chat_template(messages, add_generation_prompt=False).strip() 

def construct_speaker_full_prompt(processor, target_anno, target_idx,
                                  generation_prompt_type="information_after"):
    messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type)

    # Assistant response
    target_anno = target_anno.lower().strip()
    messages.append(
        {
            "role" : "assistant",
            "content" : [
                {"type" : "text", "text" : target_anno}
            ]
        }
    )

    return processor.apply_chat_template(messages, add_generation_prompt=False).strip()

def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False):
    messages = []

    if generation_prompt_type == "information_after":
        # User side: Intro
        messages.append(
            {
                "role" : "user",
                "content" : [
                    {"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "},
                    {"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "},
                ]
            }
        )

        # User side: Images
        for i in range(10):
            if i == 0:
                messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
            else:
                messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
            messages[0]["content"].append({"type" : "image"})

        # User side: Target assignment
        messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."})
    else:
        assert(False)

    if process:
        prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip()
        return prompt
    else:
        return messages

def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device):
    # First construct the prompts
    prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor)

    # Process the prompts
    listener_inputs = processor(
        text=prompts,
        images=raw_images,
        padding='longest',
        return_tensors='pt'
    )

    input_tokens = listener_inputs['input_ids'][:, :-2].to(device)
    attn_mask = listener_inputs['attention_mask'][:, :-2].to(device)
    attn_mask[input_tokens == 0] = 0
    images = listener_inputs['pixel_values'].to(device)
    image_attn_mask = listener_inputs['pixel_attention_mask'].to(device)

    return input_tokens, attn_mask, images, image_attn_mask

def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor):
    prompts = []
    all_raw_images = []

    for i, speaker_context in enumerate(speaker_contexts):
        raw_images = process_images(img_dir, speaker_context)
        for j in range(num_samples):
            curr_idx = i * num_samples + j
            caption = captions[curr_idx]
            prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction")

            prompts.append(prompt)
            all_raw_images.append(raw_images)
    return prompts, all_raw_images