KoichiYasuoka commited on
Commit
a166afd
·
1 Parent(s): 2858ad2

support inputs_embeds

Browse files
Files changed (1) hide show
  1. modeling_modernbert.py +68 -26
modeling_modernbert.py CHANGED
@@ -206,12 +206,17 @@ class ModernBertEmbeddings(nn.Module):
206
  def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
207
  return self.drop(self.norm(self.tok_embeddings(input_ids)))
208
 
209
- def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
210
- hidden_states = (
211
- self.compiled_embeddings(input_ids)
212
- if self.config.reference_compile
213
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
214
- )
 
 
 
 
 
215
  return hidden_states
216
 
217
 
@@ -792,6 +797,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
792
  config.n_positions - 1]`.
793
 
794
  [What are position IDs?](../glossary#position-ids)
 
 
 
 
795
  indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
796
  Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
797
  cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
@@ -843,10 +852,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
843
  )
844
  def forward(
845
  self,
846
- input_ids: torch.LongTensor = None,
847
  attention_mask: Optional[torch.Tensor] = None,
848
  sliding_window_mask: Optional[torch.Tensor] = None,
849
  position_ids: Optional[torch.LongTensor] = None,
 
850
  indices: Optional[torch.Tensor] = None,
851
  cu_seqlens: Optional[torch.Tensor] = None,
852
  max_seqlen: Optional[int] = None,
@@ -862,35 +872,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
862
  )
863
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
864
 
 
 
 
865
  all_hidden_states = () if output_hidden_states else None
866
  all_self_attentions = () if output_attentions else None
867
 
868
  self._maybe_set_compile()
869
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
 
 
870
 
871
  if batch_size is None and seq_len is None:
872
- batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
873
 
874
  if attention_mask is None:
875
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
876
 
877
  repad = False
878
  if self.config._attn_implementation == "flash_attention_2":
879
  if indices is None and cu_seqlens is None and max_seqlen is None:
880
  repad = True
881
- with torch.no_grad():
882
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
883
- inputs=input_ids, attention_mask=attention_mask
 
 
 
 
 
884
  )
885
  else:
886
  if position_ids is None:
887
- position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
888
 
889
  attention_mask, sliding_window_mask = self._update_attention_mask(
890
  attention_mask, output_attentions=output_attentions
891
  )
892
 
893
- hidden_states = self.embeddings(input_ids)
894
 
895
  for encoder_layer in self.layers:
896
  if output_hidden_states:
@@ -1026,10 +1050,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1026
  )
1027
  def forward(
1028
  self,
1029
- input_ids: Optional[torch.Tensor],
1030
  attention_mask: Optional[torch.Tensor] = None,
1031
  sliding_window_mask: Optional[torch.Tensor] = None,
1032
  position_ids: Optional[torch.Tensor] = None,
 
1033
  labels: Optional[torch.Tensor] = None,
1034
  indices: Optional[torch.Tensor] = None,
1035
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1046,19 +1071,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1046
 
1047
  if self.config._attn_implementation == "flash_attention_2":
1048
  if indices is None and cu_seqlens is None and max_seqlen is None:
1049
- batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
 
 
1050
  if attention_mask is None:
1051
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
1052
- with torch.no_grad():
1053
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1054
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
 
 
 
 
 
 
1055
  )
1056
 
1057
  outputs = self.model(
1058
- input_ids,
1059
  attention_mask=attention_mask,
1060
  sliding_window_mask=sliding_window_mask,
1061
  position_ids=position_ids,
 
1062
  indices=indices,
1063
  cu_seqlens=cu_seqlens,
1064
  max_seqlen=max_seqlen,
@@ -1131,10 +1169,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1131
  )
1132
  def forward(
1133
  self,
1134
- input_ids: Optional[torch.Tensor],
1135
  attention_mask: Optional[torch.Tensor] = None,
1136
  sliding_window_mask: Optional[torch.Tensor] = None,
1137
  position_ids: Optional[torch.Tensor] = None,
 
1138
  labels: Optional[torch.Tensor] = None,
1139
  indices: Optional[torch.Tensor] = None,
1140
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1156,10 +1195,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1156
  self._maybe_set_compile()
1157
 
1158
  outputs = self.model(
1159
- input_ids,
1160
  attention_mask=attention_mask,
1161
  sliding_window_mask=sliding_window_mask,
1162
  position_ids=position_ids,
 
1163
  indices=indices,
1164
  cu_seqlens=cu_seqlens,
1165
  max_seqlen=max_seqlen,
@@ -1242,10 +1282,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1242
  )
1243
  def forward(
1244
  self,
1245
- input_ids: Optional[torch.Tensor],
1246
  attention_mask: Optional[torch.Tensor] = None,
1247
  sliding_window_mask: Optional[torch.Tensor] = None,
1248
  position_ids: Optional[torch.Tensor] = None,
 
1249
  labels: Optional[torch.Tensor] = None,
1250
  indices: Optional[torch.Tensor] = None,
1251
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1264,10 +1305,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1264
  self._maybe_set_compile()
1265
 
1266
  outputs = self.model(
1267
- input_ids,
1268
  attention_mask=attention_mask,
1269
  sliding_window_mask=sliding_window_mask,
1270
  position_ids=position_ids,
 
1271
  indices=indices,
1272
  cu_seqlens=cu_seqlens,
1273
  max_seqlen=max_seqlen,
 
206
  def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
207
  return self.drop(self.norm(self.tok_embeddings(input_ids)))
208
 
209
+ def forward(
210
+ self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
211
+ ) -> torch.Tensor:
212
+ if inputs_embeds is not None:
213
+ hidden_states = self.drop(self.norm(inputs_embeds))
214
+ else:
215
+ hidden_states = (
216
+ self.compiled_embeddings(input_ids)
217
+ if self.config.reference_compile
218
+ else self.drop(self.norm(self.tok_embeddings(input_ids)))
219
+ )
220
  return hidden_states
221
 
222
 
 
797
  config.n_positions - 1]`.
