StarCycle's picture
init
d2d310a
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
import torch
import torch.nn as nn
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from xtuner.registry import BUILDER
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
prepare_inputs_labels_for_multimodal, traverse_dict)
class LLaVAModel(BaseModel):
def __init__(self,
llm,
siglip,
dino,
freeze_llm=False,
freeze_visual_encoder=False,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
with LoadWoInit():
self.llm = self._build_from_cfg_or_module(llm)
self.siglip = self._build_from_cfg_or_module(siglip)
self.dino = self._build_from_cfg_or_module(dino)
self.llm.config.use_cache = False
dispatch_modules(self.llm)
projector_config = ProjectorConfig(
visual_hidden_size=self.siglip.config.hidden_size + self.dino.config.hidden_size,
llm_hidden_size=self.llm.config.hidden_size,
depth=projector_depth)
self.projector = ProjectorModel(projector_config).to(
self.siglip.dtype)
if self.freeze_llm:
self.llm.requires_grad_(False)
if self.freeze_visual_encoder:
self.siglip.requires_grad_(False)
self.dino.requires_grad_(False)
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)
if hasattr(self.siglip, 'enable_input_require_grads'):
self.siglip.enable_input_require_grads()
else:
self.siglip.get_input_embeddings(
).register_forward_hook(make_inputs_require_grad)
if hasattr(self.dino, 'enable_input_require_grads'):
self.dino.enable_input_require_grads()
else:
self.dino.get_input_embeddings(
).register_forward_hook(make_inputs_require_grad)
self.projector.enable_input_require_grads()
# enable gradient (activation) checkpointing for memory efficiency
self.gradient_checkpointing_enable()
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = visual_encoder_lora is not None
if self.use_llm_lora:
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
if self.use_visual_encoder_lora:
self._prepare_visual_encoder_for_lora(
visual_encoder_lora, use_activation_checkpointing)
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(pretrained_state_dict, strict=False)
print(f'Load pretrained weight from {pretrained_pth}')
self.visual_select_layer = visual_select_layer
self._is_init = True
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _prepare_llm_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.llm)
lora_config.target_modules = modules
self.llm = get_peft_model(self.llm, lora_config)
def _prepare_visual_encoder_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
modules = find_all_linear_names(self.siglip)
lora_config.target_modules = modules
self.siglip = get_peft_model(self.siglip, lora_config)
modules = find_all_linear_names(self.dino)
lora_config.target_modules = modules
self.dino = get_peft_model(self.dino, lora_config)
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
self.siglip.gradient_checkpointing_enable()
self.dino.gradient_checkpointing_enable()
self.projector.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
self.siglip.gradient_checkpointing_disable()
self.dino.gradient_checkpointing_disable()
self.projector.gradient_checkpointing_disable()
def init_weights(self):
pass
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.siglip, state_dict=state_dict))
to_return.update(
get_peft_model_state_dict(
self.dino, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'siglip.' in k
})
to_return.update({
k: v
for k, v in state_dict.items() if 'dino.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(self.llm, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update(
{k: v
for k, v in state_dict.items() if 'llm.' in k})
# Step 3. Projector
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector.' in k})
return to_return
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
if 'pixel_values' in data:
siglip_out = self.siglip(
data['pixel_values'], output_hidden_states=True).hidden_states[self.visual_select_layer]
dino_out = self.dino(
data['pixel_values'], output_hidden_states=True).hidden_states[-1][:, 1:]
visual_out = torch.cat((siglip_out, dino_out), dim=-1)
pixel_values = self.projector(visual_out)
data['pixel_values'] = pixel_values
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
def compute_loss(self, data, data_samples=None):
outputs = self.llm(**data)
loss_dict = {'loss': outputs.loss}
return loss_dict
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)