d-Matrix commited on
Commit
a26e7df
1 Parent(s): 64b8757

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +30 -51
modeling_llama.py CHANGED
@@ -49,6 +49,7 @@ from transformers.utils import (
49
  replace_return_docstrings,
50
  )
51
  from .configuration_llama import LlamaConfig
 
52
 
53
 
54
  if is_flash_attn_2_available():
@@ -56,6 +57,7 @@ if is_flash_attn_2_available():
56
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
 
58
 
 
59
  logger = logging.get_logger(__name__)
60
 
61
  _CONFIG_FOR_DOC = "LlamaConfig"
@@ -72,24 +74,6 @@ def _get_unpad_data(attention_mask):
72
  max_seqlen_in_batch,
73
  )
74
 
75
-
76
- class LlamaRMSNorm(nn.Module):
77
- def __init__(self, hidden_size, eps=1e-6):
78
- """
79
- LlamaRMSNorm is equivalent to T5LayerNorm
80
- """
81
- super().__init__()
82
- self.weight = nn.Parameter(torch.ones(hidden_size))
83
- self.variance_epsilon = eps
84
-
85
- def forward(self, hidden_states):
86
- input_dtype = hidden_states.dtype
87
- hidden_states = hidden_states.to(torch.float32)
88
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
89
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
90
- return self.weight * hidden_states.to(input_dtype)
91
-
92
-
93
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
94
 
95
 
@@ -183,7 +167,6 @@ def rotate_half(x):
183
 
184
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
185
  """Applies Rotary Position Embedding to the query and key tensors.
186
-
187
  Args:
188
  q (`torch.Tensor`): The query tensor.
189
  k (`torch.Tensor`): The key tensor.
@@ -505,7 +488,6 @@ class LlamaFlashAttention2(LlamaAttention):
505
  """
506
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
507
  first unpad the input, then computes the attention scores and pad the final attention scores.
508
-
509
  Args:
510
  query_states (`torch.Tensor`):
511
  Input query states to be passed to Flash Attention API
@@ -656,7 +638,6 @@ class LlamaSdpaAttention(LlamaAttention):
656
  value_states = repeat_kv(value_states, self.num_key_value_groups)
657
 
658
  causal_mask = attention_mask
659
- # if attention_mask is not None and cache_position is not None:
660
  if attention_mask is not None:
661
  causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
662
 
@@ -667,12 +648,15 @@ class LlamaSdpaAttention(LlamaAttention):
667
  key_states = key_states.contiguous()
668
  value_states = value_states.contiguous()
669
 
 
 
670
  attn_output = torch.nn.functional.scaled_dot_product_attention(
671
  query_states,
672
  key_states,
673
  value_states,
674
  attn_mask=causal_mask,
675
  dropout_p=self.attention_dropout if self.training else 0.0,
 
676
  )
677
 
678
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -769,11 +753,9 @@ LLAMA_START_DOCSTRING = r"""
769
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
770
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
771
  etc.)
772
-
773
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
774
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
775
  and behavior.
776
-
777
  Parameters:
778
  config ([`LlamaConfig`]):
779
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -834,50 +816,38 @@ LLAMA_INPUTS_DOCSTRING = r"""
834
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
835
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
836
  it.
837
-
838
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
839
  [`PreTrainedTokenizer.__call__`] for details.