798
 
799
  [What are position IDs?](../glossary#position-ids)
800
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
801
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
802
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
803
+ model's internal embedding lookup matrix.
804
  indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
805
  Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
806
  cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
 
852
  )
853
  def forward(
854
  self,
855
+ input_ids: Optional[torch.LongTensor] = None,
856
  attention_mask: Optional[torch.Tensor] = None,
857
  sliding_window_mask: Optional[torch.Tensor] = None,
858
  position_ids: Optional[torch.LongTensor] = None,
859
+ inputs_embeds: Optional[torch.Tensor] = None,
860
  indices: Optional[torch.Tensor] = None,
861
  cu_seqlens: Optional[torch.Tensor] = None,
862
  max_seqlen: Optional[int] = None,
 
872
  )
873
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
874
 
875
+ if (input_ids is None) ^ (inputs_embeds is not None):
876
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
877
+
878
  all_hidden_states = () if output_hidden_states else None
879
  all_self_attentions = () if output_attentions else None
880
 
881
  self._maybe_set_compile()
882
+
883
+ if input_ids is not None:
884
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
885
 
886
  if batch_size is None and seq_len is None:
887
+ if inputs_embeds is not None:
888
+ batch_size, seq_len = inputs_embeds.shape[:2]
889
+ else:
890
+ batch_size, seq_len = input_ids.shape[:2]
891
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
892
 
893
  if attention_mask is None:
894
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
895
 
896
  repad = False
897
  if self.config._attn_implementation == "flash_attention_2":
898
  if indices is None and cu_seqlens is None and max_seqlen is None:
899
  repad = True
900
+ if inputs_embeds is None:
901
+ with torch.no_grad():
902
+ input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
903
+ inputs=input_ids, attention_mask=attention_mask
904
+ )
905
+ else:
906
+ inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
907
+ inputs=inputs_embeds, attention_mask=attention_mask
908
  )
909
  else:
910
  if position_ids is None:
911
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
912
 
