|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from transformers.activations import ACT2FN |
|
|
|
from .configuration_projector import ProjectorConfig |
|
|
|
|
|
class ProjectorModel(PreTrainedModel): |
|
_auto_class = 'AutoModel' |
|
config_class = ProjectorConfig |
|
base_model_prefix = 'model' |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config: ProjectorConfig) -> None: |
|
super().__init__(config) |
|
self.gradient_checkpointing = False |
|
|
|
modules = [ |
|
nn.Linear( |
|
config.visual_hidden_size, |
|
config.llm_hidden_size, |
|
bias=config.bias) |
|
] |
|
for _ in range(1, config.depth): |
|
modules.append(ACT2FN[config.hidden_act]) |
|
modules.append( |
|
nn.Linear( |
|
config.llm_hidden_size, |
|
config.llm_hidden_size, |
|
bias=config.bias)) |
|
self.model = nn.Sequential(*modules) |
|
|
|
def enable_input_require_grads(self): |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
self.model.register_forward_hook(make_inputs_require_grad) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, ProjectorModel): |
|
module.gradient_checkpointing = value |
|
|
|
def forward(self, x): |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) |
|
else: |
|
layer_outputs = self.model(x) |
|
return layer_outputs |
|
|