# Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict import torch import torch.nn as nn from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training from xtuner.registry import BUILDER from .modules import ProjectorConfig, ProjectorModel, dispatch_modules from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, prepare_inputs_labels_for_multimodal, traverse_dict) class LLaVAModel(BaseModel): def __init__(self, llm, siglip, dino, freeze_llm=False, freeze_visual_encoder=False, visual_select_layer=-2, pretrained_pth=None, projector_depth=2, llm_lora=None, visual_encoder_lora=None, use_activation_checkpointing=True): super().__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = freeze_visual_encoder with LoadWoInit(): self.llm = self._build_from_cfg_or_module(llm) self.siglip = self._build_from_cfg_or_module(siglip) self.dino = self._build_from_cfg_or_module(dino) self.llm.config.use_cache = False dispatch_modules(self.llm) projector_config = ProjectorConfig( visual_hidden_size=self.siglip.config.hidden_size + self.dino.config.hidden_size, llm_hidden_size=self.llm.config.hidden_size, depth=projector_depth) self.projector = ProjectorModel(projector_config).to( self.siglip.dtype) if self.freeze_llm: self.llm.requires_grad_(False) if self.freeze_visual_encoder: self.siglip.requires_grad_(False) self.dino.requires_grad_(False) if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): self.llm.enable_input_require_grads() else: self.llm.get_input_embeddings().register_forward_hook( make_inputs_require_grad) if hasattr(self.siglip, 'enable_input_require_grads'): self.siglip.enable_input_require_grads() else: self.siglip.get_input_embeddings( ).register_forward_hook(make_inputs_require_grad) if hasattr(self.dino, 'enable_input_require_grads'): self.dino.enable_input_require_grads() else: self.dino.get_input_embeddings( ).register_forward_hook(make_inputs_require_grad) self.projector.enable_input_require_grads() # enable gradient (activation) checkpointing for memory efficiency self.gradient_checkpointing_enable() self.use_llm_lora = llm_lora is not None self.use_visual_encoder_lora = visual_encoder_lora is not None if self.use_llm_lora: self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) if self.use_visual_encoder_lora: self._prepare_visual_encoder_for_lora( visual_encoder_lora, use_activation_checkpointing) if pretrained_pth is not None: pretrained_state_dict = guess_load_checkpoint(pretrained_pth) self.load_state_dict(pretrained_state_dict, strict=False) print(f'Load pretrained weight from {pretrained_pth}') self.visual_select_layer = visual_select_layer self._is_init = True def _parse_lora_config(self, lora_config): if isinstance(lora_config, dict) or isinstance( lora_config, Config) or isinstance(lora_config, ConfigDict): lora_config = BUILDER.build(lora_config) return lora_config def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) self.llm = prepare_model_for_kbit_training( self.llm, use_activation_checkpointing) if lora_config.target_modules is None: modules = find_all_linear_names(self.llm) lora_config.target_modules = modules self.llm = get_peft_model(self.llm, lora_config) def _prepare_visual_encoder_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) modules = find_all_linear_names(self.siglip) lora_config.target_modules = modules self.siglip = get_peft_model(self.siglip, lora_config) modules = find_all_linear_names(self.dino) lora_config.target_modules = modules self.dino = get_peft_model(self.dino, lora_config) def gradient_checkpointing_enable(self): self.activation_checkpointing_enable() def activation_checkpointing_enable(self): self.llm.gradient_checkpointing_enable() self.siglip.gradient_checkpointing_enable() self.dino.gradient_checkpointing_enable() self.projector.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): self.llm.gradient_checkpointing_disable() self.siglip.gradient_checkpointing_disable() self.dino.gradient_checkpointing_disable() self.projector.gradient_checkpointing_disable() def init_weights(self): pass def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() # Step 1. visual_encoder if self.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.siglip, state_dict=state_dict)) to_return.update( get_peft_model_state_dict( self.dino, state_dict=state_dict)) elif not self.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'siglip.' in k }) to_return.update({ k: v for k, v in state_dict.items() if 'dino.' in k }) # Step 2. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'projector.' in k}) return to_return def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError def forward(self, data, data_samples=None, mode='loss'): if 'pixel_values' in data: siglip_out = self.siglip( data['pixel_values'], output_hidden_states=True).hidden_states[self.visual_select_layer] dino_out = self.dino( data['pixel_values'], output_hidden_states=True).hidden_states[-1][:, 1:] visual_out = torch.cat((siglip_out, dino_out), dim=-1) pixel_values = self.projector(visual_out) data['pixel_values'] = pixel_values data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) if mode == 'loss': return self.compute_loss(data, data_samples) elif mode == 'predict': return self.predict(data, data_samples) elif mode == 'tensor': return self._forward(data, data_samples) else: raise NotImplementedError def _forward(self, data, data_samples=None): outputs = self.llm(**data) return outputs def predict(self, data, data_samples=None): outputs = self.llm(**data) logits_dict = [{'logits': logits} for logits in outputs.logits] return logits_dict def compute_loss(self, data, data_samples=None): outputs = self.llm(**data) loss_dict = {'loss': outputs.loss} return loss_dict def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.llm, name)