babylm commited on
Commit
2cd6995
·
verified ·
1 Parent(s): 0e6cae6

add support for sequence classification

Browse files
Files changed (2) hide show
  1. config.json +6 -5
  2. modeling_git.py +123 -1
config.json CHANGED
@@ -1,13 +1,14 @@
1
  {
2
- "_commit_hash": "58597eee6783b5d4405df333573fe4b4368bce29",
3
  "architectures": [
4
  "GitForCausalLM"
5
  ],
 
6
  "auto_map": {
7
  "AutoConfig": "configuration_git.GitConfig",
8
- "AutoModelForCausalLM": "modeling_git.GitForCausalLM"
 
9
  },
10
- "attention_probs_dropout_prob": 0.1,
11
  "bos_token_id": 101,
12
  "classifier_dropout": null,
13
  "eos_token_id": 102,
@@ -24,11 +25,11 @@
24
  "num_image_with_embedding": null,
25
  "pad_token_id": 0,
26
  "position_embedding_type": "absolute",
27
- "tie_word_embeddings": true,
28
  "torch_dtype": "float32",
29
- "transformers_version": null,
30
  "use_cache": true,
31
  "vision_config": {
 
32
  "_name_or_path": "",
33
  "add_cross_attention": false,
34
  "architectures": null,
 
1
  {
2
+ "_name_or_path": "babylm/git-2024",
3
  "architectures": [
4
  "GitForCausalLM"
5
  ],
6
+ "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
  "AutoConfig": "configuration_git.GitConfig",
9
+ "AutoModelForCausalLM": "modeling_git.GitForCausalLM",
10
+ "AutoModelForSequenceClassification": "modeling_git.GitForSequenceClassification"
11
  },
 
12
  "bos_token_id": 101,
13
  "classifier_dropout": null,
14
  "eos_token_id": 102,
 
25
  "num_image_with_embedding": null,
26
  "pad_token_id": 0,
27
  "position_embedding_type": "absolute",
 
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.26.0",
30
  "use_cache": true,
31
  "vision_config": {
32
+ "_commit_hash": null,
33
  "_name_or_path": "",
34
  "add_cross_attention": false,
35
  "architectures": null,
modeling_git.py CHANGED
@@ -7,7 +7,7 @@ import ipdb
7
  import os
8
  import torch
9
  from torch import nn
10
- from torch.nn import CrossEntropyLoss
11
  from itertools import product
12
  import numpy as np
13
  import transformers.models.git.modeling_git as modeling_git
@@ -15,6 +15,7 @@ import transformers.models.vit.modeling_vit as modeling_vit
15
  from transformers.models.opt.modeling_opt import OPTConfig
16
  import transformers.models.opt.modeling_opt as hg_opt
17
  import transformers.models.clip.modeling_clip as modeling_clip
 
18
 
19
 
20
  class GitForCausalLM(modeling_git.GitForCausalLM):
@@ -98,3 +99,124 @@ class GitForCausalLM(modeling_git.GitForCausalLM):
98
  hidden_states=outputs.hidden_states,
99
  attentions=outputs.attentions,
100
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import torch
9
  from torch import nn
10
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
11
  from itertools import product
12
  import numpy as np
13
  import transformers.models.git.modeling_git as modeling_git
 
15
  from transformers.models.opt.modeling_opt import OPTConfig
16
  import transformers.models.opt.modeling_opt as hg_opt
17
  import transformers.models.clip.modeling_clip as modeling_clip
18
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
19
 
20
 
21
  class GitForCausalLM(modeling_git.GitForCausalLM):
 
99
  hidden_states=outputs.hidden_states,
100
  attentions=outputs.attentions,
101
  )
102
+
103
+
104
+ class GitForSequenceClassification(modeling_git.GitPreTrainedModel):
105
+ def __init__(self, *args, **kwargs):
106
+ super().__init__(*args, **kwargs)
107
+ self.num_labels = self.config.num_labels
108
+ self.classifier = nn.Linear(
109
+ self.config.hidden_size,
110
+ self.config.num_labels,
111
+ bias=False)
112
+ self.post_init()
113
+ self.git = modeling_git.GitModel(self.config)
114
+
115
+ del self.git.image_encoder
116
+ self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
117
+ dino_cfg = self.git.image_encoder.config
118
+ config = self.git.config
119
+ config.vision_config.hidden_size = dino_cfg.hidden_size
120
+
121
+ del self.git.visual_projection
122
+ self.git.visual_projection = modeling_git.GitProjection(config)
123
+ num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
124
+ self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: Optional[torch.LongTensor] = None,
129
+ attention_mask: Optional[torch.FloatTensor] = None,
130
+ position_ids: Optional[torch.Tensor] = None,
131
+ pixel_values: Optional[torch.Tensor] = None,
132
+ head_mask: Optional[torch.FloatTensor] = None,
133
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
134
+ inputs_embeds: Optional[torch.FloatTensor] = None,
135
+ labels: Optional[torch.LongTensor] = None,
136
+ use_cache: Optional[bool] = None,
137
+ output_attentions: Optional[bool] = None,
138
+ output_hidden_states: Optional[bool] = None,
139
+ return_dict: Optional[bool] = None,
140
+ *args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
141
+ r"""
142
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
143
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
144
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
145
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
146
+ """
147
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
148
+
149
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
150
+ outputs = self.git(
151
+ input_ids,
152
+ attention_mask=attention_mask,
153
+ position_ids=position_ids,
154
+ pixel_values=pixel_values,
155
+ head_mask=head_mask,
156
+ inputs_embeds=inputs_embeds,
157
+ past_key_values=past_key_values,
158
+ use_cache=use_cache,
159
+ output_attentions=output_attentions,
160
+ output_hidden_states=output_hidden_states,
161
+ return_dict=return_dict,
162
+ *args, **kwargs)
163
+
164
+ hidden_states = outputs[0]
165
+ logits = self.classifier(hidden_states)
166
+
167
+ if input_ids is not None:
168
+ batch_size, sequence_length = input_ids.shape[:2]
169
+ else:
170
+ batch_size, sequence_length = inputs_embeds.shape[:2]
171
+
172
+ if self.config.pad_token_id is None:
173
+ sequence_lengths = -1
174
+ else:
175
+ if input_ids is not None:
176
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
177
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
178
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
179
+ sequence_lengths = sequence_lengths.to(logits.device)
180
+ else:
181
+ sequence_lengths = -1
182
+ # logger.warning(
183
+ # f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
184
+ # "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
185
+ # )
186
+
187
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
188
+
189
+ loss = None
190
+ if labels is not None:
191
+ if self.config.problem_type is None:
192
+ if self.num_labels == 1:
193
+ self.config.problem_type = "regression"
194
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
195
+ self.config.problem_type = "single_label_classification"
196
+ else:
197
+ self.config.problem_type = "multi_label_classification"
198
+
199
+ if self.config.problem_type == "regression":
200
+ loss_fct = MSELoss()
201
+ if self.num_labels == 1:
202
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
203
+ else:
204
+ loss = loss_fct(pooled_logits, labels)
205
+ elif self.config.problem_type == "single_label_classification":
206
+ loss_fct = CrossEntropyLoss()
207
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
208
+ elif self.config.problem_type == "multi_label_classification":
209
+ loss_fct = BCEWithLogitsLoss()
210
+ loss = loss_fct(pooled_logits, labels)
211
+
212
+ if not return_dict:
213
+ output = (pooled_logits,) + outputs[1:]
214
+ return ((loss,) + output) if loss is not None else output
215
+
216
+ return SequenceClassifierOutputWithPast(
217
+ loss=loss,
218
+ logits=pooled_logits,
219
+ past_key_values=outputs.past_key_values,
220
+ hidden_states=outputs.hidden_states,
221
+ attentions=outputs.attentions,
222
+ )