|
import math |
|
import torch |
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .rotary_emb import apply_rotary_emb |
|
from .utils import nearest_power_of_two |
|
|
|
try: |
|
from flash_attn import flash_attn_func as fa2 |
|
except ImportError as e: |
|
print( |
|
f"Unable to import Triton-based flash attention: {e}. No alternative currently available." |
|
) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, config): |
|
super(Attention, self).__init__() |
|
if isinstance(config.torch_dtype, str): |
|
torch_dtype = getattr(torch, config.torch_dtype) |
|
else: |
|
torch_dtype = config.torch_dtype |
|
assert torch.cuda.is_available(), "CUDA is required." |
|
assert config.n_embd % config.n_heads == 0 |
|
self.n_heads = config.n_heads |
|
|
|
self.device = torch.device("cuda") |
|
self.bsz = config.bsz |
|
self.attn = nn.Linear( |
|
config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype |
|
) |
|
self.o_proj = nn.Linear( |
|
config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype |
|
) |
|
self.o_proj.SCALE_INIT = 1 |
|
self.dropout = config.dropout |
|
self.resid_dropout = nn.Dropout(self.dropout) |
|
self.alibi_slopes = self._get_alibi_slopes(self.n_heads) |
|
self.window_size = config.window_size |
|
self.softcap = config.softcap |
|
|
|
def _generate_slopes(self, n: int): |
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
return [start * (start**i) for i in range(n)] |
|
|
|
def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25): |
|
|
|
if math.log2(n_heads).is_integer(): |
|
slopes = self._generate_slopes(n_heads) |
|
else: |
|
|
|
n = nearest_power_of_two(n_heads, round_up=False) |
|
slopes_power_of_two = self._generate_slopes(n) |
|
|
|
|
|
extra_slopes = self._generate_slopes(2 * n) |
|
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n] |
|
slopes = slopes_power_of_two + extra_slopes_trunc |
|
slopes = torch.tensor(slopes, device=self.device) |
|
slopes = slopes * interpolation_factor |
|
return slopes.to(torch.float32) |
|
|
|
|
|
def forward(self, x): |
|
bsz, seq_len, d_in = x.size() |
|
|
|
qkv = self.attn(x) |
|
q, k, v = torch.chunk(qkv, 3, dim=2) |
|
|
|
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
|
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
|
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads) |
|
y = fa2( |
|
q, |
|
k, |
|
v, |
|
dropout_p=self.dropout if self.training else 0.0, |
|
causal=True, |
|
window_size=(self.window_size, 0), |
|
alibi_slopes=self.alibi_slopes, |
|
softcap=self.softcap, |
|
) |
|
y = y.contiguous().view(bsz, seq_len, d_in) |
|
y = self.resid_dropout(self.o_proj(y)) |
|
return y |
|
|
|
class AttentionSDPA(nn.Module): |
|
def __init__(self, config): |
|
super(Attention, self).__init__() |
|
if isinstance(config.torch_dtype, str): |
|
torch_dtype = getattr(torch, config.torch_dtype) |
|
else: |
|
torch_dtype = config.torch_dtype |
|
assert torch.cuda.is_available(), "CUDA is required." |
|
assert config.n_embd % config.n_heads == 0 |
|
self.n_heads = config.n_heads |
|
|
|
self.device = torch.device("cuda") |
|
self.bsz = config.bsz |
|
self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype) |
|
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype) |
|
self.dropout = config.dropout |
|
self.resid_dropout = nn.Dropout(self.dropout) |
|
|
|
def forward(self, x): |
|
bsz, seq_len, d_in = x.size() |
|
|
|
qkv = self.attn(x) |
|
q, k, v = torch.chunk(qkv, 3, dim=2) |
|
|
|
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
|
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
|
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2) |
|
|
|
y = F.scaled_dot_product_attention( |
|
q, k, v, |
|
is_causal=True, |
|
dropout_p=self.dropout if self.training else 0.0 |
|
) |
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in) |
|
|
|
y = self.resid_dropout(self.o_proj(y)) |
|
return y |
|
|
|
|
|
class FlexAttention(nn.Module): |
|
""" |
|
Generalized Multihead Attention and supports various attention masks. |
|
Supports Rotary Positional Embeddings. |
|
""" |
|
def __init__(self, config, mask_mod, score_mod=None): |
|
""" |
|
Initializes the Attention class. |
|
|
|
Args: |
|
dim (int): Embedding size. |
|
num_heads (int): Number of heads. |
|
mask_mod (Callable): Mask to modify attention scores, e.g. causal. |
|
""" |
|
super().__init__() |
|
self.dim, self.num_heads = config.dim, config.num_heads |
|
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})" |
|
self.head_dim = config.dim // config.num_heads |
|
|
|
self.wq = nn.Linear(config.dim, config.dim) |
|
self.wk = nn.Linear(config.dim, config.dim) |
|
self.wv = nn.Linear(config.dim, config.dim) |
|
|
|
self.mask_mod = mask_mod |
|
self.score_mod = score_mod |
|
self.block_mask = create_block_mask( |
|
mask_mod=self.mask_mod, |
|
B=None, |
|
H=None, |
|
Q_LEN=config.seq_len, |
|
KV_LEN=config.seq_len, |
|
device=config.device, |
|
) |
|
|
|
self.o_proj = nn.Linear(config.dim, config.dim) |
|
self.o_proj.SCALE_INIT = 1 |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor = None, |
|
q: torch.Tensor = None, |
|
k: torch.Tensor = None, |
|
v: torch.Tensor = None, |
|
freqs_cis: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
if x is not None: |
|
q = k = v = x |
|
if any(t is None for t in [q, k, v]): |
|
raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.") |
|
|
|
bsz, q_len, _ = q.shape |
|
_, k_len, _ = k.shape |
|
_, v_len, _ = v.shape |
|
|
|
Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim) |
|
K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim) |
|
V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim) |
|
|
|
Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis) |
|
|
|
output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod) |
|
output = output.reshape(bsz, q_len, self.dim) |
|
output = self.o_proj(output) |
|
return output |