import ast
import math
from einops import rearrange, repeat
from einops_exts import rearrange_many
from einops import rearrange
from PIL import Image
import torch
from torch import einsum, nn


from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass
from transformers import CLIPVisionModel
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
from transformers import PretrainedConfig, logging, CONFIG_MAPPING
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer


logger = logging.get_logger(__name__)


class XGenMMVisionEncoderConfig(PretrainedConfig):
    model_type = "xgenmm_vision_encoder"

    def __init__(
        self,
        model_name: str = "google/siglip-so400m-patch14-384",
        anyres_grids: list[int] = [
            [384, 768],
            [768, 384],
            [768, 768],
            [1152, 384],
            [384, 1152],
        ],
        **kwargs,
    ):
        self.model_name = model_name
        self.anyres_grids = anyres_grids
        super().__init__(**kwargs)


class XGenMMVisionTokenizerConfig(PretrainedConfig):
    model_type = "xgenmm_vision_tokenizer"

    def __init__(
        self,
        vis_feature_dim: int = 1152,
        lang_embedding_dim: int = 3072,
        num_vis_tokens: int = 128,
        image_aspect_ratio: str = "anyres",
        **kwargs,
    ):
        self.vis_feature_dim = vis_feature_dim
        self.lang_embedding_dim = lang_embedding_dim
        self.num_vis_tokens = num_vis_tokens
        self.image_aspect_ratio = image_aspect_ratio
        super().__init__(**kwargs)


class XGenMMConfig(PretrainedConfig):
    model_type = "xgenmm"

    def __init__(
        self,
        vision_encoder_config: dict = None,
        vision_tokenizer_config: dict = None,
        text_config: dict = None,
        **kwargs,
    ):

        if vision_encoder_config is None:
            vision_encoder_config = {
                "image_aspect_ratio": "anyres",
                "anyres_patch_sampling": True,
            }
            logger.info(
                "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
            )

        if vision_tokenizer_config is None:
            vision_tokenizer_config = {}
            logger.info(
                "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
            )

        if text_config is None:
            text_config = {
                "initial_tokenizer_len": 32012,
                "pad_token_id": 32011,
                "bos_token_id": 1,
                "eos_token_id": 32000,
                "vocab_size": 32064,
                "hidden_size": 3072,
                "intermediate_size": 8192,
                "num_hidden_layers": 32,
                "num_attention_heads": 32,
                "num_key_value_heads": 32,
                "resid_pdrop": 0.0,
                "embd_pdrop": 0.0,
                "attention_dropout": 0.0,
                "hidden_act": "silu",
                "max_position_embeddings": 4096,
                "original_max_position_embeddings": 4096,
                "initializer_range": 0.02,
                "rms_norm_eps": 1e-05,
                "use_cache": True,
                "rope_theta": 10000.0,
                "rope_scaling": None,
                "sliding_window": 2047,
                "return_dict": True,
                "output_hidden_states": False,
                "output_attentions": False,
                "torchscript": False,
                "torch_dtype": "bfloat16",
                "use_bfloat16": False,
                "tf_legacy_loss": False,
                "pruned_heads": {},
                "tie_word_embeddings": False,
                "chunk_size_feed_forward": 0,
                "is_encoder_decoder": False,
                "is_decoder": False,
                "cross_attention_hidden_size": None,
                "add_cross_attention": False,
                "tie_encoder_decoder": False,
                "max_length": 20,
                "min_length": 0,
                "do_sample": False,
                "early_stopping": False,
                "num_beams": 1,
                "num_beam_groups": 1,
                "diversity_penalty": 0.0,
                "temperature": 1.0,
                "top_k": 50,
                "top_p": 1.0,
                "typical_p": 1.0,
                "repetition_penalty": 1.0,
                "length_penalty": 1.0,
                "no_repeat_ngram_size": 0,
                "encoder_no_repeat_ngram_size": 0,
                "bad_words_ids": None,
                "num_return_sequences": 1,
                "output_scores": False,
                "return_dict_in_generate": False,
                "forced_bos_token_id": None,
                "forced_eos_token_id": None,
                "remove_invalid_values": False,
                "exponential_decay_length_penalty": None,
                "suppress_tokens": None,
                "begin_suppress_tokens": None,
                "finetuning_task": None,
                "id2label": {0: "LABEL_0", 1: "LABEL_1"},
                "label2id": {"LABEL_0": 0, "LABEL_1": 1},
                "tokenizer_class": None,
                "prefix": None,
                "bos_token_id": 1,
                "pad_token_id": 32000,
                "eos_token_id": 32000,
                "sep_token_id": None,
                "decoder_start_token_id": None,
                "task_specific_params": None,
                "problem_type": None,
                "model_type": "phi3",
                "_attn_implementation": "flash_attention_2",
            }
            logger.info(
                "text_config is None. Initializing the text config with default values (`Phi3Config`)."
            )

        self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)

        self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(
            **vision_tokenizer_config
        )

        text_model_type = (
            text_config["model_type"] if "model_type" in text_config else "phi3"
        )
        self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

        for key in ["initial_tokenizer_len", "pad_token_id"]:
            if key not in self.text_config.to_dict():
                raise ValueError(f"The key `{key}` is missing in the text_config.")

        super().__init__(**kwargs)


def hasattr_recursive(obj, att):
    """
    Check if obj has nested attribute
    Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
    """
    if att == "":
        return True
    i = att.find(".")
    if i < 0:
        return hasattr(obj, att)
    else:
        try:
            return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
        except:
            return False


def getattr_recursive(obj, att):
    """
    Return nested attribute of obj
    Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
    """
    if att == "":
        return obj
    i = att.find(".")
    if i < 0:
        return getattr(obj, att)
    else:
        return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])


def setattr_recursive(obj, att, val):
    """
    Set nested attribute of obj
    Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
    """
    if "." in att:
        obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
    setattr(obj, att.split(".")[-1], val)


