d-Matrix
commited on
Commit
•
a26e7df
1
Parent(s):
64b8757
Update modeling_llama.py
Browse files- modeling_llama.py +30 -51
modeling_llama.py
CHANGED
@@ -49,6 +49,7 @@ from transformers.utils import (
|
|
49 |
replace_return_docstrings,
|
50 |
)
|
51 |
from .configuration_llama import LlamaConfig
|
|
|
52 |
|
53 |
|
54 |
if is_flash_attn_2_available():
|
@@ -56,6 +57,7 @@ if is_flash_attn_2_available():
|
|
56 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
57 |
|
58 |
|
|
|
59 |
logger = logging.get_logger(__name__)
|
60 |
|
61 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
@@ -72,24 +74,6 @@ def _get_unpad_data(attention_mask):
|
|
72 |
max_seqlen_in_batch,
|
73 |
)
|
74 |
|
75 |
-
|
76 |
-
class LlamaRMSNorm(nn.Module):
|
77 |
-
def __init__(self, hidden_size, eps=1e-6):
|
78 |
-
"""
|
79 |
-
LlamaRMSNorm is equivalent to T5LayerNorm
|
80 |
-
"""
|
81 |
-
super().__init__()
|
82 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
83 |
-
self.variance_epsilon = eps
|
84 |
-
|
85 |
-
def forward(self, hidden_states):
|
86 |
-
input_dtype = hidden_states.dtype
|
87 |
-
hidden_states = hidden_states.to(torch.float32)
|
88 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
89 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
90 |
-
return self.weight * hidden_states.to(input_dtype)
|
91 |
-
|
92 |
-
|
93 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
94 |
|
95 |
|
@@ -183,7 +167,6 @@ def rotate_half(x):
|
|
183 |
|
184 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
185 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
186 |
-
|
187 |
Args:
|
188 |
q (`torch.Tensor`): The query tensor.
|
189 |
k (`torch.Tensor`): The key tensor.
|
@@ -505,7 +488,6 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
505 |
"""
|
506 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
507 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
508 |
-
|
509 |
Args:
|
510 |
query_states (`torch.Tensor`):
|
511 |
Input query states to be passed to Flash Attention API
|
@@ -656,7 +638,6 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
656 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
657 |
|
658 |
causal_mask = attention_mask
|
659 |
-
# if attention_mask is not None and cache_position is not None:
|
660 |
if attention_mask is not None:
|
661 |
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
662 |
|
@@ -667,12 +648,15 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
667 |
key_states = key_states.contiguous()
|
668 |
value_states = value_states.contiguous()
|
669 |
|
|
|
|
|
670 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
671 |
query_states,
|
672 |
key_states,
|
673 |
value_states,
|
674 |
attn_mask=causal_mask,
|
675 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
|
676 |
)
|
677 |
|
678 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -769,11 +753,9 @@ LLAMA_START_DOCSTRING = r"""
|
|
769 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
770 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
771 |
etc.)
|
772 |
-
|
773 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
774 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
775 |
and behavior.
|
776 |
-
|
777 |
Parameters:
|
778 |
config ([`LlamaConfig`]):
|
779 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
@@ -834,50 +816,38 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|
834 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
835 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
836 |
it.
|
837 |
-
|
838 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
839 |
[`PreTrainedTokenizer.__call__`] for details.
|
840 |
-
|
841 |
[What are input IDs?](../glossary#input-ids)
|
842 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
843 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
844 |
-
|
845 |
- 1 for tokens that are **not masked**,
|
846 |
- 0 for tokens that are **masked**.
|
847 |
-
|
848 |
[What are attention masks?](../glossary#attention-mask)
|
849 |
-
|
850 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
851 |
[`PreTrainedTokenizer.__call__`] for details.
|
852 |
-
|
853 |
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
854 |
`past_key_values`).
|
855 |
-
|
856 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
857 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
858 |
information on the default strategy.
|
859 |
-
|
860 |
- 1 indicates the head is **not masked**,
|
861 |
- 0 indicates the head is **masked**.
|
862 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
863 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
864 |
config.n_positions - 1]`.
|
865 |
-
|
866 |
[What are position IDs?](../glossary#position-ids)
|
867 |
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
868 |
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
869 |
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
870 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
871 |
-
|
872 |
Two formats are allowed:
|
873 |
- a [`~cache_utils.Cache`] instance;
|
874 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
875 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
876 |
cache format.
|
877 |
-
|
878 |
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
879 |
legacy cache format will be returned.
|
880 |
-
|
881 |
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
882 |
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
883 |
of shape `(batch_size, sequence_length)`.
|
@@ -910,7 +880,6 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|
910 |
class LlamaModel(LlamaPreTrainedModel):
|
911 |
"""
|
912 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
913 |
-
|
914 |
Args:
|
915 |
config: LlamaConfig
|
916 |
"""
|
@@ -987,7 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
987 |
if position_ids is None:
|
988 |
position_ids = cache_position.unsqueeze(0)
|
989 |
|
990 |
-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
991 |
|
992 |
# embed positions
|
993 |
hidden_states = inputs_embeds
|
@@ -1051,16 +1020,31 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1051 |
attentions=all_self_attns,
|
1052 |
)
|
1053 |
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1059 |
if self.config._attn_implementation == "flash_attention_2":
|
1060 |
if attention_mask is not None and 0.0 in attention_mask:
|
1061 |
return attention_mask
|
1062 |
return None
|
1063 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1065 |
min_dtype = torch.finfo(dtype).min
|
1066 |
sequence_length = input_tensor.shape[1]
|
@@ -1068,7 +1052,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1068 |
target_length = self.config.max_position_embeddings
|
1069 |
else: # dynamic cache
|
1070 |
target_length = (
|
1071 |
-
attention_mask.shape[-1]
|
|
|
|
|
1072 |
)
|
1073 |
|
1074 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
@@ -1160,20 +1146,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1160 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1161 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1162 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1163 |
-
|
1164 |
Returns:
|
1165 |
-
|
1166 |
Example:
|
1167 |
-
|
1168 |
```python
|
1169 |
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1170 |
-
|
1171 |
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
1172 |
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
1173 |
-
|
1174 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1175 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1176 |
-
|
1177 |
>>> # Generate
|
1178 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1179 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
@@ -1328,10 +1308,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1328 |
@add_start_docstrings(
|
1329 |
"""
|
1330 |
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
1331 |
-
|
1332 |
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1333 |
(e.g. GPT-2) do.
|
1334 |
-
|
1335 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1336 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1337 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
@@ -1545,3 +1523,4 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|
1545 |
hidden_states=outputs.hidden_states,
|
1546 |
attentions=outputs.attentions,
|
1547 |
)
|
|
|
|
49 |
replace_return_docstrings,
|
50 |
)
|
51 |
from .configuration_llama import LlamaConfig
|
52 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
53 |
|
54 |
|
55 |
if is_flash_attn_2_available():
|
|
|
57 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
58 |
|
59 |
|
60 |
+
|
61 |
logger = logging.get_logger(__name__)
|
62 |
|
63 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
|
|
74 |
max_seqlen_in_batch,
|
75 |
)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
78 |
|
79 |
|
|
|
167 |
|
168 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
169 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
170 |
Args:
|
171 |
q (`torch.Tensor`): The query tensor.
|
172 |
k (`torch.Tensor`): The key tensor.
|
|
|
488 |
"""
|
489 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
490 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
491 |
Args:
|
492 |
query_states (`torch.Tensor`):
|
493 |
Input query states to be passed to Flash Attention API
|
|
|
638 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
639 |
|
640 |
causal_mask = attention_mask
|
|
|
641 |
if attention_mask is not None:
|
642 |
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
643 |
|
|
|
648 |
key_states = key_states.contiguous()
|
649 |
value_states = value_states.contiguous()
|
650 |
|
651 |
+
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
|
652 |
+
# relying on the `is_causal` argument.
|
653 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
654 |
query_states,
|
655 |
key_states,
|
656 |
value_states,
|
657 |
attn_mask=causal_mask,
|
658 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
659 |
+
is_causal=causal_mask is None and q_len > 1,
|
660 |
)
|
661 |
|
662 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
753 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
754 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
755 |
etc.)
|
|
|
756 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
757 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
758 |
and behavior.
|
|
|
759 |
Parameters:
|
760 |
config ([`LlamaConfig`]):
|
761 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
|
816 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
817 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
818 |
it.
|
|
|
819 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
820 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
821 |
[What are input IDs?](../glossary#input-ids)
|
822 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
823 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
824 |
- 1 for tokens that are **not masked**,
|
825 |
- 0 for tokens that are **masked**.
|
|
|
826 |
[What are attention masks?](../glossary#attention-mask)
|
|
|
827 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
828 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
829 |
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
830 |
`past_key_values`).
|
|
|
831 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
832 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
833 |
information on the default strategy.
|
|
|
834 |
- 1 indicates the head is **not masked**,
|
835 |
- 0 indicates the head is **masked**.
|
836 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
837 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
838 |
config.n_positions - 1]`.
|
|
|
839 |
[What are position IDs?](../glossary#position-ids)
|
840 |
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
841 |
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
842 |
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
843 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
|
|
844 |
Two formats are allowed:
|
845 |
- a [`~cache_utils.Cache`] instance;
|
846 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
847 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
848 |
cache format.
|
|
|
849 |
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
850 |
legacy cache format will be returned.
|
|
|
851 |
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
852 |
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
853 |
of shape `(batch_size, sequence_length)`.
|
|
|
880 |
class LlamaModel(LlamaPreTrainedModel):
|
881 |
"""
|
882 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
|
|
883 |
Args:
|
884 |
config: LlamaConfig
|
885 |
"""
|
|
|
956 |
if position_ids is None:
|
957 |
position_ids = cache_position.unsqueeze(0)
|
958 |
|
959 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
960 |
|
961 |
# embed positions
|
962 |
hidden_states = inputs_embeds
|
|
|
1020 |
attentions=all_self_attns,
|
1021 |
)
|
1022 |
|
1023 |
+
def _update_causal_mask(
|
1024 |
+
self,
|
1025 |
+
attention_mask: torch.Tensor,
|
1026 |
+
input_tensor: torch.Tensor,
|
1027 |
+
cache_position: torch.Tensor,
|
1028 |
+
past_seen_tokens: int,
|
1029 |
+
):
|
1030 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
1031 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
1032 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
1033 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
1034 |
+
|
1035 |
if self.config._attn_implementation == "flash_attention_2":
|
1036 |
if attention_mask is not None and 0.0 in attention_mask:
|
1037 |
return attention_mask
|
1038 |
return None
|
1039 |
|
1040 |
+
if self.config._attn_implementation == "sdpa":
|
1041 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
|
1042 |
+
# in order to dispatch on Flash Attention 2.
|
1043 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
1044 |
+
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
|
1045 |
+
):
|
1046 |
+
return None
|
1047 |
+
|
1048 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1049 |
min_dtype = torch.finfo(dtype).min
|
1050 |
sequence_length = input_tensor.shape[1]
|
|
|
1052 |
target_length = self.config.max_position_embeddings
|
1053 |
else: # dynamic cache
|
1054 |
target_length = (
|
1055 |
+
attention_mask.shape[-1]
|
1056 |
+
if isinstance(attention_mask, torch.Tensor)
|
1057 |
+
else past_seen_tokens + sequence_length + 1
|
1058 |
)
|
1059 |
|
1060 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
|
1146 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1147 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1148 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
1149 |
Returns:
|
|
|
1150 |
Example:
|
|
|
1151 |
```python
|
1152 |
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
|
1153 |
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
1154 |
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
|
1155 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1156 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
1157 |
>>> # Generate
|
1158 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1159 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
1308 |
@add_start_docstrings(
|
1309 |
"""
|
1310 |
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
|
|
1311 |
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1312 |
(e.g. GPT-2) do.
|
|
|
1313 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1314 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1315 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
|
1523 |
hidden_states=outputs.hidden_states,
|
1524 |
attentions=outputs.attentions,
|
1525 |
)
|
1526 |
+
|