Andrei Panferov commited on
Commit
43230db
·
1 Parent(s): 088443f

New inference

Browse files
config.json CHANGED
@@ -1,40 +1,138 @@
1
  {
2
- "_name_or_path": "/home/blacksamorez/models/Mixtral-8x7b-AQLM-2Bit-1x16-hf/",
3
- "aqlm": {
4
- "in_group_size": 8,
5
- "nbits_per_codebook": 16,
6
- "num_codebooks": 1,
7
- "out_group_size": 1
8
- },
9
- "architectures": [
10
- "MixtralForCausalLM"
11
- ],
12
- "attention_dropout": 0.0,
13
- "auto_map": {
14
- "AutoConfig": "configuration_mixtral_aqlm.MixtralConfig",
15
- "AutoModelForCausalLM": "modeling_mixtral_aqlm.MixtralForCausalLM"
16
- },
17
- "bos_token_id": 1,
18
- "eos_token_id": 2,
19
- "hidden_act": "silu",
20
- "hidden_size": 4096,
21
- "initializer_range": 0.02,
22
- "intermediate_size": 14336,
23
- "max_position_embeddings": 32768,
24
- "model_type": "mixtral_aqlm",
25
- "num_attention_heads": 32,
26
- "num_experts_per_tok": 2,
27
- "num_hidden_layers": 32,
28
- "num_key_value_heads": 8,
29
- "num_local_experts": 8,
30
- "output_router_logits": false,
31
- "rms_norm_eps": 1e-05,
32
- "rope_theta": 1000000.0,
33
- "router_aux_loss_coef": 0.02,
34
- "sliding_window": null,
35
- "tie_word_embeddings": false,
36
- "torch_dtype": "float16",
37
- "transformers_version": "4.37.0",
38
- "use_cache": true,
39
- "vocab_size": 32000
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "vocab_size": 32000,
3
+ "max_position_embeddings": 32768,
4
+ "hidden_size": 4096,
5
+ "intermediate_size": 14336,
6
+ "num_hidden_layers": 32,
7
+ "num_attention_heads": 32,
8
+ "sliding_window": null,
9
+ "num_key_value_heads": 8,
10
+ "hidden_act": "silu",
11
+ "initializer_range": 0.02,
12
+ "rms_norm_eps": 1e-05,
13
+ "use_cache": true,
14
+ "rope_theta": 1000000.0,
15
+ "attention_dropout": 0.0,
16
+ "num_experts_per_tok": 2,
17
+ "num_local_experts": 8,
18
+ "output_router_logits": false,
19
+ "router_aux_loss_coef": 0.02,
20
+ "torch_dtype": "float16",
21
+ "tie_word_embeddings": false,
22
+ "architectures": [
23
+ "MixtralForCausalLM"
24
+ ],
25
+ "bos_token_id": 1,
26
+ "eos_token_id": 2,
27
+ "_name_or_path": "mistralai/Mixtral-8x7B-v0.1",
28
+ "transformers_version": "4.38.0.dev0",
29
+ "model_type": "mixtral",
30
+ "quantization_config": {
31
+ "quant_method": "aqlm",
32
+ "nbits_per_codebook": 16,
33
+ "num_codebooks": 1,
34
+ "out_group_size": 1,
35
+ "in_group_size": 8,
36
+ "linear_weights_not_to_quantize": [
37
+ "model.layers.0.block_sparse_moe.gate.weight",
38
+ "model.layers.0.input_layernorm.weight",
39
+ "model.layers.0.post_attention_layernorm.weight",
40
+ "model.layers.1.block_sparse_moe.gate.weight",
41
+ "model.layers.1.input_layernorm.weight",
42
+ "model.layers.1.post_attention_layernorm.weight",
43
+ "model.layers.2.block_sparse_moe.gate.weight",
44
+ "model.layers.2.input_layernorm.weight",
45
+ "model.layers.2.post_attention_layernorm.weight",
46
+ "model.layers.3.block_sparse_moe.gate.weight",
47
+ "model.layers.3.input_layernorm.weight",
48
+ "model.layers.3.post_attention_layernorm.weight",
49
+ "model.layers.4.block_sparse_moe.gate.weight",
50
+ "model.layers.4.input_layernorm.weight",
51
+ "model.layers.4.post_attention_layernorm.weight",
52
+ "model.layers.5.block_sparse_moe.gate.weight",
53
+ "model.layers.5.input_layernorm.weight",
54
+ "model.layers.5.post_attention_layernorm.weight",
55
+ "model.layers.6.block_sparse_moe.gate.weight",
56
+ "model.layers.6.input_layernorm.weight",
57
+ "model.layers.6.post_attention_layernorm.weight",
58
+ "model.layers.7.block_sparse_moe.gate.weight",
59
+ "model.layers.7.input_layernorm.weight",
60
+ "model.layers.7.post_attention_layernorm.weight",
61
+ "model.layers.8.block_sparse_moe.gate.weight",
62
+ "model.layers.8.input_layernorm.weight",
63
+ "model.layers.8.post_attention_layernorm.weight",
64
+ "model.layers.9.block_sparse_moe.gate.weight",
65
+ "model.layers.9.input_layernorm.weight",
66
+ "model.layers.9.post_attention_layernorm.weight",
67
+ "model.layers.10.block_sparse_moe.gate.weight",
68
+ "model.layers.10.input_layernorm.weight",
69
+ "model.layers.10.post_attention_layernorm.weight",
70
+ "model.layers.11.block_sparse_moe.gate.weight",
71
+ "model.layers.11.input_layernorm.weight",
72
+ "model.layers.11.post_attention_layernorm.weight",
73
+ "model.layers.12.block_sparse_moe.gate.weight",
74
+ "model.layers.12.input_layernorm.weight",
75
+ "model.layers.12.post_attention_layernorm.weight",
76
+ "model.layers.13.block_sparse_moe.gate.weight",
77
+ "model.layers.13.input_layernorm.weight",
78
+ "model.layers.13.post_attention_layernorm.weight",
79
+ "model.layers.14.block_sparse_moe.gate.weight",
80
+ "model.layers.14.input_layernorm.weight",
81
+ "model.layers.14.post_attention_layernorm.weight",
82
+ "model.layers.15.block_sparse_moe.gate.weight",
83
+ "model.layers.15.input_layernorm.weight",
84
+ "model.layers.15.post_attention_layernorm.weight",
85
+ "model.layers.16.block_sparse_moe.gate.weight",
86
+ "model.layers.16.input_layernorm.weight",
87
+ "model.layers.16.post_attention_layernorm.weight",
88
+ "model.layers.17.block_sparse_moe.gate.weight",
89
+ "model.layers.17.input_layernorm.weight",
90
+ "model.layers.17.post_attention_layernorm.weight",
91
+ "model.layers.18.block_sparse_moe.gate.weight",
92
+ "model.layers.18.input_layernorm.weight",
93
+ "model.layers.18.post_attention_layernorm.weight",
94
+ "model.layers.19.block_sparse_moe.gate.weight",
95
+ "model.layers.19.input_layernorm.weight",
96
+ "model.layers.19.post_attention_layernorm.weight",
97
+ "model.layers.20.block_sparse_moe.gate.weight",
98
+ "model.layers.20.input_layernorm.weight",
99
+ "model.layers.20.post_attention_layernorm.weight",
100
+ "model.layers.21.block_sparse_moe.gate.weight",
101
+ "model.layers.21.input_layernorm.weight",
102
+ "model.layers.21.post_attention_layernorm.weight",
103
+ "model.layers.22.block_sparse_moe.gate.weight",
104
+ "model.layers.22.input_layernorm.weight",
105
+ "model.layers.22.post_attention_layernorm.weight",
106
+ "model.layers.23.block_sparse_moe.gate.weight",
107
+ "model.layers.23.input_layernorm.weight",
108
+ "model.layers.23.post_attention_layernorm.weight",
109
+ "model.layers.24.block_sparse_moe.gate.weight",
110
+ "model.layers.24.input_layernorm.weight",
111
+ "model.layers.24.post_attention_layernorm.weight",
112
+ "model.layers.25.block_sparse_moe.gate.weight",
113
+ "model.layers.25.input_layernorm.weight",
114
+ "model.layers.25.post_attention_layernorm.weight",
115
+ "model.layers.26.block_sparse_moe.gate.weight",
116
+ "model.layers.26.input_layernorm.weight",
117
+ "model.layers.26.post_attention_layernorm.weight",
118
+ "model.layers.27.block_sparse_moe.gate.weight",
119
+ "model.layers.27.input_layernorm.weight",
120
+ "model.layers.27.post_attention_layernorm.weight",
121
+ "model.layers.28.block_sparse_moe.gate.weight",
122
+ "model.layers.28.input_layernorm.weight",
123
+ "model.layers.28.post_attention_layernorm.weight",
124
+ "model.layers.29.block_sparse_moe.gate.weight",
125
+ "model.layers.29.input_layernorm.weight",
126
+ "model.layers.29.post_attention_layernorm.weight",
127
+ "model.layers.30.block_sparse_moe.gate.weight",
128
+ "model.layers.30.input_layernorm.weight",
129
+ "model.layers.30.post_attention_layernorm.weight",
130
+ "model.layers.31.block_sparse_moe.gate.weight",
131
+ "model.layers.31.input_layernorm.weight",
132
+ "model.layers.31.post_attention_layernorm.weight",
133
+ "model.embed_tokens.weight",
134
+ "model.norm.weight",
135
+ "lm_head.weight"
136
+ ]
137
+ }
138
+ }
configuration_mixtral_aqlm.py DELETED
@@ -1,18 +0,0 @@
1
- from transformers import MixtralConfig as OrigLlamaConfig
2
-
3
-
4
- class MixtralConfig(OrigLlamaConfig):
5
- model_type = "mixtral_aqlm"
6
-
7
- def __init__(
8
- self,
9
- aqlm: dict[str, int] = {
10
- "nbits_per_codebook": 16,
11
- "num_codebooks": 1,
12
- "out_group_size": 8,
13
- "in_group_size": 1,
14
- },
15
- **kwargs,
16
- ):
17
- super().__init__(**kwargs)
18
- self.aqlm = aqlm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
- "transformers_version": "4.37.0"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
+ "transformers_version": "4.38.0.dev0"
6
  }
