Shaltiel commited on
Commit
f5c007b
·
1 Parent(s): cbc604b

Added flash attention

Browse files
configuration_megatron_gpt.py CHANGED
@@ -81,6 +81,8 @@ class MegatronGPTConfig(PretrainedConfig):
81
  Whether to calculate and apply the relative position bias within the attention function.
82
  If this is False, then model.generate will require you to calculate the triangular attention
83
  mask and pass it through in the attention mask.
 
 
84
  rope_scaling (`Dict`, *optional*):
85
  Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
86
  strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
@@ -118,6 +120,7 @@ class MegatronGPTConfig(PretrainedConfig):
118
  eos_token_id=2,
119
  tie_word_embeddings=False,
120
  rope_scaling=None,
 
121
  **kwargs,
122
  ):
123
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -141,6 +144,7 @@ class MegatronGPTConfig(PretrainedConfig):
141
  self.use_cache = use_cache
142
  self.self_attention_relative_position_bias = self_attention_relative_position_bias
143
  self.tie_word_embeddings = tie_word_embeddings
 
144
  self.rope_scaling = rope_scaling
145
  self._rope_scaling_validation()
146
 
 
81
  Whether to calculate and apply the relative position bias within the attention function.
82
  If this is False, then model.generate will require you to calculate the triangular attention
83
  mask and pass it through in the attention mask.
84
+ use_flash_attention (`bool`, *optional*, defaults to `False`):
85
+ When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
86
  rope_scaling (`Dict`, *optional*):
87
  Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
88
  strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
 
120
  eos_token_id=2,
121
  tie_word_embeddings=False,
122
  rope_scaling=None,
123
+ use_flash_attention=False,
124
  **kwargs,
125
  ):
126
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
144
  self.use_cache = use_cache
145
  self.self_attention_relative_position_bias = self_attention_relative_position_bias
146
  self.tie_word_embeddings = tie_word_embeddings
147
+ self.use_flash_attention = use_flash_attention
148
  self.rope_scaling = rope_scaling
149
  self._rope_scaling_validation()
150
 
modeling_megatron_gpt.py CHANGED
@@ -21,6 +21,7 @@
21
  """ PyTorch MegatronGPT model."""
22
 
23
  from dataclasses import dataclass
 
24
  from typing import Optional, Tuple, Union
25
 
26
  import torch
@@ -43,8 +44,21 @@ from transformers.modeling_outputs import (
43
  )
44
  from transformers.modeling_utils import PreTrainedModel
45
  from transformers.utils import logging
 
46
  from .configuration_megatron_gpt import MegatronGPTConfig
47
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def get_activation(act):
49
  if act in ["gelu", "geglu", "fast-geglu"]:
50
  act = 'gelu'
@@ -111,9 +125,10 @@ class MegatronGPTAttention(nn.Module):
111
  self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
112
  self._init_rope()
113
 
 
114
  self.register_buffer(
115
  "norm_factor",
116
- torch.sqrt(torch.tensor(self.head_size if config.normalize_attention_scores else 1.0, dtype=torch.float32)).to(torch.get_default_dtype()),
117
  persistent=False,
118
  )
119
  self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
@@ -207,7 +222,10 @@ class MegatronGPTAttention(nn.Module):
207
  present = (key, value) if use_cache else None
208
 
209
  # Compute attention
210
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
 
 
 
211
 
212
  # Reshape outputs
213
  attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
@@ -244,6 +262,34 @@ class MegatronGPTAttention(nn.Module):
244
  # -> [bs, seq_len, hidden_size]
245
  return tensor
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
248
  # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
249
  # compute causal mask from causal mask buffer
 
21
  """ PyTorch MegatronGPT model."""
22
 
23
  from dataclasses import dataclass
24
+ import math
25
  from typing import Optional, Tuple, Union
26
 
27
  import torch
 
44
  )
45
  from transformers.modeling_utils import PreTrainedModel
46
  from transformers.utils import logging
47
+ # try to load using a relative path, but if it fails try loading it directly
48
  from .configuration_megatron_gpt import MegatronGPTConfig
49
 
50
+ try:
51
+ from flash_attn.bert_padding import unpad_input, pad_input
52
+ from flash_attn import flash_attn_varlen_func as flash_attn_func
53
+ HAS_FLASH = True
54
+ except:
55
+ try:
56
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_func
57
+ HAS_FLASH = True
58
+ except:
59
+ HAS_FLASH = False
60
+
61
+
62
  def get_activation(act):
63
  if act in ["gelu", "geglu", "fast-geglu"]:
64
  act = 'gelu'
 
125
  self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
126
  self._init_rope()
127
 
128
+ self.norm_factor_float = math.sqrt(self.head_size if config.normalize_attention_scores else 1.0)
129
  self.register_buffer(
130
  "norm_factor",
131
+ torch.tensor(self.norm_factor_float, dtype=torch.float32).to(torch.get_default_dtype()),
132
  persistent=False,
133
  )
134
  self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
 
222
  present = (key, value) if use_cache else None
223
 
224
  # Compute attention
225
+ if not HAS_FLASH or output_attentions or head_mask is not None or not self.config.use_flash_attention:
226
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
227
+ else:
228
+ attn_output = self._flash_attn(query, key, value, attention_mask)
229
 
230
  # Reshape outputs
231
  attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
 
262
  # -> [bs, seq_len, hidden_size]
263
  return tensor
264
 
265
+ def _flash_attn(self, query, key, value, attention_mask=None):
266
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
267
+ # compute causal mask from causal mask buffer
268
+ batch_size, num_attention_heads, query_seq_length, attn_head_size = query.size()
269
+
270
+ # transpose_for_scores_flash returns b s h d
271
+ query_layer = query.transpose(1, 2).half()
272
+ key_layer = key.transpose(1, 2).half()
273
+ value_layer = value.transpose(1, 2).half()
274
+
275
+ # fix the mask
276
+ attention_mask = (attention_mask == 0).int().squeeze(1).squeeze(1)
277
+ query_layer, query_indicies, cu_seqlens_q, max_seqlen_q = unpad_input(query_layer, attention_mask[:, -query_seq_length:])
278
+ key_layer, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_layer, attention_mask)
279
+ value_layer, _, _, _ = unpad_input(value_layer, attention_mask)
280
+
281
+ # returns [batch * seq, nheads, headdim]
282
+ context_layer = flash_attn_func(query_layer, key_layer, value_layer,
283
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
284
+ dropout_p=self.config.attention_dropout, softmax_scale=1 / self.norm_factor_float, causal=self.self_attention_relative_position_bias if max_seqlen_q > 1 else False)
285
+
286
+ # fix the shape to be [bs, num_attention_heads, seq_len, attn_head_size]
287
+ context_layer = pad_input(context_layer, query_indicies, batch_size, query_seq_length)
288
+ context_layer = context_layer.view(batch_size, query_seq_length, num_attention_heads, attn_head_size) \
289
+ .transpose(1, 2)
290
+
291
+ return context_layer.to(value.dtype)
292
+
293
  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
294
  # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
295
  # compute causal mask from causal mask buffer