duzx16 commited on
Commit
a7eaddd
·
1 Parent(s): 835c717

Add support for flash attention 2

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +117 -4
modeling_chatglm.py CHANGED
@@ -21,12 +21,17 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging, is_torch_npu_available
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
 
 
 
 
30
  # flags required to enable jit fusion kernels
31
 
32
  if sys.platform != 'darwin' and not is_torch_npu_available():
@@ -160,12 +165,13 @@ class RMSNorm(torch.nn.Module):
160
  class CoreAttention(torch.nn.Module):
161
  def __init__(self, config: ChatGLMConfig, layer_number):
162
  super(CoreAttention, self).__init__()
163
-
164
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
165
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
166
  if self.apply_query_key_layer_scaling:
167
  self.attention_softmax_in_fp32 = True
168
  self.layer_number = max(1, layer_number)
 
169
 
170
  projection_size = config.kv_channels * config.num_attention_heads
171
 
@@ -259,21 +265,122 @@ class SdpaAttention(CoreAttention):
259
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
260
  if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
261
  context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
262
- is_causal=True)
 
263
  else:
264
  if attention_mask is not None:
265
  attention_mask = ~attention_mask
266
  context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
267
- attention_mask)
 
268
  context_layer = context_layer.transpose(1, 2).contiguous()
269
  new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
270
  context_layer = context_layer.reshape(*new_context_layer_shape)
271
  return context_layer
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  CORE_ATTENTION_CLASSES = {
275
  "eager": CoreAttention,
276
  "sdpa": SdpaAttention,
 
277
  }
278
 
279
 
@@ -652,12 +759,18 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
652
  config_class = ChatGLMConfig
653
  base_model_prefix = "transformer"
654
  _no_split_modules = ["GLMBlock"]
 
 
655
 
656
  def _init_weights(self, module: nn.Module):
657
  """Initialize the weights."""
658
  return
659
 
660
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
 
 
 
 
661
  batch_size, seq_length = input_ids.shape
662
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
663
  full_attention_mask.tril_()
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
+ is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
  from .configuration_chatglm import ChatGLMConfig
30
 
31
+ if is_flash_attn_2_available():
32
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
33
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
34
+
35
  # flags required to enable jit fusion kernels
36
 
37
  if sys.platform != 'darwin' and not is_torch_npu_available():
 
165
  class CoreAttention(torch.nn.Module):
166
  def __init__(self, config: ChatGLMConfig, layer_number):
167
  super(CoreAttention, self).__init__()
168
+ self.config = config
169
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
170
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
171
  if self.apply_query_key_layer_scaling:
172
  self.attention_softmax_in_fp32 = True
173
  self.layer_number = max(1, layer_number)
174
+ self.is_causal = True
175
 
176
  projection_size = config.kv_channels * config.num_attention_heads
177
 
 
265
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
266
  if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
267
  context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
268
+ is_causal=True,
269
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
270
  else:
271
  if attention_mask is not None:
272
  attention_mask = ~attention_mask
273
  context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
274
+ attention_mask,
275
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
276
  context_layer = context_layer.transpose(1, 2).contiguous()
277
  new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
278
  context_layer = context_layer.reshape(*new_context_layer_shape)
279
  return context_layer
280
 
281
 
282
+ def _get_unpad_data(attention_mask):
283
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
284
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
285
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
286
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
287
+ return (
288
+ indices,
289
+ cu_seqlens,
290
+ max_seqlen_in_batch,
291
+ )
292
+
293
+
294
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
295
+ class FlashAttention2(CoreAttention):
296
+ def __init__(self, *args, **kwargs):
297
+ super().__init__(*args, **kwargs)
298
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
299
+
300
+ def forward(self, query_states, key_states, value_states, attention_mask):
301
+ query_states = query_states.transpose(1, 2)
302
+ key_states = key_states.transpose(1, 2)
303
+ value_states = value_states.transpose(1, 2)
304
+ batch_size, query_length = query_states.shape[:2]
305
+ if not self._flash_attn_uses_top_left_mask:
306
+ causal = self.is_causal
307
+ else:
308
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
309
+ causal = self.is_causal and query_length != 1
310
+ dropout = self.config.attention_dropout if self.training else 0.0
311
+ # Contains at least one padding token in the sequence
312
+ if attention_mask is not None:
313
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
314
+ query_states, key_states, value_states, attention_mask, query_length
315
+ )
316
+
317
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
318
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
319
+
320
+ attn_output_unpad = flash_attn_varlen_func(
321
+ query_states,
322
+ key_states,
323
+ value_states,
324
+ cu_seqlens_q=cu_seqlens_q,
325
+ cu_seqlens_k=cu_seqlens_k,
326
+ max_seqlen_q=max_seqlen_in_batch_q,
327
+ max_seqlen_k=max_seqlen_in_batch_k,
328
+ dropout_p=dropout,
329
+ softmax_scale=None,
330
+ causal=causal,
331
+ )
332
+
333
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
334
+ else:
335
+ attn_output = flash_attn_func(
336
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
337
+ )
338
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
339
+ return attn_output
340
+
341
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
342
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
343
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
344
+
345
+ key_layer = index_first_axis(
346
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
347
+ )
348
+ value_layer = index_first_axis(
349
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
350
+ )
351
+ if query_length == kv_seq_len:
352
+ query_layer = index_first_axis(
353
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
354
+ )
355
+ cu_seqlens_q = cu_seqlens_k
356
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
357
+ indices_q = indices_k
358
+ elif query_length == 1:
359
+ max_seqlen_in_batch_q = 1
360
+ cu_seqlens_q = torch.arange(
361
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
362
+ ) # There is a memcpy here, that is very bad.
363
+ indices_q = cu_seqlens_q[:-1]
364
+ query_layer = query_layer.squeeze(1)
365
+ else:
366
+ # The -q_len: slice assumes left padding.
367
+ attention_mask = attention_mask[:, -query_length:]
368
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
369
+
370
+ return (
371
+ query_layer,
372
+ key_layer,
373
+ value_layer,
374
+ indices_q,
375
+ (cu_seqlens_q, cu_seqlens_k),
376
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
377
+ )
378
+
379
+
380
  CORE_ATTENTION_CLASSES = {
381
  "eager": CoreAttention,
382
  "sdpa": SdpaAttention,
383
+ "flash_attention_2": FlashAttention2
384
  }
385
 
386
 
 
759
  config_class = ChatGLMConfig
760
  base_model_prefix = "transformer"
761
  _no_split_modules = ["GLMBlock"]
762
+ _supports_flash_attn_2 = True
763
+ _supports_sdpa = True
764
 
765
  def _init_weights(self, module: nn.Module):
766
  """Initialize the weights."""
767
  return
768
 
769
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
770
+ if self.config._attn_implementation == "flash_attention_2":
771
+ if padding_mask is not None and not padding_mask.all():
772
+ return padding_mask
773
+ return None
774
  batch_size, seq_length = input_ids.shape
775
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
776
  full_attention_mask.tril_()