radna commited on
Commit
45d84e9
1 Parent(s): 68c0edf

Update flash_attention.py

Browse files
Files changed (1) hide show
  1. flash_attention.py +72 -75
flash_attention.py CHANGED
@@ -1,75 +1,72 @@
1
- import torch
2
- import torch.nn as nn
3
- from einops import rearrange
4
-
5
- try: # v1
6
- from flash_attn.flash_attn_interface import \
7
- flash_attn_unpadded_qkvpacked_func
8
- except: # v2
9
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
10
-
11
- from flash_attn.bert_padding import pad_input, unpad_input
12
-
13
-
14
- class FlashAttention(nn.Module):
15
- """Implement the scaled dot product attention with softmax.
16
- Arguments
17
- ---------
18
- softmax_scale: The temperature to use for the softmax attention.
19
- (default: 1/sqrt(d_keys) where d_keys is computed at
20
- runtime)
21
- attention_dropout: The dropout rate to apply to the attention
22
- (default: 0.0)
23
- """
24
-
25
- def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
26
- super().__init__()
27
- self.softmax_scale = softmax_scale
28
- self.dropout_p = attention_dropout
29
-
30
- def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
31
- max_s=None, need_weights=False):
32
- """Implements the multihead softmax attention.
33
- Arguments
34
- ---------
35
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
36
- if unpadded: (nnz, 3, h, d)
37
- key_padding_mask: a bool tensor of shape (B, S)
38
- """
39
- assert not need_weights
40
- assert qkv.dtype in [torch.float16, torch.bfloat16]
41
- assert qkv.is_cuda
42
-
43
- if cu_seqlens is None:
44
- batch_size = qkv.shape[0]
45
- seqlen = qkv.shape[1]
46
- if key_padding_mask is None:
47
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
48
- max_s = seqlen
49
- cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
50
- device=qkv.device)
51
- output = flash_attn_unpadded_qkvpacked_func(
52
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
53
- softmax_scale=self.softmax_scale, causal=causal
54
- )
55
- output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
56
- else:
57
- nheads = qkv.shape[-2]
58
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
59
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
60
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
61
- output_unpad = flash_attn_unpadded_qkvpacked_func(
62
- x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
63
- softmax_scale=self.softmax_scale, causal=causal
64
- )
65
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
66
- indices, batch_size, seqlen),
67
- 'b s (h d) -> b s h d', h=nheads)
68
- else:
69
- assert max_s is not None
70
- output = flash_attn_unpadded_qkvpacked_func(
71
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
72
- softmax_scale=self.softmax_scale, causal=causal
73
- )
74
-
75
- return output, None
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+
5
+ from triton_flash_atn import _attention
6
+
7
+ from triton_bert_pading import pad_input, unpad_input
8
+
9
+
10
+
11
+ class FlashAttention(nn.Module):
12
+ """Implement the scaled dot product attention with softmax.
13
+ Arguments
14
+ ---------
15
+ softmax_scale: The temperature to use for the softmax attention.
16
+ (default: 1/sqrt(d_keys) where d_keys is computed at
17
+ runtime)
18
+ attention_dropout: The dropout rate to apply to the attention
19
+ (default: 0.0)
20
+ """
21
+
22
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
23
+ super().__init__()
24
+ self.softmax_scale = softmax_scale
25
+ self.dropout_p = attention_dropout
26
+
27
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
28
+ max_s=None, need_weights=False):
29
+ """Implements the multihead softmax attention.
30
+ Arguments
31
+ ---------
32
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
33
+ if unpadded: (nnz, 3, h, d)
34
+ key_padding_mask: a bool tensor of shape (B, S)
35
+ """
36
+ assert not need_weights
37
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
38
+ assert qkv.is_cuda
39
+
40
+ if cu_seqlens is None:
41
+ batch_size = qkv.shape[0]
42
+ seqlen = qkv.shape[1]
43
+ if key_padding_mask is None:
44
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
45
+ max_s = seqlen
46
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
47
+ device=qkv.device)
48
+ output = _attention.apply(
49
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
50
+ softmax_scale=self.softmax_scale, causal=causal
51
+ )
52
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
53
+ else:
54
+ nheads = qkv.shape[-2]
55
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
56
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
57
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
58
+ output_unpad = _attention.apply(
59
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
60
+ softmax_scale=self.softmax_scale, causal=causal
61
+ )
62
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
63
+ indices, batch_size, seqlen),
64
+ 'b s (h d) -> b s h d', h=nheads)
65
+ else:
66
+ assert max_s is not None
67
+ output = _attention.apply(
68
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
69
+ softmax_scale=self.softmax_scale, causal=causal
70
+ )
71
+
72
+ return output, None