Coobiw commited on
Commit
58240a8
·
verified ·
1 Parent(s): 7e54849

add flash-attn support

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +151 -20
modeling_internlm2.py CHANGED
@@ -25,6 +25,7 @@ import torch
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
 
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import BaseModelOutputWithPast
30
  from transformers.modeling_utils import PreTrainedModel
@@ -42,6 +43,30 @@ logger = logging.get_logger(__name__)
42
 
43
  _CONFIG_FOR_DOC = 'InternLM2Config'
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
47
  def _make_causal_mask(input_ids_shape: torch.Size,
@@ -264,21 +289,21 @@ class InternLM2MLP(nn.Module):
264
  bias=False,
265
  lora_r=256,
266
  lora_alpha=256,
267
- lora_len=576)
268
  self.w3 = PLoRA(
269
  self.hidden_size,
270
  self.intermediate_size,
271
  bias=False,
272
  lora_r=256,
273
  lora_alpha=256,
274
- lora_len=576)
275
  self.w2 = PLoRA(
276
  self.intermediate_size,
277
  self.hidden_size,
278
  bias=False,
279
  lora_r=256,
280
  lora_alpha=256,
281
- lora_len=576)
282
 
283
  self.act_fn = ACT2FN[config.hidden_act]
284
 
@@ -332,7 +357,7 @@ class InternLM2Attention(nn.Module):
332
  bias=config.bias,
333
  lora_r=256,
334
  lora_alpha=256,
335
- lora_len=576)
336
 
337
  self.wo = PLoRA(
338
  self.num_heads * self.head_dim,
@@ -340,7 +365,7 @@ class InternLM2Attention(nn.Module):
340
  bias=config.bias,
341
  lora_r=256,
342
  lora_alpha=256,
343
- lora_len=576)
344
  self._init_rope()
345
 
346
  def _init_rope(self):
@@ -498,7 +523,7 @@ class InternLM2FlashAttention2(InternLM2Attention):
498
  qkv_states = rearrange(
499
  qkv_states,
500
  'b q (h gs d) -> b q h gs d',
501
- gs=self.num_heads + 2 * self.num_key_value_heads,
502
  d=self.head_dim,
503
  q=q_len,
504
  )
@@ -507,6 +532,10 @@ class InternLM2FlashAttention2(InternLM2Attention):
507
  query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
508
  key_states = qkv_states[..., -2, :]
509
  value_states = qkv_states[..., -1, :]
 
 
 
 
510
 
511
  kv_seq_len = key_states.shape[-2]
512
  if past_key_value is not None:
@@ -523,12 +552,12 @@ class InternLM2FlashAttention2(InternLM2Attention):
523
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
524
 
525
  past_key_value = (key_states, value_states) if use_cache else None
526
-
527
  query_states = query_states.transpose(1, 2)
528
  key_states = key_states.transpose(1, 2)
529
  value_states = value_states.transpose(1, 2)
530
 
531
- dropout_rate = 0.0 if not self.training else self.attention_dropout
532
 
533
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
534
  # therefore the input hidden states gets silently casted in float32. Hence, we need
@@ -569,17 +598,110 @@ class InternLM2FlashAttention2(InternLM2Attention):
569
  attn_weights = None
570
 
571
  return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  class InternLM2DecoderLayer(nn.Module):
575
 
576
  def __init__(self, config: InternLM2Config):
577
  super().__init__()
578
  self.hidden_size = config.hidden_size
579
- self.attention = (
580
- InternLM2Attention(config=config)
581
- if not getattr(config, '_flash_attn_2_enabled', False) else
582
- InternLM2FlashAttention2(config=config))
583
  self.feed_forward = InternLM2MLP(config)
584
  self.attention_norm = InternLM2RMSNorm(
585
  config.hidden_size, eps=config.rms_norm_eps)
@@ -773,6 +895,8 @@ class InternLM2Model(InternLM2PreTrainedModel):
773
 
774
  def __init__(self, config: InternLM2Config):
775
  super().__init__(config)
 
 
776
  self.padding_idx = config.pad_token_id
777
  self.vocab_size = config.vocab_size
778
 
@@ -843,6 +967,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
843
  use_cache = use_cache if use_cache is not None else self.config.use_cache
844
 
845
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
846
 
847
  # retrieve input_ids and inputs_embeds
848
  if input_ids is not None and inputs_embeds is not None:
@@ -876,14 +1003,18 @@ class InternLM2Model(InternLM2PreTrainedModel):
876
  inputs_embeds = self.tok_embeddings(input_ids)