840
-
841
  [What are input IDs?](../glossary#input-ids)
842
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
843
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
844
-
845
  - 1 for tokens that are **not masked**,
846
  - 0 for tokens that are **masked**.
847
-
848
  [What are attention masks?](../glossary#attention-mask)
849
-
850
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
851
  [`PreTrainedTokenizer.__call__`] for details.
852
-
853
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
854
  `past_key_values`).
855
-
856
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
857
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
858
  information on the default strategy.
859
-
860
  - 1 indicates the head is **not masked**,
861
  - 0 indicates the head is **masked**.
862
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
863
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
864
  config.n_positions - 1]`.
865
-
866
  [What are position IDs?](../glossary#position-ids)
867
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
868
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
869
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
870
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
871
-
872
  Two formats are allowed:
873
  - a [`~cache_utils.Cache`] instance;
874
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
875
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
876
  cache format.
877
-
878
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
879
  legacy cache format will be returned.
880
-
881
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
882
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
883
  of shape `(batch_size, sequence_length)`.
@@ -910,7 +880,6 @@ LLAMA_INPUTS_DOCSTRING = r"""
910
  class LlamaModel(LlamaPreTrainedModel):
911
  """
912
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
913
-
914
  Args:
915
  config: LlamaConfig
916
  """
@@ -987,7 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
987
  if position_ids is None:
988
  position_ids = cache_position.unsqueeze(0)
989
 
990
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
991
 
992
  # embed positions
993
  hidden_states = inputs_embeds
@@ -1051,16 +1020,31 @@ class LlamaModel(LlamaPreTrainedModel):
1051
  attentions=all_self_attns,
1052
  )
1053
 
1054
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1055
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1056
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1057
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1058
- def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
 
 
 
 
 
 
 
1059
  if self.config._attn_implementation == "flash_attention_2":
1060
  if attention_mask is not None and 0.0 in attention_mask:
1061
  return attention_mask
1062
  return None
1063
 
 
 
 
 
 
 
 
 
1064
  dtype, device = input_tensor.dtype, input_tensor.device
1065
  min_dtype = torch.finfo(dtype).min
1066
  sequence_length = input_tensor.shape[1]
@@ -1068,7 +1052,9 @@ class LlamaModel(LlamaPreTrainedModel):
1068
  target_length = self.config.max_position_embeddings
1069
  else: # dynamic cache
1070
  target_length = (
1071
- attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
 
 
1072
  )
1073
 
1074
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
@@ -1160,20 +1146,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1160
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1161
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1162
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1163
-
1164
  Returns:
1165
-
1166
  Example:
1167
-
1168
  ```python
1169
  >>> from transformers import AutoTokenizer, LlamaForCausalLM
1170
-
1171
  >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1172
  >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1173
-
1174
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1175
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1176
-
1177
  >>> # Generate
