|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.model import BaseModel |
|
from peft import get_peft_model, prepare_model_for_kbit_training |
|
|
|
from xtuner.registry import BUILDER |
|
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules |
|
from .utils import (LoadWoInit, find_all_linear_names, |
|
get_peft_model_state_dict, guess_load_checkpoint, |
|
make_inputs_require_grad, |
|
prepare_inputs_labels_for_multimodal, traverse_dict) |
|
|
|
|
|
class LLaVAModel(BaseModel): |
|
|
|
def __init__(self, |
|
llm, |
|
siglip, |
|
dino, |
|
freeze_llm=False, |
|
freeze_visual_encoder=False, |
|
visual_select_layer=-2, |
|
pretrained_pth=None, |
|
projector_depth=2, |
|
llm_lora=None, |
|
visual_encoder_lora=None, |
|
use_activation_checkpointing=True): |
|
super().__init__() |
|
self.freeze_llm = freeze_llm |
|
self.freeze_visual_encoder = freeze_visual_encoder |
|
with LoadWoInit(): |
|
self.llm = self._build_from_cfg_or_module(llm) |
|
self.siglip = self._build_from_cfg_or_module(siglip) |
|
self.dino = self._build_from_cfg_or_module(dino) |
|
self.llm.config.use_cache = False |
|
dispatch_modules(self.llm) |
|
|
|
projector_config = ProjectorConfig( |
|
visual_hidden_size=self.siglip.config.hidden_size + self.dino.config.hidden_size, |
|
llm_hidden_size=self.llm.config.hidden_size, |
|
depth=projector_depth) |
|
self.projector = ProjectorModel(projector_config).to( |
|
self.siglip.dtype) |
|
|
|
if self.freeze_llm: |
|
self.llm.requires_grad_(False) |
|
if self.freeze_visual_encoder: |
|
self.siglip.requires_grad_(False) |
|
self.dino.requires_grad_(False) |
|
|
|
if use_activation_checkpointing: |
|
|
|
if hasattr(self.llm, 'enable_input_require_grads'): |
|
self.llm.enable_input_require_grads() |
|
else: |
|
self.llm.get_input_embeddings().register_forward_hook( |
|
make_inputs_require_grad) |
|
|
|
if hasattr(self.siglip, 'enable_input_require_grads'): |
|
self.siglip.enable_input_require_grads() |
|
else: |
|
self.siglip.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
|
|
if hasattr(self.dino, 'enable_input_require_grads'): |
|
self.dino.enable_input_require_grads() |
|
else: |
|
self.dino.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
self.projector.enable_input_require_grads() |
|
|
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
self.use_llm_lora = llm_lora is not None |
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
|
|
if self.use_llm_lora: |
|
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
|
if self.use_visual_encoder_lora: |
|
self._prepare_visual_encoder_for_lora( |
|
visual_encoder_lora, use_activation_checkpointing) |
|
|
|
if pretrained_pth is not None: |
|
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
|
|
self.load_state_dict(pretrained_state_dict, strict=False) |
|
print(f'Load pretrained weight from {pretrained_pth}') |
|
|
|
self.visual_select_layer = visual_select_layer |
|
|
|
self._is_init = True |
|
|
|
def _parse_lora_config(self, lora_config): |
|
if isinstance(lora_config, dict) or isinstance( |
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
lora_config = BUILDER.build(lora_config) |
|
return lora_config |
|
|
|
def _prepare_llm_for_lora(self, |
|
lora_config, |
|
use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
self.llm = prepare_model_for_kbit_training( |
|
self.llm, use_activation_checkpointing) |
|
if lora_config.target_modules is None: |
|
modules = find_all_linear_names(self.llm) |
|
lora_config.target_modules = modules |
|
self.llm = get_peft_model(self.llm, lora_config) |
|
|
|
def _prepare_visual_encoder_for_lora(self, |
|
lora_config, |
|
use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
modules = find_all_linear_names(self.siglip) |
|
lora_config.target_modules = modules |
|
self.siglip = get_peft_model(self.siglip, lora_config) |
|
modules = find_all_linear_names(self.dino) |
|
lora_config.target_modules = modules |
|
self.dino = get_peft_model(self.dino, lora_config) |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.activation_checkpointing_enable() |
|
|
|
def activation_checkpointing_enable(self): |
|
self.llm.gradient_checkpointing_enable() |
|
self.siglip.gradient_checkpointing_enable() |
|
self.dino.gradient_checkpointing_enable() |
|
self.projector.gradient_checkpointing_enable() |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.activation_checkpointing_disable() |
|
|
|
def activation_checkpointing_disable(self): |
|
self.llm.gradient_checkpointing_disable() |
|
self.siglip.gradient_checkpointing_disable() |
|
self.dino.gradient_checkpointing_disable() |
|
self.projector.gradient_checkpointing_disable() |
|
|
|
def init_weights(self): |
|
pass |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
to_return = OrderedDict() |
|
|
|
if self.use_visual_encoder_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.siglip, state_dict=state_dict)) |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.dino, state_dict=state_dict)) |
|
elif not self.freeze_visual_encoder: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'siglip.' in k |
|
}) |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'dino.' in k |
|
}) |
|
|
|
if self.use_llm_lora: |
|
to_return.update( |
|
get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
|
elif not self.freeze_llm: |
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'llm.' in k}) |
|
|
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'projector.' in k}) |
|
return to_return |
|
|
|
def _build_from_cfg_or_module(self, cfg_or_mod): |
|
if isinstance(cfg_or_mod, nn.Module): |
|
return cfg_or_mod |
|
elif isinstance(cfg_or_mod, dict): |
|
traverse_dict(cfg_or_mod) |
|
return BUILDER.build(cfg_or_mod) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
if 'pixel_values' in data: |
|
siglip_out = self.siglip( |
|
data['pixel_values'], output_hidden_states=True).hidden_states[self.visual_select_layer] |
|
dino_out = self.dino( |
|
data['pixel_values'], output_hidden_states=True).hidden_states[-1][:, 1:] |
|
visual_out = torch.cat((siglip_out, dino_out), dim=-1) |
|
pixel_values = self.projector(visual_out) |
|
data['pixel_values'] = pixel_values |
|
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) |
|
|
|
if mode == 'loss': |
|
return self.compute_loss(data, data_samples) |
|
elif mode == 'predict': |
|
return self.predict(data, data_samples) |
|
elif mode == 'tensor': |
|
return self._forward(data, data_samples) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _forward(self, data, data_samples=None): |
|
|
|
outputs = self.llm(**data) |
|
|
|
return outputs |
|
|
|
def predict(self, data, data_samples=None): |
|
outputs = self.llm(**data) |
|
logits_dict = [{'logits': logits} for logits in outputs.logits] |
|
return logits_dict |
|
|
|
def compute_loss(self, data, data_samples=None): |
|
outputs = self.llm(**data) |
|
loss_dict = {'loss': outputs.loss} |
|
return loss_dict |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.llm, name) |
|
|