modeling_mixtral_aqlm.py DELETED
@@ -1,1603 +0,0 @@
1
- # coding=utf-8
2
- # This code is a modification of transformers/models/mixtral/modeling_mixtral.py , which is has the following copyright:
3
- # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
4
- #
5
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
- # and OPT implementations in this library. It has been modified from its
7
- # original forms to accommodate minor architectural differences compared
8
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
- #
10
- # Licensed under the Apache License, Version 2.0 (the "License");
11
- # you may not use this file except in compliance with the License.
12
- # You may obtain a copy of the License at
13
- #
14
- # http://www.apache.org/licenses/LICENSE-2.0
15
- #
16
- # Unless required by applicable law or agreed to in writing, software
17
- # distributed under the License is distributed on an "AS IS" BASIS,
18
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
- # See the License for the specific language governing permissions and
20
- # limitations under the License.
21
- """ PyTorch Mixtral model."""
22
- import inspect
23
- import math
24
- import warnings
25
- from typing import List, Optional, Tuple, Union
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- import torch.utils.checkpoint
30
- from aqlm import QuantizedLinear
31
- from torch import nn
32
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
- from transformers.activations import ACT2FN
34
- from transformers.cache_utils import Cache, DynamicCache
35
- from transformers.modeling_attn_mask_utils import (
36
- _prepare_4d_causal_attention_mask,
37
- _prepare_4d_causal_attention_mask_for_sdpa,
38
- )
39
- from transformers.modeling_outputs import (
40
- MoeCausalLMOutputWithPast,
41
- MoeModelOutputWithPast,
42
- SequenceClassifierOutputWithPast,
43
- )
44
- from transformers.modeling_utils import PreTrainedModel
45
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
46
- from transformers.utils import (
47
- add_start_docstrings,
48
- add_start_docstrings_to_model_forward,
49
- is_flash_attn_2_available,
50
- is_flash_attn_greater_or_equal_2_10,
51
- logging,
52
- replace_return_docstrings,
53
- )
54
- from transformers.utils.import_utils import is_torch_fx_available
55
-
56
- from .configuration_mixtral_aqlm import MixtralConfig
57
-
58
- if is_flash_attn_2_available():
59
- try:
60
- from flash_attn import flash_attn_func, flash_attn_varlen_func
61
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
- except:
63
- pass
64
-
65
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
66
-
67
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
68
- # It means that the function will not be traced through and simply appear as a node in the graph.
69
- if is_torch_fx_available():
70
- if not is_torch_greater_or_equal_than_1_13:
71
- import torch.fx
72
-
73
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
74
-
75
-
76
- logger = logging.get_logger(__name__)
77
-
78
- _CONFIG_FOR_DOC = "MixtralConfig"
79
-
80
-
81
- def load_balancing_loss_func(
82
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
83
- ) -> float:
84
- r"""
85
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
86
-
87
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
88
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
89
- experts is too unbalanced.
90
-
91
- Args:
92
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
93
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
94
- shape [batch_size X sequence_length, num_experts].
95
- attention_mask (`torch.Tensor`, None):
96
- The attention_mask used in forward function
97
- shape [batch_size X sequence_length] if not None.
98
- num_experts (`int`, *optional*):
99
- Number of experts
100
-
101
- Returns:
102
- The auxiliary loss.
103
- """
104
- if gate_logits is None or not isinstance(gate_logits, tuple):
105
- return 0
106
-
107
- if isinstance(gate_logits, tuple):
108
- compute_device = gate_logits[0].device
109
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
110
-
111
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
112
-
113
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
114
-
115
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
116
-
117
- if attention_mask is None:
118
- # Compute the percentage of tokens routed to each experts
119
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
120
-
121
- # Compute the average probability of routing to these experts
122
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
123
- else:
124
- batch_size, sequence_length = attention_mask.shape
125
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
126
-
127
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
128
- expert_attention_mask = (
129
- attention_mask[None, :, :, None, None]
130
- .expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
131
- .reshape(-1, 2, num_experts)
132
- .to(compute_device)
133
- )
134
-
135
- # Compute the percentage of tokens routed to each experts
136
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
137
- expert_attention_mask, dim=0
138
- )
139
-
140
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
141
- router_per_expert_attention_mask = (
142
- attention_mask[None, :, :, None]
143
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
144
- .reshape(-1, num_experts)
145
- .to(compute_device)
146
- )
147
-
148
- # Compute the average probability of routing to these experts
149
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
150
- router_per_expert_attention_mask, dim=0
151
- )
152
-
153
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
154
- return overall_loss * num_experts
155
-
156
-
157
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
158
- def _get_unpad_data(attention_mask):
159
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
160
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
161
- max_seqlen_in_batch = seqlens_in_batch.max().item()
162
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
163
- return (
164
- indices,
165
- cu_seqlens,
166
- max_seqlen_in_batch,
167
- )
168
-
169
-
170
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
171
- class MixtralRMSNorm(nn.Module):
172
- def __init__(self, hidden_size, eps=1e-6):
173
- """
174
- MixtralRMSNorm is equivalent to T5LayerNorm
175
- """
176
- super().__init__()
177
- self.weight = nn.Parameter(torch.ones(hidden_size))
178
- self.variance_epsilon = eps
179
-
180
- def forward(self, hidden_states):
181
- input_dtype = hidden_states.dtype
182
- hidden_states = hidden_states.to(torch.float32)
183
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
184
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
185
- return self.weight * hidden_states.to(input_dtype)
186
-
187
-
188
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
189
- class MixtralRotaryEmbedding(nn.Module):
190
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
191
- super().__init__()
192
-
193
- self.dim = dim
194
- self.max_position_embeddings = max_position_embeddings
195
- self.base = base
196
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
197
- self.register_buffer("inv_freq", inv_freq, persistent=False)
198
-
199
- # Build here to make `torch.jit.trace` work.
200
- self._set_cos_sin_cache(
201
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
202
- )
203
-
204
- def _set_cos_sin_cache(self, seq_len, device, dtype):
205
- self.max_seq_len_cached = seq_len
206
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
207
-
208
- freqs = torch.outer(t, self.inv_freq)
209
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
210
- emb = torch.cat((freqs, freqs), dim=-1)
211
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
212
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
213
-
214
- def forward(self, x, seq_len=None):
215
- # x: [bs, num_attention_heads, seq_len, head_size]
216
- if seq_len > self.max_seq_len_cached:
217
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
218
-
219
- return (
220
- self.cos_cached[:seq_len].to(dtype=x.dtype),
221
- self.sin_cached[:seq_len].to(dtype=x.dtype),
222
- )
223
-
224
-
225
- # Copied from transformers.models.llama.modeling_llama.rotate_half
226
- def rotate_half(x):
227
- """Rotates half the hidden dims of the input."""
228
- x1 = x[..., : x.shape[-1] // 2]
229
- x2 = x[..., x.shape[-1] // 2 :]
230
- return torch.cat((-x2, x1), dim=-1)
231
-
232
-
233
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
234
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
235
- """Applies Rotary Position Embedding to the query and key tensors.
236
-
237
- Args:
238
- q (`torch.Tensor`): The query tensor.
239
- k (`torch.Tensor`): The key tensor.
240
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
241
- sin (`torch.Tensor`): The sine part of the rotary embedding.
242
- position_ids (`torch.Tensor`):
243
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
244
- used to pass offsetted position ids when working with a KV-cache.
245
- unsqueeze_dim (`int`, *optional*, defaults to 1):
246
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
247
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
248
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
249
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
250
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
251
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
252
- Returns:
253
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
254
- """
255
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
256
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
257
- q_embed = (q * cos) + (rotate_half(q) * sin)
258
- k_embed = (k * cos) + (rotate_half(k) * sin)
259
- return q_embed, k_embed
260
-
261
-
262
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
263
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
264
- """
265
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
266
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
267
- """
268
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
269
- if n_rep == 1:
270
- return hidden_states
271
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
272
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
273
-
274
-
275
- # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
276
- class MixtralAttention(nn.Module):
277
- """
278
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
279
- and "Generating Long Sequences with Sparse Transformers".
280
- """
281
-
282
- def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
283
- super().__init__()
284
- self.config = config
285
- self.layer_idx = layer_idx
286
- if layer_idx is None:
287
- logger.warning_once(
288
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
289
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
290
- "when creating this class."
291
- )
292
-
293
- self.hidden_size = config.hidden_size
294
- self.num_heads = config.num_attention_heads
295
- self.head_dim = self.hidden_size // self.num_heads
296
- self.num_key_value_heads = config.num_key_value_heads
297
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
298
- self.max_position_embeddings = config.max_position_embeddings
299
- self.rope_theta = config.rope_theta
300
- self.is_causal = True
301
- self.attention_dropout = config.attention_dropout
302
-
303
- if (self.head_dim * self.num_heads) != self.hidden_size:
304
- raise ValueError(
305
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
306
- f" and `num_heads`: {self.num_heads})."
307
- )
308
- self.q_proj = QuantizedLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, **config.aqlm)
309
- self.k_proj = QuantizedLinear(
310
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, **config.aqlm
311
- )
312
- self.v_proj = QuantizedLinear(
313
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, **config.aqlm
314
- )
315
- self.o_proj = QuantizedLinear(self.num_heads * self.head_dim, self.hidden_size, bias=False, **config.aqlm)
316
-
317
- self.rotary_emb = MixtralRotaryEmbedding(
318
- self.head_dim,
319
- max_position_embeddings=self.max_position_embeddings,
320
- base=self.rope_theta,
321
- )
322
-
323
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
324
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
325
-
326
- def forward(
327
- self,
328
- hidden_states: torch.Tensor,
329
- attention_mask: Optional[torch.Tensor] = None,
330
- position_ids: Optional[torch.LongTensor] = None,
331
- past_key_value: Optional[Cache] = None,
332
- output_attentions: bool = False,
333
- use_cache: bool = False,
334
- **kwargs,
335
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
336
- if "padding_mask" in kwargs:
337
- warnings.warn(
338
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
339
- )
340
- bsz, q_len, _ = hidden_states.size()
341
-
342
- query_states = self.q_proj(hidden_states)
343
- key_states = self.k_proj(hidden_states)
344
- value_states = self.v_proj(hidden_states)
345
-
346
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
348
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
349
-
350
- kv_seq_len = key_states.shape[-2]
351
- if past_key_value is not None:
352
- if self.layer_idx is None:
353
- raise ValueError(
354
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
355
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
356
- "with a layer index."
357
- )
358
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
359
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
360
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
361
-
362
- if past_key_value is not None:
363
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
364
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
365
-
366
- # repeat k/v heads if n_kv_heads < n_heads
367
- key_states = repeat_kv(key_states, self.num_key_value_groups)
368
- value_states = repeat_kv(value_states, self.num_key_value_groups)
369
-
370
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
371
-
372
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
373
- raise ValueError(
374
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
375
- f" {attn_weights.size()}"
376
- )
377
-
378
- if attention_mask is not None:
379
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
380
- raise ValueError(
381
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
382
- )
383
-
384
- attn_weights = attn_weights + attention_mask
385
-
386
- # upcast attention to fp32
387
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
388
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
389
- attn_output = torch.matmul(attn_weights, value_states)
390
-
391
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
392
- raise ValueError(
393
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
394
- f" {attn_output.size()}"
395
- )
396
-
397
- attn_output = attn_output.transpose(1, 2).contiguous()
398
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
399
-
400
- attn_output = self.o_proj(attn_output)
401
-
402
- if not output_attentions:
403
- attn_weights = None
404
-
405
- return attn_output, attn_weights, past_key_value
406
-
407
-
408
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
409
- class MixtralFlashAttention2(MixtralAttention):
410
- """
411
- Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
412
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
413
- flash attention and deal with padding tokens in case the input contains any of them.
414
- """
415
-
416
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
417
- def __init__(self, *args, **kwargs):
418
- super().__init__(*args, **kwargs)
419
-
420
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
421
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
422
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
423
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
424
-
425
- def forward(
426
- self,
427
- hidden_states: torch.Tensor,
428
- attention_mask: Optional[torch.Tensor] = None,
429
- position_ids: Optional[torch.LongTensor] = None,
430
- past_key_value: Optional[Cache] = None,
431
- output_attentions: bool = False,
432
- use_cache: bool = False,
433
- **kwargs,
434
- ):
435
- if "padding_mask" in kwargs:
436
- warnings.warn(
437
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
438
- )
439
-
440
- # overwrite attention_mask with padding_mask
441
- attention_mask = kwargs.pop("padding_mask")
442
- bsz, q_len, _ = hidden_states.size()
443
-
444
- query_states = self.q_proj(hidden_states)
445
- key_states = self.k_proj(hidden_states)
446
- value_states = self.v_proj(hidden_states)
447
-
448
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
449
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
450
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
451
-
452
- kv_seq_len = key_states.shape[-2]
453
- if past_key_value is not None:
454
- if self.layer_idx is None:
455
- raise ValueError(
456
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
457
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
458
- "with a layer index."
459
- )
460
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
461
-
462
- # Because the input can be padded, the absolute sequence length depends on the max position id.
463
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
464
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
465
-
466
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
467
-
468
- use_sliding_windows = (
469
- _flash_supports_window_size
470
- and getattr(self.config, "sliding_window", None) is not None
471
- and kv_seq_len > self.config.sliding_window
472
- )
473
-
474
- if not _flash_supports_window_size:
475
- logger.warning_once(
476
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
477
- " make sure to upgrade flash-attn library."
478
- )
479
-
480
- if past_key_value is not None:
481
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
482
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
483
- if (
484
- getattr(self.config, "sliding_window", None) is not None
485
- and kv_seq_len > self.config.sliding_window
486
- and cache_has_contents
487
- ):
488
- slicing_tokens = 1 - self.config.sliding_window
489
-
490
- past_key = past_key_value[self.layer_idx][0]
491
- past_value = past_key_value[self.layer_idx][1]
492
-
493
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
494
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
495
-
496
- if past_key.shape[-2] != self.config.sliding_window - 1:
497
- raise ValueError(
498
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
499
- f" {past_key.shape}"
500
- )
501
-
502
- if attention_mask is not None:
503
- attention_mask = attention_mask[:, slicing_tokens:]
504
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
505
-
506
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
507
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
508
-
509
- # repeat k/v heads if n_kv_heads < n_heads
510
- key_states = repeat_kv(key_states, self.num_key_value_groups)
511
- value_states = repeat_kv(value_states, self.num_key_value_groups)
512
- dropout_rate = 0.0 if not self.training else self.attention_dropout
513
-
514
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
515
- # therefore the input hidden states gets silently casted in float32. Hence, we need
516
- # cast them back in float16 just to be sure everything works as expected.
517
- input_dtype = query_states.dtype
518
- if input_dtype == torch.float32:
519
- if torch.is_autocast_enabled():
520
- target_dtype = torch.get_autocast_gpu_dtype()
521
- # Handle the case where the model is quantized
522
- elif hasattr(self.config, "_pre_quantization_dtype"):
523
- target_dtype = self.config._pre_quantization_dtype
524
- else:
525
- target_dtype = self.q_proj.weight.dtype
526
-
527
- logger.warning_once(
528
- f"The input hidden states seems to be silently casted in float32, this might be related to"
529
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
530
- f" {target_dtype}."
531
- )
532
-
533
- query_states = query_states.to(target_dtype)
534
- key_states = key_states.to(target_dtype)
535
- value_states = value_states.to(target_dtype)
536
-
537
- # Reashape to the expected shape for Flash Attention
538
- query_states = query_states.transpose(1, 2)
539
- key_states = key_states.transpose(1, 2)
540
- value_states = value_states.transpose(1, 2)
541
-
542
- attn_output = self._flash_attention_forward(
543
- query_states,
544
- key_states,
545
- value_states,
546
- attention_mask,
547
- q_len,
548
- dropout=dropout_rate,
549
- use_sliding_windows=use_sliding_windows,
550
- )
551
-
552
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
553
- attn_output = self.o_proj(attn_output)
554
-
555
- if not output_attentions:
556
- attn_weights = None
557
-
558
- return attn_output, attn_weights, past_key_value
559
-
560
- def _flash_attention_forward(
561
- self,
562
- query_states,
563
- key_states,
564
- value_states,
565
- attention_mask,
566
- query_length,
567
- dropout=0.0,
568
- softmax_scale=None,
569
- use_sliding_windows=False,
570
- ):
571
- """
572
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
573
- first unpad the input, then computes the attention scores and pad the final attention scores.
574
-
575
- Args:
576
- query_states (`torch.Tensor`):
577
- Input query states to be passed to Flash Attention API
578
- key_states (`torch.Tensor`):
579
- Input key states to be passed to Flash Attention API
580
- value_states (`torch.Tensor`):
581
- Input value states to be passed to Flash Attention API
582
- attention_mask (`torch.Tensor`):
583
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
584
- position of padding tokens and 1 for the position of non-padding tokens.
585
- dropout (`int`, *optional*):
586
- Attention dropout
587
- softmax_scale (`float`, *optional*):
588
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
589
- use_sliding_windows (`bool`, *optional*):
590
- Whether to activate sliding window attention.
591
- """
592
- if not self._flash_attn_uses_top_left_mask:
593
- causal = self.is_causal
594
- else:
595
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
596
- causal = self.is_causal and query_length != 1
597
-
598
- # Contains at least one padding token in the sequence
599
- if attention_mask is not None:
600
- batch_size = query_states.shape[0]
601
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
602
- query_states, key_states, value_states, attention_mask, query_length
603
- )
604
-
605
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
606
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
607
-
608
- if not use_sliding_windows:
609
- attn_output_unpad = flash_attn_varlen_func(
610
- query_states,
611
- key_states,
612
- value_states,
613
- cu_seqlens_q=cu_seqlens_q,
614
- cu_seqlens_k=cu_seqlens_k,
615
- max_seqlen_q=max_seqlen_in_batch_q,
616
- max_seqlen_k=max_seqlen_in_batch_k,
617
- dropout_p=dropout,
618
- softmax_scale=softmax_scale,
619
- causal=causal,
620
- )
621
- else:
622
- attn_output_unpad = flash_attn_varlen_func(
623
- query_states,
624
- key_states,
625
- value_states,
626
- cu_seqlens_q=cu_seqlens_q,
627
- cu_seqlens_k=cu_seqlens_k,
628
- max_seqlen_q=max_seqlen_in_batch_q,
629
- max_seqlen_k=max_seqlen_in_batch_k,
630
- dropout_p=dropout,
631
- softmax_scale=softmax_scale,
632
- causal=causal,
633
- window_size=(self.config.sliding_window, self.config.sliding_window),
634
- )
635
-
636
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
637
- else:
638
- if not use_sliding_windows:
639
- attn_output = flash_attn_func(
640
- query_states,
641
- key_states,
642
- value_states,
643
- dropout,
644
- softmax_scale=softmax_scale,
645
- causal=causal,
646
- )
647
- else:
648
- attn_output = flash_attn_func(
649
- query_states,
650
- key_states,
651
- value_states,
652
- dropout,
653
- softmax_scale=softmax_scale,
654
- causal=causal,
655
- window_size=(self.config.sliding_window, self.config.sliding_window),
656
- )
657
-
658
- return attn_output
659
-
660
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
661
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
662
-
663
- # On the first iteration we need to properly re-create the padding mask
664
- # by slicing it on the proper place
665
- if kv_seq_len != attention_mask.shape[-1]:
666
- attention_mask_num_tokens = attention_mask.shape[-1]
667
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
668
-
669
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
670
-
671
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
672
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
673
-
674
- if query_length == kv_seq_len:
675
- query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
676
- cu_seqlens_q = cu_seqlens_k
677
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
678
- indices_q = indices_k
679
- elif query_length == 1:
680
- max_seqlen_in_batch_q = 1
681
- cu_seqlens_q = torch.arange(
682
- batch_size + 1, dtype=torch.int32, device=query_layer.device
683
- ) # There is a memcpy here, that is very bad.
684
- indices_q = cu_seqlens_q[:-1]
685
- query_layer = query_layer.squeeze(1)
686
- else:
687
- # The -q_len: slice assumes left padding.
688
- attention_mask = attention_mask[:, -query_length:]
689
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
690
-
691
- return (
692
- query_layer,
693
- key_layer,
694
- value_layer,
695
- indices_q,
696
- (cu_seqlens_q, cu_seqlens_k),
697
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
698
- )
699
-
700
-
701
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral
702
- class MixtralSdpaAttention(MixtralAttention):
703
- """
704
- Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
705
- `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
706
- SDPA API.
707
- """
708
-
709
- # Adapted from MixtralAttention.forward
710
- def forward(
711
- self,
712
- hidden_states: torch.Tensor,
713
- attention_mask: Optional[torch.Tensor] = None,
714
- position_ids: Optional[torch.LongTensor] = None,
715
- past_key_value: Optional[Cache] = None,
716
- output_attentions: bool = False,
717
- use_cache: bool = False,
718
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
719
- if output_attentions:
720
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
721
- logger.warning_once(
722
- "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
723
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
724
- )
725
- return super().forward(
726
- hidden_states=hidden_states,
727
- attention_mask=attention_mask,
728
- position_ids=position_ids,
729
- past_key_value=past_key_value,
730
- output_attentions=output_attentions,
731
- use_cache=use_cache,
732
- )
733
-
734
- bsz, q_len, _ = hidden_states.size()
735
-
736
- query_states = self.q_proj(hidden_states)
737
- key_states = self.k_proj(hidden_states)
738
- value_states = self.v_proj(hidden_states)
739
-
740
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
741
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
742
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
743
-
744
- kv_seq_len = key_states.shape[-2]
745
- if past_key_value is not None:
746
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
747
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
748
-
749
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
750
-
751
- if past_key_value is not None:
752
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
753
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
754
-
755
- key_states = repeat_kv(key_states, self.num_key_value_groups)
756
- value_states = repeat_kv(value_states, self.num_key_value_groups)
757
-
758
- if attention_mask is not None:
759
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
760
- raise ValueError(
761
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
762
- )
763
-
764
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
765
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
766
- if query_states.device.type == "cuda" and attention_mask is not None:
767
- query_states = query_states.contiguous()
768
- key_states = key_states.contiguous()
769
- value_states = value_states.contiguous()
770
-
771
- attn_output = torch.nn.functional.scaled_dot_product_attention(
772
- query_states,
773
- key_states,
774
- value_states,
775
- attn_mask=attention_mask,
776
- dropout_p=self.attention_dropout if self.training else 0.0,
777
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
778
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
779
- )
780
-
781
- attn_output = attn_output.transpose(1, 2).contiguous()
782
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
783
-
784
- attn_output = self.o_proj(attn_output)
785
-
786
- return attn_output, None, past_key_value
787
-
788
-
789
- MIXTRAL_ATTENTION_CLASSES = {
790
- "eager": MixtralAttention,
791
- "flash_attention_2": MixtralFlashAttention2,
792
- "sdpa": MixtralSdpaAttention,
793
- }
794
-
795
-
796
- class MixtralBLockSparseTop2MLP(nn.Module):
797
- def __init__(self, config: MixtralConfig):
798
- super().__init__()
799
- self.ffn_dim = config.intermediate_size
800
- self.hidden_dim = config.hidden_size
801
-
802
- self.w1 = QuantizedLinear(self.hidden_dim, self.ffn_dim, bias=False, **config.aqlm)
803
- self.w2 = QuantizedLinear(self.ffn_dim, self.hidden_dim, bias=False, **config.aqlm)
804
- self.w3 = QuantizedLinear(self.hidden_dim, self.ffn_dim, bias=False, **config.aqlm)
805
-
806
- self.act_fn = ACT2FN[config.hidden_act]
807
-
808
- def forward(self, hidden_states):
809
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
810
- current_hidden_states = self.w2(current_hidden_states)
811
- return current_hidden_states
812
-
813
-
814
- class MixtralSparseMoeBlock(nn.Module):
815
- """
816
- This implementation is
817
- strictly equivalent to standard MoE with full capacity (no
818
- dropped tokens). It's faster since it formulates MoE operations
819
- in terms of block-sparse operations to accomodate imbalanced
820
- assignments of tokens to experts, whereas standard MoE either
821
- (1) drop tokens at the cost of reduced performance or (2) set
822
- capacity factor to number of experts and thus waste computation
823
- and memory on padding.
824
- """
825
-
826
- def __init__(self, config):
827
- super().__init__()
828
- self.hidden_dim = config.hidden_size
829
- self.ffn_dim = config.intermediate_size
830
- self.num_experts = config.num_local_experts
831
- self.top_k = config.num_experts_per_tok
832
-
833
- # gating
834
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
835
-
836
- self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
837
-
838
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
839
- """ """
840
- batch_size, sequence_length, hidden_dim = hidden_states.shape
841
- hidden_states = hidden_states.view(-1, hidden_dim)
842
- # router_logits: (batch * sequence_length, n_experts)
843
- router_logits = self.gate(hidden_states)
844
-
845
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
846
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
847
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
848
- # we cast back to the input dtype
849
- routing_weights = routing_weights.to(hidden_states.dtype)
850
-
851
- final_hidden_states = torch.zeros(
852
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
853
- )
854
-
855
- # One hot encode the selected experts to create an expert mask
856
- # this will be used to easily index which expert is going to be sollicitated
857
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
858
-
859
- # Loop over all available experts in the model and perform the computation on each expert
860
- for expert_idx in range(self.num_experts):
861
- expert_layer = self.experts[expert_idx]
862
- idx, top_x = torch.where(expert_mask[expert_idx])
863
-
864
- if top_x.shape[0] == 0:
865
- continue
866
-
867
- # in torch it is faster to index using lists than torch tensors
868
- top_x_list = top_x.tolist()
869
- idx_list = idx.tolist()
870
-
871
- # Index the correct hidden states and compute the expert hidden state for
872
- # the current expert. We need to make sure to multiply the output hidden
873
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
874
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
875
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
876
-
877
- # However `index_add_` only support torch tensors for indexing so we'll use
878
- # the `top_x` tensor here.
879
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
880
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
881
- return final_hidden_states, router_logits
882
-
883
-
884
- class MixtralDecoderLayer(nn.Module):
885
- def __init__(self, config: MixtralConfig, layer_idx: int):
886
- super().__init__()
887
- self.hidden_size = config.hidden_size
888
-
889
- self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
890
-
891
- self.block_sparse_moe = MixtralSparseMoeBlock(config)
892
- self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
893
- self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
894
-
895
- def forward(
896
- self,
897
- hidden_states: torch.Tensor,
898
- attention_mask: Optional[torch.Tensor] = None,
899
- position_ids: Optional[torch.LongTensor] = None,
900
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
901
- output_attentions: Optional[bool] = False,
902
- output_router_logits: Optional[bool] = False,
903
- use_cache: Optional[bool] = False,
904
- **kwargs,
905
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
906
- if "padding_mask" in kwargs:
907
- warnings.warn(
908
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
909
- )
910
- """
911
- Args:
912
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
913
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
914
- `(batch, sequence_length)` where padding elements are indicated by 0.
915
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
916
- output_attentions (`bool`, *optional*):
917
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
918
- returned tensors for more detail.
919
- output_router_logits (`bool`, *optional*):
920
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
921
- should not be returned during inference.
922
- use_cache (`bool`, *optional*):
923
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
924
- (see `past_key_values`).
925
- """
926
-
927
- residual = hidden_states
928
-
929
- hidden_states = self.input_layernorm(hidden_states)
930
-
931
- # Self Attention
932
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
933
- hidden_states=hidden_states,
934
- attention_mask=attention_mask,
935
- position_ids=position_ids,
936
- past_key_value=past_key_value,
937
- output_attentions=output_attentions,
938
- use_cache=use_cache,
939
- )
940
- hidden_states = residual + hidden_states
941
-
942
- # Fully Connected
943
- residual = hidden_states
944
- hidden_states = self.post_attention_layernorm(hidden_states)
945
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
946
- hidden_states = residual + hidden_states
947
-
948
- outputs = (hidden_states,)
949
-
950
- if output_attentions:
951
- outputs += (self_attn_weights,)
952
-
953
- if use_cache:
954
- outputs += (present_key_value,)
955
-
956
- if output_router_logits:
957
- outputs += (router_logits,)
958
-
959
- return outputs
960
-
961
-
962
- MIXTRAL_START_DOCSTRING = r"""
963
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
964
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
965
- etc.)
966
-
967
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
968
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
969
- and behavior.
970
-
971
- Parameters:
972
- config ([`MixtralConfig`]):
973
- Model configuration class with all the parameters of the model. Initializing with a config file does not
974
- load the weights associated with the model, only the configuration. Check out the
975
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
976
- """
977
-
978
-
979
- @add_start_docstrings(
980
- "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
981
- MIXTRAL_START_DOCSTRING,
982
- )
983
- # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
984
- class MixtralPreTrainedModel(PreTrainedModel):
985
- config_class = MixtralConfig
986
- base_model_prefix = "model"
987
- supports_gradient_checkpointing = True
988
- _no_split_modules = ["MixtralDecoderLayer"]
989
- _skip_keys_device_placement = "past_key_values"
990
- _supports_flash_attn_2 = True
991
- _supports_sdpa = True
992
- _supports_cache_class = True
993
-
994
- def _init_weights(self, module):
995
- std = self.config.initializer_range
996
- if isinstance(module, nn.Linear):
997
- module.weight.data.normal_(mean=0.0, std=std)
998
- if module.bias is not None:
999
- module.bias.data.zero_()
1000
- elif isinstance(module, nn.Embedding):
1001
- module.weight.data.normal_(mean=0.0, std=std)
1002
- if module.padding_idx is not None:
1003
- module.weight.data[module.padding_idx].zero_()
1004
-
1005
-
1006
- MIXTRAL_INPUTS_DOCSTRING = r"""
1007
- Args:
1008
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1009
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1010
- it.
1011
-
1012
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1013
- [`PreTrainedTokenizer.__call__`] for details.
1014
-
1015
- [What are input IDs?](../glossary#input-ids)
1016
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1017
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1018
-
1019
- - 1 for tokens that are **not masked**,
1020
- - 0 for tokens that are **masked**.
1021
-
1022
- [What are attention masks?](../glossary#attention-mask)
1023
-
1024
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1025
- [`PreTrainedTokenizer.__call__`] for details.
1026
-
1027
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1028
- `past_key_values`).
1029
-
1030
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1031
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1032
- information on the default strategy.
1033
-
1034
- - 1 indicates the head is **not masked**,
1035
- - 0 indicates the head is **masked**.
1036
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1037
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1038
- config.n_positions - 1]`.
1039
-
1040
- [What are position IDs?](../glossary#position-ids)
1041
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1042
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1043
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1044
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1045
-
1046
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1047
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1048
-
1049
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1050
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1051
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1052
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1053
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1054
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1055
- model's internal embedding lookup matrix.
1056
- use_cache (`bool`, *optional*):
1057
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1058
- `past_key_values`).
1059
- output_attentions (`bool`, *optional*):
1060
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1061
- tensors for more detail.
1062
- output_hidden_states (`bool`, *optional*):
1063
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1064
- more detail.
1065
- output_router_logits (`bool`, *optional*):
1066
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1067
- should not be returned during inference.
1068
- return_dict (`bool`, *optional*):
1069
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1070
- """
1071
-
1072
-
1073
- @add_start_docstrings(
1074
- "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1075
- MIXTRAL_START_DOCSTRING,
1076
- )
1077
- # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1078
- class MixtralModel(MixtralPreTrainedModel):
1079
- """
1080
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1081
-
1082
- Args:
1083
- config: MixtralConfig
1084
- """
1085
-
1086
- def __init__(self, config: MixtralConfig):
1087
- super().__init__(config)
1088
- self.padding_idx = config.pad_token_id
1089
- self.vocab_size = config.vocab_size
1090
-
1091
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1092
- self.layers = nn.ModuleList(
1093
- [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1094
- )
1095
- self._attn_implementation = config._attn_implementation
1096
- self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1097
-
1098
- self.gradient_checkpointing = False
1099
- # Initialize weights and apply final processing
1100
- self.post_init()
1101
-
1102
- def get_input_embeddings(self):
1103
- return self.embed_tokens
1104
-
1105
- def set_input_embeddings(self, value):
1106
- self.embed_tokens = value
1107
-
1108
- # Ignore copy
1109
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1110
- def forward(
1111
- self,
1112
- input_ids: torch.LongTensor = None,
1113
- attention_mask: Optional[torch.Tensor] = None,
1114
- position_ids: Optional[torch.LongTensor] = None,
1115
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1116
- inputs_embeds: Optional[torch.FloatTensor] = None,
1117
- use_cache: Optional[bool] = None,
1118
- output_attentions: Optional[bool] = None,
1119
- output_hidden_states: Optional[bool] = None,
1120
- output_router_logits: Optional[bool] = None,
1121
- return_dict: Optional[bool] = None,
1122
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1123
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1124
- output_router_logits = (
1125
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1126
- )
1127
- output_hidden_states = (
1128
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1129
- )
1130
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1131
-
1132
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1133
-
1134
- # retrieve input_ids and inputs_embeds
1135
- if input_ids is not None and inputs_embeds is not None:
1136
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1137
- elif input_ids is not None:
1138
- batch_size, seq_length = input_ids.shape
1139
- elif inputs_embeds is not None:
1140
- batch_size, seq_length, _ = inputs_embeds.shape
1141
- else:
1142
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1143
-
1144
- past_key_values_length = 0
1145
-
1146
- if self.gradient_checkpointing and self.training:
1147
- if use_cache:
1148
- logger.warning_once(
1149
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1150
- )
1151
- use_cache = False
1152
-
1153
- if use_cache:
1154
- use_legacy_cache = not isinstance(past_key_values, Cache)
1155
- if use_legacy_cache:
1156
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1157
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1158
-
1159
- if position_ids is None:
1160
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1161
- position_ids = torch.arange(
1162
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1163
- )
1164
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1165
- else:
1166
- position_ids = position_ids.view(-1, seq_length).long()
1167
-
1168
- if inputs_embeds is None:
1169
- inputs_embeds = self.embed_tokens(input_ids)
1170
-
1171
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1172
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1173
- if is_padding_right:
1174
- raise ValueError(
1175
- "You are attempting to perform batched generation with padding_side='right'"
1176
- " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
1177
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1178
- )
1179
-
1180
- if self._attn_implementation == "flash_attention_2":
1181
- # 2d mask is passed through the layers
1182
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1183
- elif self._attn_implementation == "sdpa" and not output_attentions:
1184
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1185
- # the manual implementation that requires a 4D causal mask in all cases.
1186
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1187
- attention_mask,
1188
- (batch_size, seq_length),
1189
- inputs_embeds,
1190
- past_key_values_length,
1191
- )
1192
- else:
1193
- # 4d mask is passed through the layers
1194
- attention_mask = _prepare_4d_causal_attention_mask(
1195
- attention_mask,
1196
- (batch_size, seq_length),
1197
- inputs_embeds,
1198
- past_key_values_length,
1199
- sliding_window=self.config.sliding_window,
1200
- )
1201
-
1202
- hidden_states = inputs_embeds
1203
-
1204
- # decoder layers
1205
- all_hidden_states = () if output_hidden_states else None
1206
- all_self_attns = () if output_attentions else None
1207
- all_router_logits = () if output_router_logits else None
1208
- next_decoder_cache = None
1209
-
1210
- for decoder_layer in self.layers:
1211
- if output_hidden_states:
1212
- all_hidden_states += (hidden_states,)
1213
-
1214
- if self.gradient_checkpointing and self.training:
1215
- layer_outputs = self._gradient_checkpointing_func(
1216
- decoder_layer.__call__,
1217
- hidden_states,
1218
- attention_mask,
1219
- position_ids,
1220
- past_key_values,
1221
- output_attentions,
1222
- output_router_logits,
1223
- use_cache,
1224
- )
1225
- else:
1226
- layer_outputs = decoder_layer(
1227
- hidden_states,
1228
- attention_mask=attention_mask,
1229
- position_ids=position_ids,
1230
- past_key_value=past_key_values,
1231
- output_attentions=output_attentions,
1232
- output_router_logits=output_router_logits,
1233
- use_cache=use_cache,
1234
- )
1235
-
1236
- hidden_states = layer_outputs[0]
1237
-
1238
- if use_cache:
1239
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1240
-
1241
- if output_attentions:
1242
- all_self_attns += (layer_outputs[1],)
1243
-
1244
- if output_router_logits:
1245
- all_router_logits += (layer_outputs[-1],)
1246
-
1247
- hidden_states = self.norm(hidden_states)
1248
-
1249
- # add hidden states from the last decoder layer
1250
- if output_hidden_states:
1251
- all_hidden_states += (hidden_states,)
1252
-
1253
- next_cache = None
1254
- if use_cache:
1255
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1256
-
1257
- if not return_dict:
1258
- return tuple(
1259
- v
1260
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1261
- if v is not None
1262
- )
1263
- return MoeModelOutputWithPast(
1264
- last_hidden_state=hidden_states,
1265
- past_key_values=next_cache,
1266
- hidden_states=all_hidden_states,
1267
- attentions=all_self_attns,
1268
- router_logits=all_router_logits,
1269
- )
1270
-
1271
-
1272
- class MixtralForCausalLM(MixtralPreTrainedModel):
1273
- _tied_weights_keys = ["lm_head.weight"]
1274
-
1275
- def __init__(self, config):
1276
- super().__init__(config)
1277
- self.model = MixtralModel(config)
1278
- self.vocab_size = config.vocab_size
1279
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1280
- self.router_aux_loss_coef = config.router_aux_loss_coef
1281
- self.num_experts = config.num_local_experts
1282
- self.num_experts_per_tok = config.num_experts_per_tok
1283
- # Initialize weights and apply final processing
1284
- self.post_init()
1285
-
1286
- def get_input_embeddings(self):
1287
- return self.model.embed_tokens
1288
-
1289
- def set_input_embeddings(self, value):
1290
- self.model.embed_tokens = value
1291
-
1292
- def get_output_embeddings(self):
1293
- return self.lm_head
1294
-
1295
- def set_output_embeddings(self, new_embeddings):
1296
- self.lm_head = new_embeddings
1297
-
1298
- def set_decoder(self, decoder):
1299
- self.model = decoder
1300
-
1301
- def get_decoder(self):
1302
- return self.model
1303
-
1304
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1305
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1306
- # Ignore copy
1307
- def forward(
1308
- self,
1309
- input_ids: torch.LongTensor = None,
1310
- attention_mask: Optional[torch.Tensor] = None,
1311
- position_ids: Optional[torch.LongTensor] = None,
1312
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1313
- inputs_embeds: Optional[torch.FloatTensor] = None,
1314
- labels: Optional[torch.LongTensor] = None,
1315
- use_cache: Optional[bool] = None,
1316
- output_attentions: Optional[bool] = None,
1317
- output_hidden_states: Optional[bool] = None,
1318
- output_router_logits: Optional[bool] = None,
1319
- return_dict: Optional[bool] = None,
1320
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1321
- r"""
1322
- Args:
1323
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1324
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1325
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1326
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1327
-
1328
- Returns:
1329
-
1330
- Example:
1331
-
1332
- ```python
1333
- >>> from transformers import AutoTokenizer, MixtralForCausalLM
1334
-
1335
- >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1336
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1337
-
1338
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1339
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1340
-
1341
- >>> # Generate
1342
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1343
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1344
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1345
- ```"""
1346
-
1347
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1348
- output_router_logits = (
1349
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1350
- )
1351
-
1352
- output_hidden_states = (
1353
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1354
- )
1355
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1356
-
1357
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1358
- outputs = self.model(
1359
- input_ids=input_ids,
1360
- attention_mask=attention_mask,
1361
- position_ids=position_ids,
1362
- past_key_values=past_key_values,
1363
- inputs_embeds=inputs_embeds,
1364
- use_cache=use_cache,
1365
- output_attentions=output_attentions,
1366
- output_hidden_states=output_hidden_states,
1367
- output_router_logits=output_router_logits,
1368
- return_dict=return_dict,
1369
- )
1370
-
1371
- hidden_states = outputs[0]
1372
- logits = self.lm_head(hidden_states)
1373
- logits = logits.float()
1374
-
1375
- loss = None
1376
- if labels is not None:
1377
- # Shift so that tokens < n predict n
1378
- shift_logits = logits[..., :-1, :].contiguous()
1379
- shift_labels = labels[..., 1:].contiguous()
1380
- # Flatten the tokens
1381
- loss_fct = CrossEntropyLoss()
1382
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1383
- shift_labels = shift_labels.view(-1)
1384
- # Enable model parallelism
1385
- shift_labels = shift_labels.to(shift_logits.device)
1386
- loss = loss_fct(shift_logits, shift_labels)
1387
-
1388
- aux_loss = None
1389
- if output_router_logits:
1390
- aux_loss = load_balancing_loss_func(
1391
- outputs.router_logits if return_dict else outputs[-1],
1392
- self.num_experts,
1393
- self.num_experts_per_tok,
1394
- attention_mask,
1395
- )
1396
- if labels is not None:
1397
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1398
-
1399
- if not return_dict:
1400
- output = (logits,) + outputs[1:]
1401
- if output_router_logits:
1402
- output = (aux_loss,) + output
1403
- return (loss,) + output if loss is not None else output
1404
-
1405
- return MoeCausalLMOutputWithPast(
1406
- loss=loss,
1407
- aux_loss=aux_loss,
1408
- logits=logits,
1409
- past_key_values=outputs.past_key_values,
1410
- hidden_states=outputs.hidden_states,
1411
- attentions=outputs.attentions,
1412
- router_logits=outputs.router_logits,
1413
- )
1414
-
1415
- def prepare_inputs_for_generation(
1416
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1417
- ):
1418
- # Omit tokens covered by past_key_values
1419
- if past_key_values is not None:
1420
- if isinstance(past_key_values, Cache):
1421
- cache_length = past_key_values.get_seq_length()
1422
- past_length = past_key_values.seen_tokens
1423
- max_cache_length = past_key_values.get_max_length()
1424
- else:
1425
- cache_length = past_length = past_key_values[0][0].shape[2]
1426
- max_cache_length = None
1427
-
1428
- # Keep only the unprocessed tokens:
1429
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1430
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1431
- # input)
1432
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1433
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1434
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1435
- # input_ids based on the past_length.
1436
- elif past_length < input_ids.shape[1]:
1437
- input_ids = input_ids[:, past_length:]
1438
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1439
-
1440
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1441
- if (
1442
- max_cache_length is not None
1443
- and attention_mask is not None
1444
- and cache_length + input_ids.shape[1] > max_cache_length
1445
- ):
1446
- attention_mask = attention_mask[:, -max_cache_length:]
1447
-
1448
- position_ids = kwargs.get("position_ids", None)
1449
- if attention_mask is not None and position_ids is None:
1450
- # create position_ids on the fly for batch generation
1451
- position_ids = attention_mask.long().cumsum(-1) - 1
1452
- position_ids.masked_fill_(attention_mask == 0, 1)
1453
- if past_key_values:
1454
- position_ids = position_ids[:, -input_ids.shape[1] :]
1455
-
1456
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1457
- if inputs_embeds is not None and past_key_values is None:
1458
- model_inputs = {"inputs_embeds": inputs_embeds}
1459
- else:
1460
- model_inputs = {"input_ids": input_ids}
1461
-
1462
- model_inputs.update(
1463
- {
1464
- "position_ids": position_ids,
1465
- "past_key_values": past_key_values,
1466
- "use_cache": kwargs.get("use_cache"),
1467
- "attention_mask": attention_mask,
1468
- }
1469
- )
1470
- return model_inputs
1471
-
1472
- @staticmethod
1473
- def _reorder_cache(past_key_values, beam_idx):
1474
- reordered_past = ()
1475
- for layer_past in past_key_values:
1476
- reordered_past += (
1477
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1478
- )
1479
- return reordered_past
1480
-
1481
-
1482
- @add_start_docstrings(
1483
- """
1484
- The Mixtral Model transformer with a sequence classification head on top (linear layer).
1485
-
1486
- [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1487
- (e.g. GPT-2) do.
1488
-
1489
- Since it does classification on the last token, it requires to know the position of the last token. If a
1490
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1491
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1492
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1493
- each row of the batch).
1494
- """,
1495
- MIXTRAL_START_DOCSTRING,
1496
- )
1497
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1498
- class MixtralForSequenceClassification(MixtralPreTrainedModel):
1499
- def __init__(self, config):
1500
- super().__init__(config)
1501
- self.num_labels = config.num_labels
1502
- self.model = MixtralModel(config)
1503
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1504
-
1505
- # Initialize weights and apply final processing
1506
- self.post_init()
1507
-
1508
- def get_input_embeddings(self):
1509
- return self.model.embed_tokens
1510
-
1511
- def set_input_embeddings(self, value):
1512
- self.model.embed_tokens = value
1513
-
1514
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1515
- def forward(
1516
- self,
1517
- input_ids: torch.LongTensor = None,
1518
- attention_mask: Optional[torch.Tensor] = None,
1519
- position_ids: Optional[torch.LongTensor] = None,
1520
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1521
- inputs_embeds: Optional[torch.FloatTensor] = None,
1522
- labels: Optional[torch.LongTensor] = None,
1523
- use_cache: Optional[bool] = None,
1524
- output_attentions: Optional[bool] = None,
1525
- output_hidden_states: Optional[bool] = None,
1526
- return_dict: Optional[bool] = None,
1527
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1528
- r"""
1529
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1530
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1531
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1532
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1533
- """
1534
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1535
-
1536
- transformer_outputs = self.model(
1537
- input_ids,
1538
- attention_mask=attention_mask,
1539
- position_ids=position_ids,
1540
- past_key_values=past_key_values,
1541
- inputs_embeds=inputs_embeds,
1542
- use_cache=use_cache,
1543
- output_attentions=output_attentions,
1544
- output_hidden_states=output_hidden_states,
1545
- return_dict=return_dict,
1546
- )
1547
- hidden_states = transformer_outputs[0]
1548
- logits = self.score(hidden_states)
1549
-
1550
- if input_ids is not None:
1551
- batch_size = input_ids.shape[0]
1552
- else:
1553
- batch_size = inputs_embeds.shape[0]
1554
-
1555
- if self.config.pad_token_id is None and batch_size != 1:
1556
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1557
- if self.config.pad_token_id is None:
1558
- sequence_lengths = -1
1559
- else:
1560
- if input_ids is not None:
1561
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1562
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1563
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1564
- sequence_lengths = sequence_lengths.to(logits.device)
1565
- else:
1566
- sequence_lengths = -1
1567
-
1568
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1569
-
1570
- loss = None
1571
- if labels is not None:
1572
- labels = labels.to(logits.device)
1573
- if self.config.problem_type is None:
1574
- if self.num_labels == 1:
1575
- self.config.problem_type = "regression"
1576
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1577
- self.config.problem_type = "single_label_classification"
1578
- else:
1579
- self.config.problem_type = "multi_label_classification"
1580
-
1581
- if self.config.problem_type == "regression":
1582
- loss_fct = MSELoss()
1583
- if self.num_labels == 1:
1584
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1585
- else:
1586
- loss = loss_fct(pooled_logits, labels)
1587
- elif self.config.problem_type == "single_label_classification":
1588
- loss_fct = CrossEntropyLoss()
1589
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1590
- elif self.config.problem_type == "multi_label_classification":
1591
- loss_fct = BCEWithLogitsLoss()
1592
- loss = loss_fct(pooled_logits, labels)
1593
- if not return_dict:
1594
- output = (pooled_logits,) + transformer_outputs[1:]
1595
- return ((loss,) + output) if loss is not None else output
1596
-
1597
- return SequenceClassifierOutputWithPast(
1598
- loss=loss,
1599
- logits=pooled_logits,
1600
- past_key_values=transformer_outputs.past_key_values,
1601
- hidden_states=transformer_outputs.hidden_states,
1602
- attentions=transformer_outputs.attentions,
1603
- )