1178
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1179
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1328,10 +1308,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1328
  @add_start_docstrings(
1329
  """
1330
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
1331
-
1332
  [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1333
  (e.g. GPT-2) do.
1334
-
1335
  Since it does classification on the last token, it requires to know the position of the last token. If a
1336
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1337
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
@@ -1545,3 +1523,4 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1545
  hidden_states=outputs.hidden_states,
1546
  attentions=outputs.attentions,
1547
  )
 
 
49
  replace_return_docstrings,
50
  )
51
  from .configuration_llama import LlamaConfig
52
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
53
 
54
 
55
  if is_flash_attn_2_available():
 
57
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
 
59
 
60
+
61
  logger = logging.get_logger(__name__)
62
 
63
  _CONFIG_FOR_DOC = "LlamaConfig"
 
74
  max_seqlen_in_batch,
75
  )
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
78
 
79
 
 
167
 
168
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
  """Applies Rotary Position Embedding to the query and key tensors.
 
170
  Args:
171
  q (`torch.Tensor`): The query tensor.
172
  k (`torch.Tensor`): The key tensor.
 
488
  """
489
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
490
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
491
  Args:
492
  query_states (`torch.Tensor`):
493
  Input query states to be passed to Flash Attention API
 
638
  value_states = repeat_kv(value_states, self.num_key_value_groups)
639
 
640
  causal_mask = attention_mask
 
641
  if attention_mask is not None:
642
  causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
643
 
 
648
  key_states = key_states.contiguous()
649
  value_states = value_states.contiguous()
650
 
651
+ # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
652
+ # relying on the `is_causal` argument.
653
  attn_output = torch.nn.functional.scaled_dot_product_attention(
654
  query_states,
655
  key_states,
656
  value_states,
657
  attn_mask=causal_mask,
658
  dropout_p=self.attention_dropout if self.training else 0.0,
659
+ is_causal=causal_mask is None and q_len > 1,
660
  )
661
 
662
  attn_output = attn_output.transpose(1, 2).contiguous()
 
753
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
754
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
755
  etc.)
 
756
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
757
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
758
  and behavior.
 
759
  Parameters:
760
  config ([`LlamaConfig`]):
761
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
816
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
817
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
818
  it.
 
819
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
820
  [`PreTrainedTokenizer.__call__`] for details.
 
821
  [What are input IDs?](../glossary#input-ids)
822
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
823
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
824
  - 1 for tokens that are **not masked**,
825
  - 0 for tokens that are **masked**.
 
826
  [What are attention masks?](../glossary#attention-mask)
 
827
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
828
  [`PreTrainedTokenizer.__call__`] for details.
 
829
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
830
  `past_key_values`).
 
831
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
832
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
833
  information on the default strategy.
 
834
  - 1 indicates the head is **not masked**,
835
  - 0 indicates the head is **masked**.
836
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
837
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
838
  config.n_positions - 1]`.
 
839
  [What are position IDs?](../glossary#position-ids)
840
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
841
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
842
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
843
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
844
  Two formats are allowed:
845
  - a [`~cache_utils.Cache`] instance;
846
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
847
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
848
  cache format.
 
849
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
850
  legacy cache format will be returned.
 
851
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
852
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
853
  of shape `(batch_size, sequence_length)`.
 
880
  class LlamaModel(LlamaPreTrainedModel):
881
  """
882
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
 
883
  Args:
884
  config: LlamaConfig
885
  """
 
956
  if position_ids is None:
957
  position_ids = cache_position.unsqueeze(0)
958
 
959
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
960
 
961
  # embed positions
962
  hidden_states = inputs_embeds
 
1020
  attentions=all_self_attns,
1021
  )
1022
 
1023
+ def _update_causal_mask(
1024
+ self,
1025
+ attention_mask: torch.Tensor,
1026
+ input_tensor: torch.Tensor,
1027
+ cache_position: torch.Tensor,
1028
+ past_seen_tokens: int,
1029
+ ):
1030
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1031
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1032
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1033
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1034
+
1035
  if self.config._attn_implementation == "flash_attention_2":
1036
  if attention_mask is not None and 0.0 in attention_mask:
1037
  return attention_mask
1038
  return None
1039
 
1040
+ if self.config._attn_implementation == "sdpa":
1041
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1042
+ # in order to dispatch on Flash Attention 2.
1043
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1044
+ attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
1045
+ ):
1046
+ return None
1047
+
1048
  dtype, device = input_tensor.dtype, input_tensor.device
1049
  min_dtype = torch.finfo(dtype).min
1050
  sequence_length = input_tensor.shape[1]
 
1052
  target_length = self.config.max_position_embeddings
1053
  else: # dynamic cache
1054
  target_length = (
1055
+ attention_mask.shape[-1]
1056
+ if isinstance(attention_mask, torch.Tensor)
1057
+ else past_seen_tokens + sequence_length + 1
1058
  )
1059
 
1060
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
 
1146
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1147
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1148
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1149
  Returns:
 
1150
  Example:
 
1151
  ```python
1152
  >>> from transformers import AutoTokenizer, LlamaForCausalLM
 
1153
  >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1154
  >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
1155
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1156
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1157
  >>> # Generate
1158
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1159
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1308
  @add_start_docstrings(
1309
  """
1310
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
 
1311
  [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1312
  (e.g. GPT-2) do.
 
1313
  Since it does classification on the last token, it requires to know the position of the last token. If a
1314
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1315
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
1523
  hidden_states=outputs.hidden_states,
1524
  attentions=outputs.attentions,
1525
  )
1526
+