Transformer_500M / attn.py
yagizdevre's picture
fix
2955790
import math
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch.nn as nn
import torch.nn.functional as F
from .rotary_emb import apply_rotary_emb
from .utils import nearest_power_of_two
try:
from flash_attn import flash_attn_func as fa2
except ImportError as e:
print(
f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
)
# TODO: Add FlexAttention + local attention mask when it's in stable release
class Attention(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
assert torch.cuda.is_available(), "CUDA is required."
assert config.n_embd % config.n_heads == 0
self.n_heads = config.n_heads
self.device = torch.device("cuda")
self.bsz = config.bsz
self.attn = nn.Linear(
config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
)
self.o_proj = nn.Linear(
config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
)
self.o_proj.SCALE_INIT = 1
self.dropout = config.dropout
self.resid_dropout = nn.Dropout(self.dropout)
self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
self.window_size = config.window_size
self.softcap = config.softcap
def _generate_slopes(self, n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * (start**i) for i in range(n)]
def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
# If n_heads is a power of 2, generate slopes directly
if math.log2(n_heads).is_integer():
slopes = self._generate_slopes(n_heads)
else:
# Get slopes for the nearest power of two
n = nearest_power_of_two(n_heads, round_up=False)
slopes_power_of_two = self._generate_slopes(n)
# Generate extra slopes
extra_slopes = self._generate_slopes(2 * n)
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
slopes = slopes_power_of_two + extra_slopes_trunc
slopes = torch.tensor(slopes, device=self.device)
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
return slopes.to(torch.float32) # Ensure slopes are in float32
def forward(self, x):
bsz, seq_len, d_in = x.size()
qkv = self.attn(x)
q, k, v = torch.chunk(qkv, 3, dim=2)
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
y = fa2( # https://arxiv.org/pdf/2307.08691
q,
k,
v,
dropout_p=self.dropout if self.training else 0.0,
causal=True,
window_size=(self.window_size, 0),
alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
)
y = y.contiguous().view(bsz, seq_len, d_in)
y = self.resid_dropout(self.o_proj(y))
return y
class AttentionSDPA(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
assert torch.cuda.is_available(), "CUDA is required."
assert config.n_embd % config.n_heads == 0
self.n_heads = config.n_heads
self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
self.bsz = config.bsz
self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
self.dropout = config.dropout
self.resid_dropout = nn.Dropout(self.dropout)
def forward(self, x):
bsz, seq_len, d_in = x.size()
qkv = self.attn(x)
q, k, v = torch.chunk(qkv, 3, dim=2)
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
y = F.scaled_dot_product_attention(
q, k, v,
is_causal=True,
dropout_p=self.dropout if self.training else 0.0
)
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
y = self.resid_dropout(self.o_proj(y))
return y
class FlexAttention(nn.Module):
"""
Generalized Multihead Attention and supports various attention masks.
Supports Rotary Positional Embeddings.
"""
def __init__(self, config, mask_mod, score_mod=None):
"""
Initializes the Attention class.
Args:
dim (int): Embedding size.
num_heads (int): Number of heads.
mask_mod (Callable): Mask to modify attention scores, e.g. causal.
"""
super().__init__()
self.dim, self.num_heads = config.dim, config.num_heads
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
self.head_dim = config.dim // config.num_heads
self.wq = nn.Linear(config.dim, config.dim)
self.wk = nn.Linear(config.dim, config.dim)
self.wv = nn.Linear(config.dim, config.dim)
self.mask_mod = mask_mod
self.score_mod = score_mod
self.block_mask = create_block_mask(
mask_mod=self.mask_mod,
B=None, # Broadcast
H=None, # Broadcast
Q_LEN=config.seq_len,
KV_LEN=config.seq_len,
device=config.device,
)
self.o_proj = nn.Linear(config.dim, config.dim)
self.o_proj.SCALE_INIT = 1
def forward(
self,
x: torch.Tensor = None,
q: torch.Tensor = None,
k: torch.Tensor = None,
v: torch.Tensor = None,
freqs_cis: torch.Tensor = None,
) -> torch.Tensor:
if x is not None:
q = k = v = x
if any(t is None for t in [q, k, v]):
raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
bsz, q_len, _ = q.shape
_, k_len, _ = k.shape
_, v_len, _ = v.shape
Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim)
K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim)
V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim)
Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)
output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod)
output = output.reshape(bsz, q_len, self.dim)
output = self.o_proj(output)
return output