913
  attention_mask, sliding_window_mask = self._update_attention_mask(
914
  attention_mask, output_attentions=output_attentions
915
  )
916
 
917
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
918
 
919
  for encoder_layer in self.layers:
920
  if output_hidden_states:
 
1050
  )
1051
  def forward(
1052
  self,
1053
+ input_ids: Optional[torch.LongTensor] = None,
1054
  attention_mask: Optional[torch.Tensor] = None,
1055
  sliding_window_mask: Optional[torch.Tensor] = None,
1056
  position_ids: Optional[torch.Tensor] = None,
1057
+ inputs_embeds: Optional[torch.Tensor] = None,
1058
  labels: Optional[torch.Tensor] = None,
1059
  indices: Optional[torch.Tensor] = None,
1060
  cu_seqlens: Optional[torch.Tensor] = None,
 
1071
 
1072
  if self.config._attn_implementation == "flash_attention_2":
1073
  if indices is None and cu_seqlens is None and max_seqlen is None:
1074
+ if batch_size is None and seq_len is None:
1075
+ if inputs_embeds is not None:
1076
+ batch_size, seq_len = inputs_embeds.shape[:2]
1077
+ else:
1078
+ batch_size, seq_len = input_ids.shape[:2]
1079
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1080
+
1081
  if attention_mask is None:
1082
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1083
+
1084
+ if inputs_embeds is None:
1085
+ with torch.no_grad():
1086
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1087
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1088
+ )
1089
+ else:
1090
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1091
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1092
  )
1093
 
1094
  outputs = self.model(
1095
+ input_ids=input_ids,
1096
  attention_mask=attention_mask,
1097
  sliding_window_mask=sliding_window_mask,
1098
  position_ids=position_ids,
1099
+ inputs_embeds=inputs_embeds,
1100
  indices=indices,
1101
  cu_seqlens=cu_seqlens,
1102
  max_seqlen=max_seqlen,
 
1169
  )
1170
  def forward(
1171
  self,
1172
+ input_ids: Optional[torch.LongTensor] = None,
1173
  attention_mask: Optional[torch.Tensor] = None,
1174
  sliding_window_mask: Optional[torch.Tensor] = None,
1175
  position_ids: Optional[torch.Tensor] = None,
1176
+ inputs_embeds: Optional[torch.Tensor] = None,
1177
  labels: Optional[torch.Tensor] = None,
1178
  indices: Optional[torch.Tensor] = None,
1179
  cu_seqlens: Optional[torch.Tensor] = None,
 
1195
  self._maybe_set_compile()
1196
 
1197
  outputs = self.model(
1198
+ input_ids=input_ids,
1199
  attention_mask=attention_mask,
1200
  sliding_window_mask=sliding_window_mask,
1201
  position_ids=position_ids,
1202
+ inputs_embeds=inputs_embeds,
1203
  indices=indices,
1204
  cu_seqlens=cu_seqlens,
1205
  max_seqlen=max_seqlen,
 
1282
  )
1283
  def forward(
1284
  self,
1285
+ input_ids: Optional[torch.LongTensor] = None,
1286
  attention_mask: Optional[torch.Tensor] = None,
1287
  sliding_window_mask: Optional[torch.Tensor] = None,
1288
  position_ids: Optional[torch.Tensor] = None,
1289
+ inputs_embeds: Optional[torch.Tensor] = None,
1290
  labels: Optional[torch.Tensor] = None,
1291
  indices: Optional[torch.Tensor] = None,
1292
  cu_seqlens: Optional[torch.Tensor] = None,
 
1305
  self._maybe_set_compile()
1306
 
1307
  outputs = self.model(
1308
+ input_ids=input_ids,
1309
  attention_mask=attention_mask,
1310
  sliding_window_mask=sliding_window_mask,
1311
  position_ids=position_ids,
1312
+ inputs_embeds=inputs_embeds,
1313
  indices=indices,
1314
  cu_seqlens=cu_seqlens,
1315
  max_seqlen=max_seqlen,