|
|
|
|
|
""" |
|
================================================ |
|
@author: Jaron |
|
@time: 2024/08/21 17:41:52 |
|
@email: [email protected] |
|
@description: Video-CCAM |
|
================================================ |
|
""" |
|
import torch |
|
import os.path as osp |
|
|
|
from PIL import Image |
|
from peft import PeftModel |
|
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, SiglipVisionModel, SiglipImageProcessor, GenerationConfig |
|
|
|
|
|
from .configuration_videoccam import VideoCCAMConfig |
|
|
|
|
|
class VideoCCAM(PreTrainedModel): |
|
config_class = VideoCCAMConfig |
|
_auto_class = 'AutoModel' |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
def __init__(self, config, device_map: str = 'auto'): |
|
super().__init__(config) |
|
self.image_token = config.image_token |
|
self.video_token = config.video_token |
|
self.vision_select_layer = config.vision_select_layer |
|
self.vision_max_chunk_size = config.vision_max_chunk_size |
|
self.gradient_checkpointing = False |
|
|
|
self.projector = AutoModel.from_pretrained( |
|
config.projector_name_or_path, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
torch_dtype=config.torch_dtype, |
|
attn_implementation='sdpa' if config._attn_implementation == 'flash_attention_2' else config._attn_implementation |
|
) |
|
self.llm = AutoModelForCausalLM.from_pretrained( |
|
config.llm_name_or_path, |
|
device_map=device_map, |
|
torch_dtype=config.torch_dtype, |
|
attn_implementation=config._attn_implementation |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
config.llm_name_or_path, |
|
additional_special_tokens=[self.image_token, self.video_token] |
|
) |
|
self.generation_config = GenerationConfig.from_pretrained(config.llm_name_or_path) |
|
self.image_token_id, self.video_token_id = self.tokenizer.convert_tokens_to_ids([self.image_token, self.video_token]) |
|
self.vision_encoder = SiglipVisionModel.from_pretrained( |
|
config.vision_encoder_name_or_path, |
|
device_map=device_map, |
|
torch_dtype=config.torch_dtype, |
|
attn_implementation=config._attn_implementation |
|
) |
|
self.image_processor = SiglipImageProcessor.from_pretrained( |
|
config.vision_encoder_name_or_path |
|
) |
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
if gradient_checkpointing_kwargs is None: |
|
gradient_checkpointing_kwargs = dict(use_reentrant=False) |
|
self.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
|
self.vision_encoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
|
|
|
def forward_visual_embeds(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
if self.vision_select_layer in {-1, self.vision_encoder.config.num_hidden_layers}: |
|
visual_embeds = self.vision_encoder(pixel_values, output_hidden_states=False).last_hidden_state |
|
else: |
|
visual_embeds = self.vision_encoder(pixel_values, output_hidden_states=True).hidden_states[self.vision_select_layer] |
|
return visual_embeds |
|
|
|
@torch.inference_mode |
|
def chat( |
|
self, |
|
messages: list[list[dict]], |
|
images: list[Image.Image, list[Image.Image]] = None, |
|
generation_config = None, |
|
batch_generate: bool = False, |
|
visual_embeds: torch.Tensor = None, |
|
return_visual_embeds: bool = False, |
|
**kwargs |
|
): |
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
|
|
if visual_embeds is None: |
|
_images, split_size = [], [] |
|
for i in images: |
|
if isinstance(i, Image.Image): |
|
_images.append(i) |
|
split_size.append(1) |
|
else: |
|
_images += i |
|
split_size.append(len(i)) |
|
pixel_values = self.image_processor( |
|
_images, |
|
return_tensors='pt' |
|
)['pixel_values'].to( |
|
dtype=self.vision_encoder.get_input_embeddings().weight.dtype, |
|
device=self.vision_encoder.get_input_embeddings().weight.device |
|
) |
|
if 0 < self.vision_max_chunk_size < len(pixel_values): |
|
split_idx = list(range(0, len(pixel_values), self.vision_max_chunk_size)) + [-1] |
|
visual_embeds = torch.cat([ |
|
self.forward_visual_embeds(pixel_values[le:ri]) |
|
for le, ri in zip(split_idx[:-1], split_idx[1:]) |
|
], dim=0) |
|
else: |
|
visual_embeds = self.forward_visual_embeds(pixel_values) |
|
visual_embeds = self.projector(visual_embeds.split(split_size, dim=0)) |
|
|
|
|
|
device = self.llm.get_input_embeddings().weight.device |
|
input_ids = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) |
|
_input_ids, split_idx = [], [0] |
|
for i in input_ids: |
|
_input_ids += i |
|
split_idx.append(split_idx[-1] + len(i)) |
|
_input_ids = torch.tensor(_input_ids, dtype=torch.long, device=device) |
|
visual_idx = torch.where((_input_ids == self.image_token_id) | (_input_ids == self.video_token_id))[0].tolist() |
|
assert len(visual_idx) == len(visual_embeds), f'The number of visual tokens ({len(visual_idx)}) should be equal to the number of visual features ({len(visual_embeds)}).' |
|
|
|
_input_ids[visual_idx] = 0 |
|
_inputs_embeds = self.llm.get_input_embeddings()(_input_ids) |
|
inputs_embeds, cur_visual_pointer = [], 0 |
|
for start_idx, end_idx in zip(split_idx[:-1], split_idx[1:]): |
|
if cur_visual_pointer < len(visual_idx) and visual_idx[cur_visual_pointer] < end_idx: |
|
mid_idx = visual_idx[cur_visual_pointer] |
|
embeds = [_inputs_embeds[start_idx:mid_idx], visual_embeds[cur_visual_pointer]] |
|
cur_visual_pointer += 1 |
|
while cur_visual_pointer < len(visual_idx) and visual_idx[cur_visual_pointer] < end_idx: |
|
embeds += [_inputs_embeds[mid_idx+1:visual_idx[cur_visual_pointer]], visual_embeds[cur_visual_pointer]] |
|
mid_idx = visual_idx[cur_visual_pointer] |
|
cur_visual_pointer += 1 |
|
embeds.append(_inputs_embeds[mid_idx+1:end_idx]) |
|
inputs_embeds.append(torch.cat(embeds, dim=0)) |
|
|
|
else: |
|
inputs_embeds.append(_inputs_embeds[start_idx:end_idx]) |
|
|
|
if batch_generate: |
|
B, L = len(inputs_embeds), max(i.size(0) for i in inputs_embeds) |
|
pad_embeds = self.llm.get_input_embeddings()( |
|
torch.tensor([self.tokenizer.pad_token_id], dtype=torch.long, device=device) |
|
) |
|
inputs_embeds_list = [] |
|
attention_mask = torch.zeros(B, L, dtype=torch.long, device=device) |
|
for i, embeds in enumerate(inputs_embeds): |
|
l = embeds.size(0) |
|
inputs_embeds_list += [pad_embeds.expand(L - l, -1), embeds] |
|
attention_mask[i, -l:] = 1 |
|
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).view(B, L, -1) |
|
output_ids = self.llm.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
**kwargs |
|
) |
|
else: |
|
output_ids = [] |
|
for embeds in inputs_embeds: |
|
output_ids.append(self.llm.generate( |
|
inputs_embeds=embeds[None], |
|
attention_mask=torch.ones(1, embeds.size(0), dtype=torch.long, device=device), |
|
generation_config=generation_config, |
|
**kwargs |
|
)[0]) |
|
prediction = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
if return_visual_embeds: |
|
return prediction, visual_embeds |
|
else: |
|
return prediction |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: str, |
|
*args, |
|
config: VideoCCAMConfig = None, |
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
device_map: str = 'auto', |
|
**kwargs |
|
) -> PreTrainedModel: |
|
merge_pretrained_lora = kwargs.pop('merge_pretrained_lora', True) |
|
|
|
config.torch_dtype = torch_dtype |
|
config.projector_name_or_path = osp.join(pretrained_model_name_or_path, 'projector') |
|
if osp.isdir(cur_path := osp.join(pretrained_model_name_or_path, 'llm')): |
|
config.llm_name_or_path = cur_path |
|
if osp.isdir(cur_path := osp.join(pretrained_model_name_or_path, 'vision_encoder')): |
|
config.vision_encoder_name_or_path = cur_path |
|
model = cls(config, device_map) |
|
|
|
|
|
if osp.exists(cur_path := osp.join(pretrained_model_name_or_path, 'llm_adapter')): |
|
model.llm = PeftModel.from_pretrained(model.llm, cur_path, device_map=device_map) |
|
print(f'Load LLM adapter from {cur_path}.') |
|
if merge_pretrained_lora: |
|
model.llm = model.llm.merge_and_unload() |
|
if osp.exists(cur_path := osp.join(pretrained_model_name_or_path, 'vision_encoder_adapter')): |
|
model.vision_encoder = PeftModel.from_pretrained(model.vision_encoder, cur_path, device_map=device_map) |
|
print(f'Load vision encoder adapter from {cur_path}.') |
|
if merge_pretrained_lora: |
|
model.vision_encoder = model.vision_encoder.merge_and_unload() |
|
|
|
return model |
|
|