def check_embedding_fns(lang_model):
    """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
    if not has_fn(lang_model, "get_input_embeddings"):
        if hasattr_recursive(lang_model, "transformer.wte"):  # MPT
            lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
        elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"):  # OPT
            lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
        else:
            raise ValueError(
                "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "set_input_embeddings"):
        if hasattr_recursive(lang_model, "transformer.wte"):  # MPT
            lang_model.set_input_embeddings = lambda x: setattr_recursive(
                lang_model, "transformer.wte", x
            )
        elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"):  # OPT
            lang_model.set_input_embeddings = lambda x: setattr_recursive(
                lang_model, "model.decoder.embed_tokens", x
            )
        else:
            raise ValueError(
                "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "get_output_embeddings"):
        if hasattr_recursive(lang_model, "lm_head"):
            lang_model.get_output_embeddings = lambda: lang_model.lm_head
        else:
            raise ValueError(
                "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "set_output_embeddings"):
        if hasattr_recursive(lang_model, "lm_head"):
            lang_model.set_output_embeddings = lambda x: setattr_recursive(
                lang_model, "lm_head", x
            )
        else:
            raise ValueError(
                "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
            )


def has_fn(model, fn_name):
    """Check if model has a function fn_name"""
    return callable(getattr(model, fn_name, None))


def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
    """
    Stack a list of tensors with padding on one side
    Args:
        list_of_tensors (list[torch.Tensor]): List of tensors to stack
        padding_value (int, optional): Value to pad with. Defaults to 0.
        padding_side (str, optional): Side to pad on. Defaults to "right".
    Returns:
        torch.Tensor: Stacked tensors
    """
    max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
    padded_tensors = []
    for tensor in list_of_tensors:
        num_tokens = tensor.size(0)
        if len(tensor.size()) == 1:
            padding = torch.full(
                (max_tokens - num_tokens,),
                padding_value,
                dtype=tensor.dtype,
                device=tensor.device,
            )
        else:
            padding = torch.full(
                (max_tokens - num_tokens, tensor.size(1)),
                padding_value,
                dtype=tensor.dtype,
                device=tensor.device,
            )
        padded_tensor = (
            torch.cat((tensor, padding), dim=0)
            if padding_side == "right"
            else torch.cat((padding, tensor), dim=0)
        )
        padded_tensors.append(padded_tensor)
    return torch.stack(padded_tensors)


def unpad_image(tensor, original_size, keep_original_shape=False):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of the image (height, width).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        if keep_original_shape:
            attention_mask = torch.ones(
                (current_height, current_width), device=tensor.device
            )
            attention_mask[:padding, :] = 0
            attention_mask[current_height - padding :, :] = 0
            return tensor, attention_mask
        else:
            unpadded_tensor = tensor[:, padding : current_height - padding, :]
            return unpadded_tensor, None
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        if keep_original_shape:
            attention_mask = torch.ones(
                (current_height, current_width), device=tensor.device
            )
            attention_mask[:, :padding] = 0
            attention_mask[:, current_width - padding :] = 0
            return tensor, attention_mask
        else:
            unpadded_tensor = tensor[:, :, padding : current_width - padding]
            return unpadded_tensor, None


def select_best_resolution(original_size, possible_resolutions):
    """
    Selects the best resolution from a list of possible resolutions based on the original size.

    Args:
        original_size (tuple): The original size of the image in the format (width, height).
        possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].

    Returns:
        tuple: The best fit resolution in the format (width, height).
    """
    original_width, original_height = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for width, height in possible_resolutions:
        scale = min(width / original_width, height / original_height)
        downscaled_width, downscaled_height = int(original_width * scale), int(
            original_height * scale
        )
        effective_resolution = min(
            downscaled_width * downscaled_height, original_width * original_height
        )
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (
            effective_resolution == max_effective_resolution
            and wasted_resolution < min_wasted_resolution
        ):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (width, height)

    return best_fit


def resize_and_pad_image(image, target_resolution):
    """
    Resize and pad an image to a target resolution while maintaining aspect ratio.

    Args:
        image (PIL.Image.Image): The input image.
        target_resolution (tuple): The target resolution (width, height) of the image.

    Returns:
        PIL.Image.Image: The resized and padded image.
    """
    original_width, original_height = image.size
    target_width, target_height = target_resolution

    scale_w = target_width / original_width
    scale_h = target_height / original_height

    if scale_w < scale_h:
        new_width = target_width
        new_height = min(math.ceil(original_height * scale_w), target_height)
    else:
        new_height = target_height
        new_width = min(math.ceil(original_width * scale_h), target_width)

    # Resize the image
    resized_image = image.resize((new_width, new_height))

    new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
    paste_x = (target_width - new_width) // 2
    paste_y = (target_height - new_height) // 2
    new_image.paste(resized_image, (paste_x, paste_y))

    return new_image


def divide_to_patches(image, patch_size):
    """
    Divides an image into patches of a specified size.

    Args:
        image (PIL.Image.Image): The input image.
        patch_size (int): The size of each patch.

    Returns:
        list: A list of PIL.Image.Image objects representing the patches.
    """
    patches = []
    width, height = image.size
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            box = (j, i, j + patch_size, i + patch_size)
            patch = image.crop(box)
            patches.append(patch)

    return patches


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
    """
    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.

    Args:
        image_size (tuple): The size of the input image in the format (width, height).
        grid_pinpoints (str): A string representation of a list of possible resolutions.
        patch_size (int): The size of each image patch.

    Returns:
        tuple: The shape of the image patch grid in the format (width, height).
    """
    if type(grid_pinpoints) is list:
        possible_resolutions = grid_pinpoints
    else:
        possible_resolutions = ast.literal_eval(grid_pinpoints)
    width, height = select_best_resolution(image_size, possible_resolutions)
    return width // patch_size, height // patch_size


def process_anyres_image(image, processor, grid_pinpoints):
    """
    Process an image with variable resolutions.

    Args:
        image (PIL.Image.Image): The input image to be processed.
        processor: The image processor object.
        grid_pinpoints (str): A string representation of a list of possible resolutions.

    Returns:
        torch.Tensor: A tensor containing the processed image patches.
    """
    # FIXME: determine grid_pinpoints from image sizes.
    if type(grid_pinpoints) is list:
        possible_resolutions = grid_pinpoints
    else:
        possible_resolutions = ast.literal_eval(grid_pinpoints)
    best_resolution = select_best_resolution(image.size, possible_resolutions)
    image_padded = resize_and_pad_image(image, best_resolution)

    processor_size = processor.transforms[0].size
    patches = divide_to_patches(image_padded, processor_size[0])

    image_original_resize = image.resize((processor_size[0], processor_size[0]))

    image_patches = [image_original_resize] + patches
    image_patches = [processor(image_patch) for image_patch in image_patches]
    return torch.stack(image_patches, dim=0)


def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result


class VisionTokenizer(nn.Module):
    def __init__(self, dim_media, num_tokens_per_media):
        super().__init__()
        self.dim_media = dim_media
        self.num_tokens_per_media = num_tokens_per_media


class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents, vision_attn_masks=None):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)
            latent (torch.Tensor): latent features
                shape (b, T, n2, D)
        """
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        h = self.heads

        q = self.to_q(latents)
        kv_input = torch.cat(
            (x, latents), dim=-2
        )  # TODO: Change the shape of vision attention mask according to this.
        if vision_attn_masks is not None:
            vision_attn_masks = torch.cat(
                (
                    vision_attn_masks,
                    torch.ones(
                        (latents.shape[0], latents.shape[-2]),
                        dtype=latents.dtype,
                        device=latents.device,
                    ),
                ),
                dim=-1,
            )
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale

        # attention
        sim = einsum("... i d, ... j d  -> ... i j", q, k)
        # Apply vision attention mask here.
        # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
        if vision_attn_masks is not None:
            attn_bias = torch.zeros(
                (q.size(0), 1, 1, q.size(-2), k.size(-2)),
                dtype=q.dtype,
                device=q.device,
            )
            vision_attn_masks = repeat(
                vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2)
            )
            attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
            sim += attn_bias

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        return self.to_out(out)


def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def num_params(module, filter_to_trainable=False):
    """Returns the number of parameters in the module, or optionally only the trainable parameters"""
    if filter_to_trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())


