import math import torch from torch.nn.attention.flex_attention import flex_attention, create_block_mask import torch.nn as nn import torch.nn.functional as F from .rotary_emb import apply_rotary_emb from .utils import nearest_power_of_two try: from flash_attn import flash_attn_func as fa2 except ImportError as e: print( f"Unable to import Triton-based flash attention: {e}. No alternative currently available." ) # TODO: Add FlexAttention + local attention mask when it's in stable release class Attention(nn.Module): def __init__(self, config): super(Attention, self).__init__() if isinstance(config.torch_dtype, str): torch_dtype = getattr(torch, config.torch_dtype) else: torch_dtype = config.torch_dtype assert torch.cuda.is_available(), "CUDA is required." assert config.n_embd % config.n_heads == 0 self.n_heads = config.n_heads self.device = torch.device("cuda") self.bsz = config.bsz self.attn = nn.Linear( config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype ) self.o_proj = nn.Linear( config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype ) self.o_proj.SCALE_INIT = 1 self.dropout = config.dropout self.resid_dropout = nn.Dropout(self.dropout) self.alibi_slopes = self._get_alibi_slopes(self.n_heads) self.window_size = config.window_size self.softcap = config.softcap def _generate_slopes(self, n: int): start = 2 ** (-(2 ** -(math.log2(n) - 3))) return [start * (start**i) for i in range(n)] def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25): # If n_heads is a power of 2, generate slopes directly if math.log2(n_heads).is_integer(): slopes = self._generate_slopes(n_heads) else: # Get slopes for the nearest power of two n = nearest_power_of_two(n_heads, round_up=False) slopes_power_of_two = self._generate_slopes(n) # Generate extra slopes extra_slopes = self._generate_slopes(2 * n) extra_slopes_trunc = extra_slopes[0::2][: n_heads - n] slopes = slopes_power_of_two + extra_slopes_trunc slopes = torch.tensor(slopes, device=self.device) slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017 return slopes.to(torch.float32) # Ensure slopes are in float32 def forward(self, x): bsz, seq_len, d_in = x.size() qkv = self.attn(x) q, k, v = torch.chunk(qkv, 3, dim=2) q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) y = fa2( # https://arxiv.org/pdf/2307.08691 q, k, v, dropout_p=self.dropout if self.training else 0.0, causal=True, window_size=(self.window_size, 0), alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409 softcap=self.softcap, # https://arxiv.org/pdf/2408.00118 ) y = y.contiguous().view(bsz, seq_len, d_in) y = self.resid_dropout(self.o_proj(y)) return y class AttentionSDPA(nn.Module): def __init__(self, config): super(Attention, self).__init__() if isinstance(config.torch_dtype, str): torch_dtype = getattr(torch, config.torch_dtype) else: torch_dtype = config.torch_dtype assert torch.cuda.is_available(), "CUDA is required." assert config.n_embd % config.n_heads == 0 self.n_heads = config.n_heads self.device = torch.device("cuda") # Technically don't need CUDA for SDPA self.bsz = config.bsz self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype) self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype) self.dropout = config.dropout self.resid_dropout = nn.Dropout(self.dropout) def forward(self, x): bsz, seq_len, d_in = x.size() qkv = self.attn(x) q, k, v = torch.chunk(qkv, 3, dim=2) q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) y = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0 ) y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in) y = self.resid_dropout(self.o_proj(y)) return y class FlexAttention(nn.Module): """ Generalized Multihead Attention and supports various attention masks. Supports Rotary Positional Embeddings. """ def __init__(self, config, mask_mod, score_mod=None): """ Initializes the Attention class. Args: dim (int): Embedding size. num_heads (int): Number of heads. mask_mod (Callable): Mask to modify attention scores, e.g. causal. """ super().__init__() self.dim, self.num_heads = config.dim, config.num_heads assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})" self.head_dim = config.dim // config.num_heads self.wq = nn.Linear(config.dim, config.dim) self.wk = nn.Linear(config.dim, config.dim) self.wv = nn.Linear(config.dim, config.dim) self.mask_mod = mask_mod self.score_mod = score_mod self.block_mask = create_block_mask( mask_mod=self.mask_mod, B=None, # Broadcast H=None, # Broadcast Q_LEN=config.seq_len, KV_LEN=config.seq_len, device=config.device, ) self.o_proj = nn.Linear(config.dim, config.dim) self.o_proj.SCALE_INIT = 1 def forward( self, x: torch.Tensor = None, q: torch.Tensor = None, k: torch.Tensor = None, v: torch.Tensor = None, freqs_cis: torch.Tensor = None, ) -> torch.Tensor: if x is not None: q = k = v = x if any(t is None for t in [q, k, v]): raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.") bsz, q_len, _ = q.shape _, k_len, _ = k.shape _, v_len, _ = v.shape Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim) K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim) V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim) Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis) output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod) output = output.reshape(bsz, q_len, self.dim) output = self.o_proj(output) return output