# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import warnings
from typing import List, Optional, Tuple, Union
import torch.utils.checkpoint
import transformers
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from transformers import WhisperConfig, WhisperModel, WhisperProcessor
from .configuration_internvl_chat import InternVLChatConfig
from .conversation import get_conv_template
from .modeling_intern_vit import InternVisionModel, has_flash_attn
logger = logging.get_logger(__name__)
def version_cmp(v1, v2, op='eq'):
import operator
from packaging import version
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
class InternVLChatModel(PreTrainedModel):
config_class = InternVLChatConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_supports_flash_attn_2 = True
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
super().__init__(config)
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = config.template
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
use_flash_attn = use_flash_attn if has_flash_attn else False
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
logger.info(f'num_image_token: {self.num_image_token}')
logger.info(f'ps_version: {self.ps_version}')
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = InternVisionModel(config.vision_config)
if language_model is not None:
self.language_model = language_model
else:
if config.llm_config.architectures[0] == 'LlamaForCausalLM':
self.language_model = LlamaForCausalLM(config.llm_config)
elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
self.language_model = Qwen2ForCausalLM(config.llm_config)
else:
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
# whisper model
whisper_config = WhisperConfig(**self.config.audio_config)
self.audio_model = WhisperModel.from_pretrained(
"/data/nvme5n1p1/vladimir_workspace/audio_internvl/models/whisper-large-v3-turbo",
config=whisper_config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# Remove decoder since we only need the encoder
del self.audio_model.decoder
# Initialize audio processor
self.audio_processor = WhisperProcessor.from_pretrained("/data/nvme5n1p1/vladimir_workspace/audio_internvl/models/whisper-large-v3-turbo")
# Get hidden sizes
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
whisper_hidden_size = self.audio_model.config.d_model
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
# Audio projection
self.mlp2 = nn.Sequential(
nn.LayerNorm(whisper_hidden_size),
nn.Linear(whisper_hidden_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
self.audio_context_token_id = None
self.img_context_token_id = None
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
def process_audio_feature(self, audio_values, audio_flags):
print("\n=== Processing Audio Features ===")
print(f"Input audio shape: {audio_values.shape}")
print(f"Audio flags shape: {audio_flags.shape}")
# Ensure float32 for audio input
audio_values = audio_values.to(torch.float32)
print(f"Audio values min/max: {audio_values.min():.3f}/{audio_values.max():.3f}")
# Convert audio to features
if len(audio_values.shape) == 2:
audio_list = [arr.cpu().numpy() for arr in audio_values]
else:
audio_list = [audio_values.cpu().numpy()]
processed_audio = self.audio_processor(
audio_list,
sampling_rate=16000,
return_tensors="pt"
)
audio_features = processed_audio["input_features"].to(self.device)
print(f"Processed audio features shape: {audio_features.shape}")
# Convert to float32 before encoder
audio_features = audio_features.to(torch.float32)
# Get encoder outputs
with torch.cuda.amp.autocast(enabled=False): # Disable mixed precision
audio_outputs = self.audio_model.encoder(audio_features)
audio_embeds = audio_outputs.last_hidden_state
print(f"Whisper encoder output shape: {audio_embeds.shape}")
audio_embeds = audio_embeds.to(torch.float32) # Ensure float32
print(f"Encoder output min/max: {audio_embeds.min():.3f}/{audio_embeds.max():.3f}")
# Reshape to match the expected number of tokens (300)
B, T, C = audio_embeds.shape
target_length = 300
# Use adaptive pooling to get the desired length
adaptive_pool = torch.nn.AdaptiveAvgPool1d(target_length)
audio_embeds = audio_embeds.transpose(1, 2) # [B, C, T]
audio_embeds = adaptive_pool(audio_embeds) # [B, C, 300]
audio_embeds = audio_embeds.transpose(1, 2) # [B, 300, C]
print(f"After pooling shape: {audio_embeds.shape}")
# More robust normalization before MLP2
audio_embeds = audio_embeds.float()
# First normalize per-token with more stable computation
mean = audio_embeds.mean(dim=-1, keepdim=True)
std = audio_embeds.std(dim=-1, keepdim=True)
# Add larger epsilon and clip std to avoid division by zero
std = torch.clamp(std, min=1e-6)
audio_embeds = (audio_embeds - mean) / std
# Clip extreme values more conservatively
audio_embeds = torch.clamp(audio_embeds, -2.0, 2.0)
# Apply LayerNorm with larger eps
layer_norm = nn.LayerNorm(audio_embeds.shape[-1], eps=1e-4).to(audio_embeds.device)
audio_embeds = layer_norm(audio_embeds)
print(f"Pre-MLP2 stats - min: {audio_embeds.min():.3f}, max: {audio_embeds.max():.3f}")
# Project to LLM dimension with gradient scaling and additional checks
with torch.cuda.amp.autocast(enabled=False):
# Pre-normalize and scale more carefully
mean = audio_embeds.mean(dim=-1, keepdim=True)
std = audio_embeds.std(dim=-1, keepdim=True)
std = torch.clamp(std, min=1e-6)
audio_embeds = (audio_embeds - mean) / std
# Scale down more conservatively before MLP2
audio_embeds = audio_embeds * 0.05 # Reduced from 0.1
# Apply MLP2 with gradient scaling
audio_embeds = self.mlp2(audio_embeds)
if torch.isnan(audio_embeds).any() or torch.isinf(audio_embeds).any():
print("WARNING: NaN/Inf detected after MLP2! Using robust recovery...")
audio_embeds = torch.nan_to_num(audio_embeds, nan=0.0, posinf=1.0, neginf=-1.0)
# Normalize with small noise
mean = audio_embeds.mean(dim=-1, keepdim=True)
std = audio_embeds.std(dim=-1, keepdim=True)
std = torch.clamp(std, min=1e-6)
audio_embeds = (audio_embeds - mean) / std
audio_embeds = audio_embeds + torch.randn_like(audio_embeds) * 0.0001
# Final scaling to match LLM exactly
llm_std = 0.009
audio_embeds = audio_embeds * llm_std
return audio_embeds
def forward(
self,
pixel_values: torch.FloatTensor = None,
audio_values: torch.FloatTensor = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
audio_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
# Process images if present
if pixel_values is not None:
image_flags = image_flags.squeeze(-1)
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = pixel_values.shape[0]
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
image_selected = (input_ids == self.img_context_token_id)
try:
input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
n_token = image_selected.sum()
input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + vit_embeds[:n_token]
# Process audio if present
if audio_values is not None and audio_flags is not None:
audio_flags = audio_flags.squeeze(-1)
audio_embeds = self.process_audio_feature(audio_values, audio_flags)
audio_batch_size = audio_values.shape[0] if len(audio_values.shape) > 1 else 1
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f'dynamic Audio batch size: {audio_batch_size}, audio per sample: {audio_batch_size / B}, dynamic token length: {N}')
audio_selected = (input_ids == self.audio_context_token_id)
try:
input_embeds[audio_selected] = input_embeds[audio_selected] * 0.0 + audio_embeds.reshape(-1, C)
except Exception as e:
audio_embeds = audio_embeds.reshape(-1, C)
n_token = audio_selected.sum()
input_embeds[audio_selected] = input_embeds[audio_selected] * 0.0 + audio_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
'which results in a transposed image.')
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='',
IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None):
if history is not None or return_history:
print('Now multi-turn chat is not supported in batch_chat.')
raise NotImplementedError
if image_counts is not None:
num_patches_list = image_counts
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
queries = []
for idx, num_patches in enumerate(num_patches_list):
question = questions[idx]
if pixel_values is not None and '' not in question:
question = '\n' + question
template = get_conv_template(self.template)
template.system_message = self.system_message
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('', image_tokens, 1)
queries.append(query)
tokenizer.padding_side = 'left'
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
responses = [response.split(template.sep.strip())[0].strip() for response in responses]
return responses
def chat(self, tokenizer, pixel_values=None, question=None, generation_config=None,
history=None, return_history=False, num_patches_list=None,
IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='',
AUDIO_START_TOKEN='', AUDIO_CONTEXT_TOKEN='',
verbose=False, **kwargs): # Add **kwargs to catch extra arguments
"""Chat function that handles both text-only and multimodal inputs"""
print("=== Starting Chat Process ===")
print(f"Question: {question}")
print(f"Input types - Pixel values: {type(pixel_values)}, Audio values: {type(kwargs.get('audio_values'))}")
# Basic input validation
if question is None:
raise ValueError("Question cannot be None")
if not isinstance(question, str):
raise ValueError(f"Question must be string, got {type(question)}")
audio_values = kwargs.get('audio_values', None)
# Handle image prompt
if history is None and pixel_values is not None and '' not in question:
question = '\n' + question
print("Added image token to question")
# Handle audio prompt
if history is None and audio_values is not None and '