class PerceiverResampler(VisionTokenizer):
    def __init__(
        self,
        *,
        dim,
        dim_inner=None,
        depth=6,
        dim_head=96,
        heads=16,
        num_latents=128,
        max_num_media=None,
        max_num_frames=None,
        ff_mult=4,
    ):
        """
        Perceiver module which takes in image features and outputs image tokens.
        Args:
            dim (int): dimension of the incoming image features
            dim_inner (int, optional): final dimension to project the incoming image features to;
                also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
            depth (int, optional): number of layers. Defaults to 6.
            dim_head (int, optional): dimension of each head. Defaults to 64.
            heads (int, optional): number of heads. Defaults to 8.
            num_latents (int, optional): number of latent tokens to use in the Perceiver;
                also corresponds to number of tokens per sequence to output. Defaults to 64.
            max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
                and keep positional embeddings for. If None, no positional embeddings are used.
            max_num_frames (int, optional): maximum number of frames to input into the Perceiver
                and keep positional embeddings for. If None, no positional embeddings are used.
            ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
        """
        if dim_inner is not None:
            projection = nn.Linear(dim, dim_inner)
        else:
            projection = None
            dim_inner = dim
        super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
        self.projection = projection
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # positional embeddings
        self.frame_embs = (
            nn.Parameter(torch.randn(max_num_frames, dim))
            if exists(max_num_frames)
            else None
        )
        self.media_time_embs = (
            nn.Parameter(torch.randn(max_num_media, 1, dim))
            if exists(max_num_media)
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x, vision_attn_masks):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, F, v, D)
            vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
                shape (b, v)
        Returns:
            shape (b, T, n, D) where n is self.num_latents
        """
        b, T, F, v = x.shape[:4]

        # frame and media time embeddings
        if exists(self.frame_embs):
            frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
            x = x + frame_embs
        x = rearrange(
            x, "b T F v d -> b T (F v) d"
        )  # flatten the frame and spatial dimensions
        if exists(self.media_time_embs):
            x = x + self.media_time_embs[:T]

        # blocks
        latents = self.latents
        latents = repeat(latents, "n d -> b T n d", b=b, T=T)
        for attn, ff in self.layers:
            latents = attn(x, latents, vision_attn_masks) + latents
            latents = ff(latents) + latents

        if exists(self.projection):
            return self.projection(self.norm(latents))
        else:
            return self.norm(latents)


class DecoupledEmbedding(nn.Embedding):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
    regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
    then it will create `num_additional_embeddings` additional parameters that are always trained. If
    `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
    """

    def __init__(
        self,
        max_original_id: int,
        num_additional_embeddings: int = 0,
        _weight: torch.Tensor = None,
        num_original_embeddings: int = None,
        embedding_dim: int = None,
        partially_freeze=True,
        device=None,
        dtype=None,
        pad_token_id=None,
    ) -> None:
        """
        Args:
            max_original_id (`int`):
                The largest token id that should be embedded using the regular embedding (regular `weight`).
                This is usually len(tokenizer) - 1 before additional tokens are added.
                Note that this may not equal self.weight.shape[0]
            num_additional_embeddings (`int`):
                Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
            _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
                If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
            num_original_embeddings (`int`):
                self.weight.shape[0]
            embedding_dim (`int`):
                The size of each embedding vector
            partially_freeze: (`bool`, *optional*, defaults to `True`):
                If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
            padding_idx (`int`, *optional*):
                The padding index (needs to be less than num_embeddings)

        Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
        `max_norm` or `norm_type`. We are not supporting these.
        """
        # validate args
        if pad_token_id is not None and pad_token_id > max_original_id:
            raise ValueError(
                f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
                + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
            )
        if _weight is not None:
            assert (num_original_embeddings is None) or (
                _weight.shape[0] == num_original_embeddings
            ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
            assert (embedding_dim is None) or (
                _weight.shape[1] == embedding_dim
            ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
            num_original_embeddings = _weight.shape[0]
            embedding_dim = _weight.shape[1]
        else:
            assert (
                num_original_embeddings is not None
            ), "num_original_embeddings must be provided if _weight is not provided"
            assert (
                embedding_dim is not None
            ), "embedding_dim must be provided if _weight is not provided"

        super().__init__(
            num_embeddings=num_original_embeddings,
            embedding_dim=embedding_dim,
            device=device,
            dtype=dtype,
            padding_idx=pad_token_id,
            _weight=_weight,
        )
        self.max_original_id = max_original_id
        self.padding_idx = pad_token_id
        self.num_additional_embeddings = num_additional_embeddings
        if self.num_additional_embeddings > 0:
            self.additional_embedding = nn.Embedding(
                num_embeddings=self.num_additional_embeddings,
                embedding_dim=embedding_dim,
                device=device,
                dtype=dtype,
            )
        self.set_requires_grad(
            require_regular_grad=not partially_freeze, require_additional_grad=True
        )

    def set_requires_grad(self, require_regular_grad, require_additional_grad):
        """
        Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
        """
        self.weight.requires_grad_(require_regular_grad)
        self.additional_embedding.requires_grad_(require_additional_grad)

    def forward(self, input_ids):
        """
        we have 2 embeddings, with different indices - one pretrained self.weight and another
        self.additional_embedding.weight that is being trained.

        in order to make a lookup of the input ids, we:
        1. find out the indices of the entries belonging to the 2nd embedding
        2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
        embedding starts from 0 and not num_embeddings
        3. perform the 2nd embedding lookup
        4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
        5. perform the 1st embedding lookup
        6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup

        note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
        then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
        i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
        usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
        measure.

        """
        if self.num_additional_embeddings == 0:
            return F.embedding(input_ids, self.weight)

        # Clone so that we don't modify the original input_ids later on
        input_ids = input_ids.clone()
        additional_vocab_indices = torch.where(input_ids > self.max_original_id)
        input_ids_additional_vocab = input_ids[additional_vocab_indices]
        additional_embeddings = self.additional_embedding(
            input_ids_additional_vocab - self.max_original_id - 1
        )

        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
        input_ids[additional_vocab_indices] = 0
        full_vector = F.embedding(input_ids, self.weight)

        # overwrite the records with high indices
        full_vector[additional_vocab_indices] = additional_embeddings

        return full_vector

    def extra_repr(self) -> str:
        return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
            self.max_original_id + 1,
            self.num_additional_embeddings,
            self.embedding_dim,
            (not self.weight.requires_grad),
        )


class DecoupledLinear(nn.Linear):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
    regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
    then it will create `additional_out_features * in_features` additional parameters that are always trained. If
    `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
    """

    def __init__(
        self,
        max_original_id: int,
        additional_out_features: int = 0,
        _weight: torch.Tensor = None,
        _bias: torch.Tensor = None,
        in_features: int = None,
        original_out_features: int = None,
        bias: bool = True,
        partially_freeze: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """
        Args:
            max_original_id (`int`): The largest token id that should be extracted from the regular weight.
                This is usually len(tokenizer) - 1 before additional tokens are added.
                Note that this may not equal original_out_features - 1
            _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
                If provided, this sets the `in_features` and `original_out_features` parameters.
            _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
            in_features: int. Input hidden size.
            original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
            additional_out_features: int. Number of additional trainable dimensions.
            bias: bool. Whether to include a bias term.
            partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
        """
        # argument validation
        if _weight is not None:
            assert (_weight.shape[0] == original_out_features) or (
                original_out_features is None
            ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
            assert (_weight.shape[1] == in_features) or (
                in_features is None
            ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
            in_features = _weight.shape[1]
            original_out_features = _weight.shape[0]
        else:
            assert (
                in_features is not None
            ), "in_features must be provided if _weight is not provided"
            assert (
                original_out_features is not None
            ), "original_out_features must be provided if _weight is not provided"

        if _bias is not None:
            assert bias is True, "bias must be True if _bias is provided"

        # initialize original linear
        super().__init__(in_features, original_out_features, bias, device, dtype)

        # set weight and bias manually
        if _weight is not None:
            self.weight = nn.Parameter(_weight)
        if _bias is not None:
            self.bias = nn.Parameter(_bias)

        self.in_features = in_features
        self.original_out_features = original_out_features
        self.max_original_id = max_original_id

        # initialize additional linear
        self.additional_out_features = additional_out_features
        self.has_bias = bias
        if additional_out_features > 0:
            self.additional_fc = nn.Linear(
                in_features=in_features,
                out_features=additional_out_features,
                bias=self.has_bias,
                device=device,
                dtype=dtype,
            )
        self.set_requires_grad(
            require_regular_grad=not partially_freeze, require_additional_grad=True
        )

    def set_requires_grad(self, require_regular_grad, require_additional_grad):
        """
        Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
        """
        self.weight.requires_grad_(require_regular_grad)
        if self.has_bias:
            self.bias.requires_grad_(require_regular_grad)
        self.additional_fc.requires_grad_(require_additional_grad)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = F.linear(input, self.weight, self.bias)
        output = output[..., : self.max_original_id + 1]

        if self.additional_out_features > 0:
            additional_features = F.linear(
                input, self.additional_fc.weight, self.additional_fc.bias
            )
            output = torch.cat((output, additional_features), -1)
        return output

    def extra_repr(self) -> str:
        """Overwriting `nn.Linear.extra_repr` to include new parameters."""
        return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
            self.in_features,
            self.max_original_id + 1,
            self.additional_out_features,
            self.bias is not None,
            (not self.weight.requires_grad or not self.bias.requires_grad),
        )


class VLM(nn.Module):
    """
    Generic vision-language model (VLM) class.
    A VLM consists of four components:
        1. A vision encoder that extracts features from pixels, e.g. CLIP
            input: (B, T_img, F, C, H, W)
            output: (B, T_img, F, v, d)
        2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
            input: (B, T_img, F, v, d)
            output: (B, T_img, n, d)
        3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
        4. A language model
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        vision_tokenizer: nn.Module,
        lang_model: nn.Module,
        initial_tokenizer_len: int,
        pad_token_id: int,
        gradient_checkpointing: bool = False,
    ):
        """
        Args:
            vision_encoder (nn.Module): e.g. CLIP
            vision_tokenizer (nn.Module): e.g. PerceiverResampler
            lang_model (nn.Module): e.g. MPT
            initial_tokenizer_len (int): size of the original tokenizer vocab
            pad_token_id (int): id of the pad token
            gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
        """
        super().__init__()

        # save dimension information
        self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
        if hasattr(lang_model.config, "d_model"):
            self.lang_hidden_dim = lang_model.config.d_model  # mpt uses d_model
        else:
            self.lang_hidden_dim = lang_model.config.hidden_size
        self.vis_embedding_dim = vision_tokenizer.dim_media
        self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media

        # core components
        self.vision_encoder = vision_encoder
        self.vision_tokenizer = vision_tokenizer
        self.lang_model = lang_model

        # lm embeddings
        self.pad_token_id = pad_token_id
        self.initial_tokenizer_len = initial_tokenizer_len
        input_embeds = DecoupledEmbedding(
            max_original_id=initial_tokenizer_len - 1,
            num_additional_embeddings=len(self.special_tokens),
            _weight=self.lang_model.get_input_embeddings().weight,
            pad_token_id=self.pad_token_id,
        ).to(self.lang_model.dtype)
        if hasattr(input_embeds, "additional_embedding"):
            input_embeds.additional_embedding.weight.data.normal_(
                mean=0.0,
                std=(
                    self.lang_model.config.initializer_range
                    if hasattr(self.lang_model.config, "initializer_range")
                    else 0.02
                ),
            )
        self.lang_model.set_input_embeddings(input_embeds)

        out_embeds = DecoupledLinear(
            max_original_id=initial_tokenizer_len - 1,
            additional_out_features=len(self.special_tokens),
            _weight=self.lang_model.get_output_embeddings().weight,
            _bias=(
                self.lang_model.get_output_embeddings().bias
                if hasattr(self.lang_model.get_output_embeddings(), "bias")
                else None
            ),
        ).to(self.lang_model.dtype)
        if hasattr(out_embeds, "additional_fc"):
            out_embeds.additional_fc.weight.data.normal_(
                mean=0.0,
                std=(
                    self.lang_model.config.initializer_range
                    if hasattr(self.lang_model.config, "initializer_range")
                    else 0.02
                ),
            )
        self.lang_model.set_output_embeddings(out_embeds)

        # gradient checkpointing
        self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing

    def forward(
        self,
        vision_x: Optional[torch.Tensor],
        lang_x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[
            List[Union[torch.Tensor, Tuple[torch.Tensor]]]
        ] = None,
        past_media_locations: Optional[torch.Tensor] = None,
        past_vision_tokens: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        **kwargs,
    ):
        """
        Args:
            vision_x: Vision input
                shape (B, T_img, F, C, H, W) with F=1
                only F = 1 is supported (single-frame videos)
                if T_img > the number of media tokens in the corresponding input_ids (lang_x),
                only the first number of media tokens in lang_x are used
            lang_x: Language input ids, with media tokens denoting where
                visual media should be inserted.
                shape (B, T_txt)
            attention_mask: Attention mask. Defaults to None.
            labels: Labels. Defaults to None.
                shape (B, T_txt)
            past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
                list of length = number of decoder layers in the LM
                exact implementation depends on LM, see Hugging Face docs
            past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
                shape (B, T_txt)
            past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
            use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
                If True, includes key_values, media_locations, and vision_tokens in the output.
        """
        assert not (past_vision_tokens is None) ^ (
            past_media_locations is None
        ), "past_vision_tokens and past_media_locations must both be None or both be not None"

        # convert pixels to vision tokens
        if vision_x is not None:
            vision_features = self._encode_vision_x(vision_x=vision_x)
            vision_tokens = self.vision_tokenizer(vision_features)
        else:
            vision_tokens = None

        # fuse the vision and language tokens
        new_inputs = self._prepare_inputs_for_forward(
            vision_tokens=vision_tokens,
            lang_x=lang_x,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            past_media_locations=past_media_locations,
            padding_side="right",
            past_vision_tokens=past_vision_tokens,
        )
        output = self.lang_model(
            **new_inputs,
            use_cache=use_cache,
            past_key_values=past_key_values,
            **kwargs,
        )

        # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
        # or to add the past_vision_tokens and past_media_locations to the output
        output = self._postprocess_outputs_from_forward(
            output=output,
            lang_x=lang_x,
            vision_tokens=vision_tokens,
            use_cache=use_cache,
            past_vision_tokens=past_vision_tokens,
            past_media_locations=past_media_locations,
        )

        # postforward hooks
        self._post_forward_hook()
        return output

    def _encode_vision_x_anyres(self, samples, device):
        assert self.anyres_grids is not None
        image_raw = samples[
            "image"
        ]  # list of patch list in of shape [1, N_patch, C, H, W]
        image_sizes = samples["image_size"]

        # Image_raw can be a list of list of patches, when a `samples` has multiple images.
        if isinstance(image_raw[0], list):
            images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
            image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
        else:
            # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
            # concate list of patches into one big patch for any res encoding.
            images = [x.squeeze(0) for x in image_raw]  # [N_patch, C, H, W]
        image = torch.cat(images, dim=0)  # [\sum{B}{N_patch_i}, C, H, W]
        image = image.to(device)

        with torch.no_grad():
            if self.vision_encoder.__class__.__name__ == "TimmModel":
                image_embeds = self.vision_encoder.trunk.forward_features(image)
            elif self.vision_encoder.__class__.__name__ in [
                "CLIPVisionModel",
                "SiglipVisionTransformer",
            ]:
                image_embeds = self.vision_encoder(image).last_hidden_state
            else:
                image_embeds = self.vision_encoder(image)[1]  # OpenCLIP returns tuples

        if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(
            self.vision_encoder, SiglipVisionTransformer
        ):
            base_img_size = self.vision_encoder.config.image_size
        else:
            base_img_size = self.vision_encoder.image_size[0]

        if self.vision_encoder.__class__.__name__ == "TimmModel":
            grid_size = self.vision_encoder.trunk.patch_embed.grid_size
        elif self.vision_encoder.__class__.__name__ in [
            "CLIPVisionModel",
            "SiglipVisionTransformer",
        ]:
            grid_size_base = (
                self.vision_encoder.config.image_size
                // self.vision_encoder.config.patch_size
            )
            grid_size = (grid_size_base, grid_size_base)
        else:
            grid_size = self.vision_encoder.grid_size
        height, width = grid_size

        if not image_embeds.shape[1] == height * width:
            assert (
                image_embeds.shape[1] == height * width + 1
            )  # For vision encoders that has [CLS] token.
            image_embeds = image_embeds[:, 1:, :]  # Drop the cls token for each patch.
        n_vis_token_per_patch = image_embeds.shape[1]

        # Split encoded patches and merge patch features
        # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
        split_sizes = [image.shape[0] for image in images]
        image_embeds = torch.split(image_embeds, split_sizes, dim=0)
        # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
        new_image_embeds = []
        patch_attn_masks = []
        max_n_img_token = -1
        for idx, patch_embeds in enumerate(image_embeds):
            if patch_embeds.shape[0] > 1:
                # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
                base_patch_embeds = patch_embeds[
                    0
                ]  # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
                patch_embeds = patch_embeds[1:]

                assert height * width == base_patch_embeds.shape[0]

                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
                    image_sizes[idx], self.anyres_grids, base_img_size
                )  # Hardcoded grid_pinpoints.
                patch_embeds = patch_embeds.view(
                    num_patch_height, num_patch_width, height, width, -1
                )

                patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
                patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
                patch_embeds, patch_attn_mask = unpad_image(
                    patch_embeds, image_sizes[idx], self.anyres_patch_sampling
                )
                if hasattr(self, "image_newline"):
                    patch_embeds = torch.cat(
                        (
                            patch_embeds,
                            self.image_newline[:, None, None].expand(
                                *patch_embeds.shape[:-1], 1
                            ),
                        ),
                        dim=-1,
                    )
                if self.anyres_patch_sampling:
                    patch_embeds = patch_embeds.view(
                        -1, num_patch_height, num_patch_width, height * width
                    )
                    patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
                    assert patch_attn_mask is not None
                    patch_attn_mask = patch_attn_mask.view(
                        num_patch_height, num_patch_width, height * width
                    )
                    patch_attn_mask = patch_attn_mask.flatten(0, 1)
                    patch_embeds = torch.cat(
                        (base_patch_embeds.unsqueeze(0), patch_embeds), dim=0
                    )
                    patch_attn_mask = torch.cat(
                        (
                            torch.ones(
                                n_vis_token_per_patch, device=patch_embeds.device
                            ).unsqueeze(0),
                            patch_attn_mask,
                        ),
                        dim=0,
                    )
                else:
                    patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
                    patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
            else:
                patch_embeds = (
                    patch_embeds[0].unsqueeze(0)
                    if self.anyres_patch_sampling
                    else patch_embeds[0]
                )
                patch_attn_mask = (
                    torch.ones(
                        n_vis_token_per_patch, device=patch_embeds.device
                    ).unsqueeze(0)
                    if self.anyres_patch_sampling
                    else None
                )
                if hasattr(self, "image_newline"):
                    patch_embeds = torch.cat(
                        (patch_embeds, self.image_newline[None]), dim=0
                    )
            if not self.anyres_patch_sampling:
                max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)

            new_image_embeds.append(patch_embeds)
            patch_attn_masks.append(patch_attn_mask)

        if self.anyres_patch_sampling:
            # Return individual patches for independent token downsampling.
            return new_image_embeds, patch_attn_masks

        # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
        image_embeds = []
        image_atts = []
        for image_embed in new_image_embeds:
            n_img_token = image_embed.shape[0]
            img_attn = torch.ones(
                (max_n_img_token), dtype=torch.long, device=image_embed.device
            )
            if n_img_token < max_n_img_token:
                padded_embed = torch.zeros(
                    (max_n_img_token, image_embed.shape[-1]),
                    dtype=image_embed.dtype,
                    device=image_embed.device,
                )
                padded_embed[:n_img_token, :] = image_embed
                img_attn[n_img_token:] = 0  # Mask out the padded entries.
            else:
                padded_embed = image_embed
            image_embeds.append(padded_embed)
            image_atts.append(img_attn)
        image_embeds = torch.stack(
            image_embeds, dim=0
        )  # Shape [B, N_tok_longest, C_dim]
        image_atts = torch.stack(image_atts, dim=0)  # Shape [B, N_tok_longest, C_dim]
        # TODO: reshape image_embeds and image_atts to "b T F v d"
        image_embeds = image_embeds[:, None, None, :, :]
        # image_atts = image_atts[:, None, None, :, :]

        return image_embeds, image_atts

    def _encode_vision_x(self, vision_x: torch.Tensor):
        """
        Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
        Args:
            vision_x: Vision input
                shape (B, T_img, F, C, H, W)
                Images in the same chunk are collated along T_img, and frames are collated along F
                Currently only F=1 is supported (single-frame videos)

        rearrange code based on https://github.com/dhansmair/flamingo-mini
        """
        assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
        b, T, F = vision_x.shape[:3]

        vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
        with torch.no_grad():
            if self.vision_encoder.__class__.__name__ == "TimmModel":
                vision_x = self.vision_encoder.trunk.forward_features(vision_x)
            elif self.vision_encoder.__class__.__name__ in [
                "CLIPVisionModel",
                "SiglipVisionTransformer",
            ]:
                vision_x = self.vision_encoder(vision_x).last_hidden_state
            else:
                vision_x = self.vision_encoder(vision_x)[1]  # OpenCLIP returns tuples
        vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
        return vision_x

    def _concat_vision_cache(
        self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
    ):
        """
        Helper function to include the past vision tokens and past media locations in the output.
        """
        if use_cache:
            if past_media_locations is not None and past_vision_tokens is not None:
                if vision_tokens is not None:
                    updated_vision_tokens = torch.cat(
                        [
                            past_vision_tokens,
                            vision_tokens,
                        ],
                        dim=1,
                    )
                else:
                    updated_vision_tokens = past_vision_tokens
                updated_media_locations = torch.cat(
                    [
                        past_media_locations,
                        lang_x == self.media_token_id,
                    ],
                    dim=1,
                )
            else:
                updated_vision_tokens = vision_tokens
                updated_media_locations = lang_x == self.media_token_id

        else:
            updated_vision_tokens = None
            updated_media_locations = None

        return updated_vision_tokens, updated_media_locations

    def generate(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        past_key_values: Optional[
            List[Union[torch.Tensor, Tuple[torch.Tensor]]]
        ] = None,
        past_media_locations: Optional[torch.Tensor] = None,
        past_vision_tokens: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """
        Generate text conditioned on vision and language inputs.
        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W)
                see documentation for forward
            lang_x (torch.Tensor): Language input
                shape (B, T_txt)
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            **kwargs: see generate documentation in Hugging Face CausalLM models.
        Returns:
            torch.Tensor: lang_x with generated tokens appended to it
        """
        num_beams = kwargs.pop("num_beams", 1)

        # convert pixels to vision tokens
        if vision_x is not None:
            vision_features = self._encode_vision_x(vision_x=vision_x)
            vision_tokens = self.vision_tokenizer(vision_features)
        else:
            vision_tokens = None

        # fuse the vision and language tokens
        # for xattn, vision_x and media_location are repeat_interleaved s.t.
        # the total batch size is B * num_beams
        new_inputs = self._prepare_inputs_for_forward(
            vision_tokens=vision_tokens,
            lang_x=lang_x,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            past_media_locations=past_media_locations,
            past_vision_tokens=past_vision_tokens,
            padding_side="left",
            num_beams=num_beams,
        )
        output = self.lang_model.generate(
            **new_inputs,
            past_key_values=past_key_values,
            num_beams=num_beams,
            use_cache=True,
            **kwargs,
        )
        self._post_forward_hook()
        return output

    @property
    def num_trainable_params(self):
        """Print the number of trainable parameters"""
        return num_params(self, filter_to_trainable=True)

    def set_trainable(self):
        """
        Freeze appropriate parameters in the model.
        """
        raise NotImplementedError

    def group_params_by_weight_decay(self):
        """
        Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
        """
        params_with_wd, params_without_wd = [], []
        for n, p in self.named_parameters():
            if p.requires_grad:
                if self._should_apply_weight_decay(n):
                    params_with_wd.append(p)
                else:
                    params_without_wd.append(p)
        return params_with_wd, params_without_wd

    def _should_apply_weight_decay(self, parameter_name):
        """
        Return whether weight decay should be applied to a parameter.
        """
        raise NotImplementedError

    @property
    def special_tokens(self):
        """
        Returns a dict mapping from the attribute name of a special token to its string format,
         e.g. "media_token": "<image>"
        """
        assert (
            "media_token" in self._special_tokens
        ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
        return self._special_tokens

    @property
    def special_token_ids(self):
        """
        Returns a list of the special token ids
        """
        return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]

    def set_special_token_ids(self, string_to_ids):
        """
        Args:
            string_to_ids (dict): mapping from token string to id
        """
        assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
        for att_name, token_str in self.special_tokens.items():
            token_id = string_to_ids[token_str]
            setattr(self, f"{att_name}_id", token_id)
            setattr(self.lang_model, f"{att_name}_id", token_id)

    def init_gradient_checkpointing(self):
        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
            checkpoint_wrapper,
            CheckpointWrapper,
            CheckpointImpl,
            apply_activation_checkpointing,
        )
        from functools import partial

        non_reentrant_wrapper = partial(
            checkpoint_wrapper,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
        )
        apply_activation_checkpointing(
            self,
            checkpoint_wrapper_fn=non_reentrant_wrapper,
            check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
            and not isinstance(m, CheckpointWrapper),
        )


@dataclass
class VLMOutputWithPast(CausalLMOutputWithPast):
    """
    VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
        past_media_locations: Optional[torch.Tensor] = None,
        past_vision_tokens: Optional[torch.Tensor] = None,
    """

    past_media_locations: Optional[torch.Tensor] = None
    past_vision_tokens: Optional[torch.Tensor] = None


def exists(val):
    return val is not None


def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


class VLMWithLanguageStream(VLM):
    """
    VLM that fuses modalities by inserting vision tokens directly into the language stream.
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        vision_tokenizer: nn.Module,
        lang_model: nn.Module,
        initial_tokenizer_len: int,
        pad_token_id: int,
        decoder_layers_attr_name: str = None,
        gradient_checkpointing: bool = False,
    ):
        super().__init__(
            vision_encoder=vision_encoder,
            vision_tokenizer=vision_tokenizer,
            lang_model=lang_model,
            initial_tokenizer_len=initial_tokenizer_len,
            pad_token_id=pad_token_id,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.decoder_layers_attr_name = decoder_layers_attr_name
        if decoder_layers_attr_name is not None:
            for block in getattr_recursive(
                self.lang_model, self.decoder_layers_attr_name
            ):
                block._use_gradient_checkpointing = gradient_checkpointing

    def _prepare_inputs_for_forward(
        self,
        vision_tokens: torch.Tensor,
        lang_x: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor = None,
        past_key_values=None,
        vision_attention_mask: Optional[torch.Tensor] = None,
        past_media_locations: torch.Tensor = None,
        past_vision_tokens: torch.Tensor = None,
        padding_side: str = "left",
        num_beams: int = 1,
    ):
        """
        Insert the vision tokens directly into the language stream/
        This requires us to modify the input_ids, attention_mask, and labels.
        """
        if past_key_values is not None:
            past_len = past_key_values[0][0].shape[2]
            assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
                "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
                + "Check that you've expanded the attention mask to account for past image tokens."
            )

        if vision_tokens is None:
            return {
                "input_ids": lang_x,
                "attention_mask": attention_mask,
                "labels": labels,
            }

        # get the language embeddings
        lang_embeds = self.lang_model.get_input_embeddings()(lang_x)

        # build up the multimodal embeddings
        B = lang_x.shape[0]
        has_labels = labels is not None
        multimodal_embeds = []
        multimodal_attention_mask = []
        multimodal_labels = [] if has_labels else None
        for i in range(B):
            # get index of <image> tokens in lang_x[i]
            image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]

            if len(image_token_idxs) == 0:
                multimodal_embeds.append(lang_embeds[i].clone())
                multimodal_attention_mask.append(attention_mask[i].clone())
                if has_labels:
                    multimodal_labels.append(labels[i].clone())
                continue

            # loop through the image_token_idxs and insert the vision tokens
            new_embed = lang_embeds[i].clone()
            new_attention_mask = (
                attention_mask[i].clone() if attention_mask is not None else None
            )
            if has_labels:
                new_label = labels[i].clone()

            for img_num in range(len(image_token_idxs)):
                img_idx = image_token_idxs[img_num]
                # Get vision token attention mask for padded llava-style any resolution image tokens.
                if self.image_aspect_ratio == "anyres":
                    num_vis_tokens = vision_tokens[i][img_num].shape[0]
                    if vision_attention_mask is not None:
                        vis_attention_mask = vision_attention_mask[i]
                    else:
                        vis_attention_mask = torch.ones(
                            num_vis_tokens, dtype=torch.long
                        ).to(attention_mask.device)
                else:
                    assert (
                        vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
                    ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
                            vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
                    # By default, vision tokens are not padded.
                    num_vis_tokens = self.num_tokens_per_vis
                    vis_attention_mask = torch.ones(
                        num_vis_tokens, dtype=torch.long
                    ).to(attention_mask.device)
                
                # Offset the rest of image tokens with current num_vis_tokens
                for j in range(img_num+1, len(image_token_idxs)):
                    image_token_idxs[j] += (num_vis_tokens - 1)

                new_embed = torch.cat(
                    (
                        new_embed[:img_idx],
                        vision_tokens[i][img_num],
                        new_embed[img_idx + 1 :],
                    ),
                    dim=0,
                )
                new_attention_mask = torch.cat(
                    (
                        new_attention_mask[:img_idx],
                        vis_attention_mask,
                        new_attention_mask[img_idx + 1 :],
                    ),
                    dim=0,
                )
                if has_labels:
                    new_label = torch.cat(
                        (
                            new_label[:img_idx],
                            torch.ones(num_vis_tokens, dtype=torch.long).to(
                                labels.device
                            )
                            * -100,
                            new_label[img_idx + 1 :],
                        ),
                        dim=0,
                    )
            multimodal_embeds.append(new_embed)
            multimodal_attention_mask.append(new_attention_mask)
            if has_labels:
                multimodal_labels.append(new_label)

        # stack
        multimodal_embeds = stack_with_padding(
            multimodal_embeds,
            padding_value=self.pad_token_id,
            padding_side=padding_side,
        )
        multimodal_attention_mask = stack_with_padding(
            multimodal_attention_mask,
            padding_value=0,
            padding_side=padding_side,
        )
        if has_labels:
            multimodal_labels = stack_with_padding(
                multimodal_labels,
                padding_value=-100,
                padding_side=padding_side,
            )

        return {
            "inputs_embeds": multimodal_embeds,
            "attention_mask": multimodal_attention_mask,
            "labels": multimodal_labels,
        }

    def _postprocess_outputs_from_forward(
        self,
        output: CausalLMOutputWithPast,
        lang_x: torch.Tensor,
        vision_tokens: torch.Tensor,
        past_vision_tokens: torch.Tensor,
        past_media_locations: torch.Tensor,
        use_cache: bool = False,
    ):
        # Include the past vision tokens and past media locations in the output
        updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
            lang_x=lang_x,
            vision_tokens=vision_tokens,
            past_vision_tokens=past_vision_tokens,
            past_media_locations=past_media_locations,
            use_cache=use_cache,
        )

        # return logits that are the same shape as the original input_ids
        logits = output.logits
        batch_logits = []
        B, T_txt = lang_x.shape
        for i in range(B):
            sequence_logits = []
            logits_j = 0
            for j in range(T_txt):
                if lang_x[i, j] != self.media_token_id:
                    sequence_logits.append(logits[i, logits_j])
                    logits_j += 1
                else:
                    # append the logit for the first image token, then skip over the rest
                    # note: the model actually learns to predict <im_patch>, not <image>
                    sequence_logits.append(logits[i, logits_j])
                    logits_j += self.num_tokens_per_vis
            sequence_logits = torch.stack(sequence_logits, dim=0)  # (B, vocab_size)
            batch_logits.append(sequence_logits)

        batch_logits = torch.stack(batch_logits, dim=0)  # (B, T_txt, vocab_size)
        # The final logits shape should be the same as the original input_ids shape
        assert batch_logits.shape[:2] == (B, T_txt)

        # assemble the output
        output = VLMOutputWithPast(
            loss=output.loss,
            logits=batch_logits,
            past_key_values=output.past_key_values,
            hidden_states=output.hidden_states,
            attentions=output.attentions,
            past_media_locations=updated_media_locations,
            past_vision_tokens=updated_vision_tokens,
        )

        return output

    def _post_forward_hook(self):
        pass

    @property
    def num_params_per_module(self):
        """Print the number of parameters per module in the model"""
        return "\n".join(
            [
                f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
                f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
                f"Language model: {num_params(self.lang_model):,} parameters",
            ]
        )

    @property
    def num_trainable_params_per_module(self):
        """Print the number of trainable parameters per module in the model"""
        return "\n".join(
            [
                f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
                f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
                f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
            ]
        )


class XGenMMPerceiver(VLMWithLanguageStream):
    def __init__(
        self,
        vision_encoder: nn.Module,
        vision_tokenizer: nn.Module,
        lang_model: nn.Module,
        initial_tokenizer_len: int,
        pad_token_id: int,
        decoder_layers_attr_name: str = None,
        gradient_checkpointing: bool = False,
        image_aspect_ratio: str = "anyres",
        anyres_patch_sampling: bool = True,
        anyres_grids: list[int] = None,
    ):
        """
        Args:
            vision_encoder (nn.Module): HF CLIPModel
            lang_encoder (nn.Module): HF causal language model
            vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
            initial_tokenizer_len (int): size of the tokenizer vocab
            padding_token_id (int): id of the padding token. None if no padding token; then a padding token
                will be inserted into self.special_tokens, which factory.py fills after creating new tokens
            decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
            gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
        """
        self._special_tokens = {
            "media_token": "<image>",
            "image_placeholder_token": "<image placeholder>",
            "end_of_trunk_token": "<|endofchunk|>",
        }
        lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
        super().__init__(
            vision_encoder=vision_encoder,
            vision_tokenizer=vision_tokenizer,
            lang_model=lang_model,
            initial_tokenizer_len=initial_tokenizer_len,
            gradient_checkpointing=gradient_checkpointing,
            decoder_layers_attr_name=decoder_layers_attr_name,
            pad_token_id=pad_token_id,
        )
        self.image_aspect_ratio = image_aspect_ratio
        self.anyres_patch_sampling = anyres_patch_sampling
        self.anyres_grids = anyres_grids

    def set_trainable(self):
        """
        Unfreeze everything except the vision_encoder
        """
        self.requires_grad_(True)
        self.vision_encoder.requires_grad_(False)

    def _should_apply_weight_decay(self, parameter_name):
        """
        Kosmos applies 0.01 weight deacy to everything
        """
        return True

    def generate(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        image_size: Optional[Tuple] = None,
        attention_mask: torch.Tensor = None,
        past_key_values: Optional[
            List[Union[torch.Tensor, Tuple[torch.Tensor]]]
        ] = None,
        past_media_locations: Optional[torch.Tensor] = None,
        past_vision_tokens: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """
        Generate text conditioned on vision and language inputs.
        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W)
                see documentation for forward
            lang_x (torch.Tensor): Language input
                shape (B, T_txt)
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            **kwargs: see generate documentation in Hugging Face CausalLM models.
        Returns:
            torch.Tensor: lang_x with generated tokens appended to it
        """
        num_beams = kwargs.pop("num_beams", 1)

        # convert pixels to vision tokens
        vision_attention_mask = None
        if vision_x is not None:
            if self.image_aspect_ratio == "anyres":
                input_dict = dict(image=vision_x, image_size=image_size)
                vision_features, vision_attn_masks = self._encode_vision_x_anyres(
                    input_dict, lang_x.device
                )
            else:
                vision_features = self._encode_vision_x(vision_x=vision_x)
                vision_attn_masks = None
            # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
            # Same for attention masks: [b, Np, v] -> [b*Np, v]
            if self.anyres_patch_sampling:
                split_sizes = [feature.shape[0] for feature in vision_features]
                # Nested splits for multi-image samples.
                if isinstance(vision_x[0], list):
                    nt_images = [len(images) for images in vision_x]
                    split_split_sizes = []
                    img_id = 0
                    for nt in nt_images:
                        split_split_sizes.append(split_sizes[img_id : img_id + nt])
                        img_id += nt
                else:
                    nt_images = [1] * len(vision_x)
                    split_split_sizes = split_sizes
                vision_features = torch.cat(vision_features, dim=0)
                vision_features = vision_features[
                    :, None, None, :, :
                ]  # Expand dimensions.
                vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
            vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)

            # Post-processing: Split the batches into groups of patches and concatenate them together.
            if self.anyres_patch_sampling:
                assert isinstance(vision_x, list)
                if isinstance(vision_x[0], list):
                    vision_token_groups = torch.split(
                        vision_tokens,
                        list(sum(nt_img) for nt_img in split_split_sizes),
                        dim=0,
                    )
                    vision_tokens = []

                    for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
                        patch_vis_token_groups = torch.split(
                            patch_vis_tokens, split_split_sizes[sample_id], dim=0
                        )  # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
                        flatten_vision_tokens = []
                        for image_vis_token in patch_vis_token_groups:
                            image_vis_token = image_vis_token.flatten(
                                0, 2
                            )  # [Np, 1, v, d] -> [Np*v, d]
                            flatten_vision_tokens.append(image_vis_token)
                        vision_tokens_i = flatten_vision_tokens
                        vision_tokens.append(vision_tokens_i)
                else:
                    vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
                    vision_tokens = []
                    for patch_vis_tokens in vision_token_groups:
                        patch_vis_tokens = patch_vis_tokens.flatten(
                            0, 2
                        )  # [Np, 1, v, d] -> [Np*v, d]
                        vision_tokens.append(
                            patch_vis_tokens.unsqueeze(0)
                        )  # Add the nt dimension.
        else:
            vision_tokens = None

        # fuse the vision and language tokens
        # for xattn, vision_x and media_location are repeat_interleaved s.t.
        # the total batch size is B * num_beams
        new_inputs = self._prepare_inputs_for_forward(
            vision_tokens=vision_tokens,
            lang_x=lang_x,
            attention_mask=attention_mask,
            vision_attention_mask=vision_attention_mask,
            past_key_values=past_key_values,
            past_media_locations=past_media_locations,
            past_vision_tokens=past_vision_tokens,
            padding_side="left",
            num_beams=num_beams,
        )
        if past_key_values is not None:
            output = self.lang_model.generate(
                **new_inputs,
                past_key_values=past_key_values,
                num_beams=num_beams,
                use_cache=True,
                **kwargs,
            )
        else:
            output = self.lang_model.generate(
                **new_inputs,
                num_beams=num_beams,
                use_cache=True,
                **kwargs,
            )
        self._post_forward_hook()
        return output


class XGenMMVisionEncoder(PreTrainedModel):
    main_input_name = "pixel_values"
    config_class = XGenMMVisionEncoderConfig

    def __init__(self, config: XGenMMVisionEncoderConfig):
        super().__init__(config)
        if config.model_name != "google/siglip-so400m-patch14-384":
            raise ValueError(
                f"Unsupported model {config.model_name}. New vision models will be added soon."
            )
        self.model = AutoModel.from_pretrained(config.model_name)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
        return self.model.encode_image(pixel_values)


# vision tokenizer
class XGenMMVisionTokenizer(PreTrainedModel):
    config_class = XGenMMVisionTokenizerConfig

    def __init__(self, config: XGenMMVisionTokenizerConfig):
        super().__init__(config)
        self.model = PerceiverResampler(
            dim=config.vis_feature_dim,
            dim_inner=config.lang_embedding_dim,
            num_latents=config.num_vis_tokens,
        )

    def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor):
        return self.model(vision_features, vision_attn_masks)


# XGenMM model
class XGenMMModelForConditionalGeneration(PreTrainedModel):
    config_class = XGenMMConfig

    def __init__(self, config: XGenMMConfig):
        super().__init__(config)

        # vision encoder initialization
        vision_encoder = AutoModel.from_pretrained(
            config.vision_encoder_config.model_name,
            torch_dtype=config.text_config.torch_dtype,
        ).vision_model

        # language model initialization
        language_model = AutoModelForCausalLM.from_config(
            config.text_config,
            torch_dtype=config.text_config.torch_dtype,
        )
        check_embedding_fns(language_model)
        # Update _tied_weights_keys using the base model used.
        if language_model._tied_weights_keys is not None:
            self._tied_weights_keys = [
                f"language_model.{k}" for k in language_model._tied_weights_keys
            ]

        # vision tokenizer initialization
        if (
            config.vision_tokenizer_config.lang_embedding_dim
            != language_model.get_input_embeddings().weight.shape[1]
        ):
            overwrite = language_model.get_input_embeddings().weight.shape[1]
            config.vision_tokenizer_config.lang_embedding_dim = overwrite
            print(
                f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}."
            )

        vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model.to(language_model.dtype)

        self.vlm = XGenMMPerceiver(
            vision_encoder=vision_encoder,
            vision_tokenizer=vision_tokenizer,
            lang_model=language_model,
            initial_tokenizer_len=config.text_config.initial_tokenizer_len,
            pad_token_id=config.text_config.pad_token_id,
            image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio,
            anyres_patch_sampling=config.vision_encoder_config.anyres_patch_sampling,
            anyres_grids=config.vision_encoder_config.anyres_grids,
        )
        # Initialize weights and apply final processing
        self.post_init()

    @torch.no_grad()
    def generate(
        self,
        pixel_values: torch.FloatTensor,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        **generate_kwargs,
    ) -> torch.LongTensor:
        self.vlm = self.vlm.eval()
        return self.vlm.generate(
            vision_x=pixel_values,
            lang_x=input_ids,
            attention_mask=attention_mask,
            **generate_kwargs,
        )

    def update_special_tokens(self, tokenizer):
        tokenizer.add_special_tokens(
            {"additional_special_tokens": list(self.vlm.special_tokens.values())}
        )
        self.vlm.lang_model.config.vocab_size = len(tokenizer)
        self.vlm.set_special_token_ids(
            {
                v: tokenizer.convert_tokens_to_ids(v)
                for v in self.vlm.special_tokens.values()
            }
        )
        return tokenizer