877
  im_mask = torch.zeros(inputs_embeds.shape[:2]).to(
878
  inputs_embeds.device).bool()
879
- # embed positions
880
- if attention_mask is None:
881
- attention_mask = torch.ones((batch_size, seq_length_with_past),
882
- dtype=torch.bool,
883
- device=inputs_embeds.device)
884
- attention_mask = self._prepare_decoder_attention_mask(
885
- attention_mask, (batch_size, seq_length), inputs_embeds,
886
- past_key_values_length)
 
 
 
 
887
 
888
  # embed positions
889
  hidden_states = inputs_embeds
 
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
28
+ import torch.nn.functional as F
29
  from transformers.activations import ACT2FN
30
  from transformers.modeling_outputs import BaseModelOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
 
43
 
44
  _CONFIG_FOR_DOC = 'InternLM2Config'
45
 
46
+ flash_attn_func, flash_attn_varlen_func = None, None
47
+ pad_input, index_first_axis, unpad_input = None, None, None
48
+ def _import_flash_attn():
49
+ global flash_attn_func, flash_attn_varlen_func
50
+ global pad_input, index_first_axis, unpad_input
51
+ try:
52
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
+ except ImportError:
57
+ raise ImportError("flash_attn is not installed.")
58
+
59
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
+ def _get_unpad_data(attention_mask):
61
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
62
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
63
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
64
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
65
+ return (
66
+ indices,
67
+ cu_seqlens,
68
+ max_seqlen_in_batch,
69
+ )
70
 
71
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
72
  def _make_causal_mask(input_ids_shape: torch.Size,
 
289
  bias=False,
290
  lora_r=256,
291
  lora_alpha=256,
292
+ lora_len=1225)
293
  self.w3 = PLoRA(
294
  self.hidden_size,
295
  self.intermediate_size,
296
  bias=False,
297
  lora_r=256,
298
  lora_alpha=256,
299
+ lora_len=1225)
300
  self.w2 = PLoRA(
301
  self.intermediate_size,
302
  self.hidden_size,
303
  bias=False,
304
  lora_r=256,
305
  lora_alpha=256,
306
+ lora_len=1225)
307
 
308
  self.act_fn = ACT2FN[config.hidden_act]
309
 
 
357
  bias=config.bias,
358
  lora_r=256,
359
  lora_alpha=256,
360
+ lora_len=1225)
361
 
362
  self.wo = PLoRA(
363
  self.num_heads * self.head_dim,
 
365
  bias=config.bias,
366
  lora_r=256,
367
  lora_alpha=256,
368
+ lora_len=1225)
369
  self._init_rope()
370
 
371
  def _init_rope(self):
 
523
  qkv_states = rearrange(
524
  qkv_states,
525
  'b q (h gs d) -> b q h gs d',
526
+ gs=2 + self.num_key_value_groups,
527
  d=self.head_dim,
528
  q=q_len,
529
  )
 
532
  query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
533
  key_states = qkv_states[..., -2, :]
534
  value_states = qkv_states[..., -1, :]
535
+
536
+ query_states = query_states.transpose(1, 2)
537
+ key_states = key_states.transpose(1, 2)
538
+ value_states = value_states.transpose(1, 2)
539
 
540
  kv_seq_len = key_states.shape[-2]
541
  if past_key_value is not None:
 
552
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
553
 
554
  past_key_value = (key_states, value_states) if use_cache else None
555
+
556
  query_states = query_states.transpose(1, 2)
557
  key_states = key_states.transpose(1, 2)
558
  value_states = value_states.transpose(1, 2)
559
 
560
+ dropout_rate = 0.0 if not self.training else getattr(self, "dropout_rate", 0.0)
561
 
562
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
563
  # therefore the input hidden states gets silently casted in float32. Hence, we need
 
598
  attn_weights = None
599
 
600
  return attn_output, attn_weights, past_key_value
