import torch import torch.nn as nn from einops import rearrange try: from .triton_flash_atn import _attention from .triton_bert_pading import pad_input, unpad_input except: print("FlashAttention is not installed.") class FlashAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None ): super().__init__() self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def forward( self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False, ): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None if unpadded: (nnz, 3, h, d) key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda if cu_seqlens is None: batch_size = qkv.shape[0] seqlen = qkv.shape[1] if key_padding_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = seqlen cu_seqlens = torch.arange( 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device, ) output = _attention.apply( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, self.softmax_scale, causal ) output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) else: nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads ) output_unpad = _attention.apply( x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, self.softmax_scale, causal ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen, ), "b s (h d) -> b s h d", h=nheads, ) else: assert max_s is not None output = _attention.apply( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, self.softmax_scale, causal ) return output, None