import os from enum import Enum from pathlib import Path from dataclasses import dataclass from typing import List, Optional, Union, Tuple import torch import torch.utils.checkpoint from torch import nn from transformers import AutoModelForCausalLM from transformers.models.auto import CONFIG_MAPPING from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.processing_utils import ProcessorMixin from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import ModelOutput from transformers.feature_extraction_utils import BatchFeature from transformers.tokenization_utils_base import ( TextInput, TensorType, PaddingStrategy, PreTokenizedInput, TruncationStrategy ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from .processor_mm import ( load_and_transform_image_data, load_and_transform_video_data, load_and_transform_audio_data ) from .imagebind_model import * from .helpers import * from .multimodal_preprocessors import * from .transformer import * class ModalityType(Enum): TEXT = "text" IMAGE = "image" VIDEO = "video" AUDIO = "audio" VISION = "vision" # For Imagebind def __str__(self): return self.value def __eq__(self, other): if isinstance(other, ModalityType): return self.value == other.value elif isinstance(other, str): return self.value == other return False def __hash__(self): return hash(self.value) _CONFIG_FOR_DOC = "AnyModelConfig" class AnyModelConfig(PretrainedConfig): model_type = "any_model" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, modality_config=None, text_config=None, ignore_index=-100, image_token_index=128256, video_token_index=128257, audio_token_index=128258, projector_hidden_act="gelu", **kwargs, ): if isinstance(text_config, dict): text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: text_config = CONFIG_MAPPING["llama"]() self.modality_config = modality_config self.text_config = text_config self.ignore_index = ignore_index self.image_token_index = image_token_index self.video_token_index = video_token_index self.audio_token_index = audio_token_index self.projector_hidden_act = projector_hidden_act super().__init__( **kwargs, ) class AnyModelProcessor(ProcessorMixin): # TODO: Add support for any_model_processor # attributes = ["any_model_processor", "tokenizer"] attributes = ["tokenizer"] valid_kwargs = ["chat_template"] any_model_processor_class = "AnyModelProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, tokenizer=None, **kwargs): super().__init__(tokenizer, **kwargs) if self.tokenizer is not None: self.tokenizer.add_special_tokens({"additional_special_tokens": ["", "