babylm2024-git-txt / modeling_git.py
AlinaKl's picture
Upload 10 files
861d88f verified
raw
history blame
9.59 kB
import transformers
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import ViTFeatureExtractor, ViTModel, ViTConfig
from typing import List, Optional, Tuple, Union
import warnings
import ipdb
import os
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
from itertools import product
import numpy as np
import transformers.models.git.modeling_git as modeling_git
import transformers.models.vit.modeling_vit as modeling_vit
from transformers.models.opt.modeling_opt import OPTConfig
import transformers.models.opt.modeling_opt as hg_opt
import transformers.models.clip.modeling_clip as modeling_clip
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
class GitForCausalLM(modeling_git.GitForCausalLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self.output
self.output = nn.Linear(
self.config.hidden_size,
self.config.vocab_size,
bias=False)
self.post_init()
del self.git.image_encoder
self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
dino_cfg = self.git.image_encoder.config
config = self.git.config
config.vision_config.hidden_size = dino_cfg.hidden_size
del self.git.visual_projection
self.git.visual_projection = modeling_git.GitProjection(config)
num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.git(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.output(sequence_output)
loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
if pixel_values is not None:
num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
else:
num_image_tokens = 0
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return modeling_git.CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class GitForSequenceClassification(modeling_git.GitPreTrainedModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_labels = self.config.num_labels
self.classifier = nn.Linear(
self.config.hidden_size,
self.config.num_labels,
bias=False)
self.post_init()
self.git = modeling_git.GitModel(self.config)
del self.git.image_encoder
self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
dino_cfg = self.git.image_encoder.config
config = self.git.config
config.vision_config.hidden_size = dino_cfg.hidden_size
del self.git.visual_projection
self.git.visual_projection = modeling_git.GitProjection(config)
num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.git(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
*args, **kwargs)
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
# logger.warning(
# f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
# "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
# )
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)