601
+
602
+ def _flash_attention_forward(
603
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
604
+ ):
605
+ """
606
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
607
+ first unpad the input, then computes the attention scores and pad the final attention scores.
608
+ Args:
609
+ query_states (`torch.Tensor`):
610
+ Input query states to be passed to Flash Attention API
611
+ key_states (`torch.Tensor`):
612
+ Input key states to be passed to Flash Attention API
613
+ value_states (`torch.Tensor`):
614
+ Input value states to be passed to Flash Attention API
615
+ attention_mask (`torch.Tensor`):
616
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
617
+ position of padding tokens and 1 for the position of non-padding tokens.
618
+ dropout (`int`, *optional*):
619
+ Attention dropout
620
+ softmax_scale (`float`, *optional*):
621
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
622
+ """
623
+ # Contains at least one padding token in the sequence
624
+ causal = self.is_causal and query_length != 1
625
+ if attention_mask is not None:
626
+ batch_size = query_states.shape[0]
627
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
628
+ query_states, key_states, value_states, attention_mask, query_length
629
+ )
630
 
631
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
632
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
633
+
634
+ attn_output_unpad = flash_attn_varlen_func(
635
+ query_states,
636
+ key_states,
637
+ value_states,
638
+ cu_seqlens_q=cu_seqlens_q,
639
+ cu_seqlens_k=cu_seqlens_k,
640
+ max_seqlen_q=max_seqlen_in_batch_q,
641
+ max_seqlen_k=max_seqlen_in_batch_k,
642
+ dropout_p=dropout,
643
+ softmax_scale=softmax_scale,
644
+ causal=causal,
645
+ )
646
+
647
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
648
+ else:
649
+ attn_output = flash_attn_func(
650
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
651
+ )
652
 
653
+ return attn_output
654
+
655
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
656
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
657
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
658
+
659
+ key_layer = index_first_axis(
660
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
661
+ )
662
+ value_layer = index_first_axis(
663
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
664
+ )
665
+
666
+ if query_length == kv_seq_len:
667
+ query_layer = index_first_axis(
668
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
669
+ )
670
+ cu_seqlens_q = cu_seqlens_k
671
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
672
+ indices_q = indices_k
673
+ elif query_length == 1:
674
+ max_seqlen_in_batch_q = 1
675
+ cu_seqlens_q = torch.arange(
676
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
677
+ ) # There is a memcpy here, that is very bad.
678
+ indices_q = cu_seqlens_q[:-1]
679
+ query_layer = query_layer.squeeze(1)
680
+ else:
681
+ # The -q_len: slice assumes left padding.
682
+ attention_mask = attention_mask[:, -query_length:]
683
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
684
+
685
+ return (
686
+ query_layer,
687
+ key_layer,
688
+ value_layer,
689
+ indices_q.to(torch.int64),
690
+ (cu_seqlens_q, cu_seqlens_k),
691
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
692
+ )
693
+
694
+ INTERNLM2_ATTENTION_CLASSES = {
695
+ "eager": InternLM2Attention,
696
+ "flash_attention_2": InternLM2FlashAttention2,
697
+ }
698
  class InternLM2DecoderLayer(nn.Module):
699
 
700
  def __init__(self, config: InternLM2Config):
701
  super().__init__()
702
  self.hidden_size = config.hidden_size
703
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
704
+
 
 
705
  self.feed_forward = InternLM2MLP(config)
706
  self.attention_norm = InternLM2RMSNorm(
707
  config.hidden_size, eps=config.rms_norm_eps)
 
895
 
896
  def __init__(self, config: InternLM2Config):
897
  super().__init__(config)
898
+ print(f"Attention Implementation: {self.config.attn_implementation}")
899
+
900
  self.padding_idx = config.pad_token_id
901
  self.vocab_size = config.vocab_size
902
 
 
967
  use_cache = use_cache if use_cache is not None else self.config.use_cache
968
 
969
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
970
+
971
+ if self.config.attn_implementation == "flash_attention_2":
972
+ _import_flash_attn()
973
 
974
  # retrieve input_ids and inputs_embeds
975
  if input_ids is not None and inputs_embeds is not None:
 
1003
  inputs_embeds = self.tok_embeddings(input_ids)
1004
  im_mask = torch.zeros(inputs_embeds.shape[:2]).to(
1005
  inputs_embeds.device).bool()
1006
+
1007
+ if self.config.attn_implementation == "flash_attention_2":
1008
+ # 2d mask is passed through the layers
1009
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1010
+ else:
1011
+ if attention_mask is None:
1012
+ attention_mask = torch.ones((batch_size, seq_length_with_past),
1013
+ dtype=torch.bool,
1014
+ device=inputs_embeds.device)
1015
+ attention_mask = self._prepare_decoder_attention_mask(
1016
+ attention_mask, (batch_size, seq_length), inputs_embeds,
1017
+ past_key_values_length)
1018
 
1019
  # embed positions
1020
  hidden_states = inputs_embeds