File size: 7,824 Bytes
b9b3e2d f223616 b9b3e2d 44d2569 b9b3e2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn import flash_attn_func
except ImportError as e:
print(
f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
)
def nearest_power_of_two(x: int, round_up: bool = False) -> int:
return (
1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
)
def _generate_slopes(self, n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * (start**i) for i in range(n)]
def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
# If n_heads is a power of 2, generate slopes directly
if math.log2(n_heads).is_integer():
slopes = self._generate_slopes(n_heads)
else:
# Get slopes for the nearest power of two
n = nearest_power_of_two(n_heads, round_up=False)
slopes_power_of_two = self._generate_slopes(n)
# Generate extra slopes
extra_slopes = self._generate_slopes(2 * n)
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
slopes = slopes_power_of_two + extra_slopes_trunc
slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
return slopes
def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
# For half the dimensions, build the scale factor:
freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
freqs = 1.0 / (theta ** freq_seq)
# Outer product with positions
t = torch.arange(max_seq_len, dtype=torch.float32)
angles = torch.outer(t, freqs)
# Build a complex exponential e^{i * theta}
freqs_cis = torch.polar(
torch.ones_like(angles),
angles
)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
x is [B, n_heads, seq_len, head_dim_as_complex],
so we want to broadcast freqs_cis from [max_seq_len, half_dim]
to [1, 1, seq_len, half_dim].
"""
seq_len = x.shape[2]
freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
return freqs_cis.view(1, 1, seq_len, -1)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Convert real -> complex by grouping last dim in pairs
# shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
# Multiply => apply rotation
xq_complex = xq_complex * freqs_cis
xk_complex = xk_complex * freqs_cis
# Convert back to real => shape [B, n_heads, seq_len, head_dim]
xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
self.dim, self.num_heads = config.dim, config.num_heads
assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
self.head_dim = config.dim // config.num_heads
self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
self.c_proj.SCALE_INIT = 1
self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
self.window_size = config.window_size
self.softcap = config.softcap
self.dropout = config.dropout
self.resid_dropout = nn.Dropout(self.dropout)
def _generate_slopes(self, n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * (start**i) for i in range(n)]
def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
# If n_heads is a power of 2, generate slopes directly
if math.log2(num_heads).is_integer():
slopes = self._generate_slopes(num_heads)
else:
# Get slopes for the nearest power of two
n = nearest_power_of_two(num_heads, round_up=False)
slopes_power_of_two = self._generate_slopes(n)
# Generate extra slopes
extra_slopes = self._generate_slopes(2 * n)
extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
slopes = slopes_power_of_two + extra_slopes_trunc
slopes = torch.tensor(slopes, device=torch.device("cuda"), dtype=torch.float32)
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
return slopes
def forward(
self,
x: torch.Tensor = None,
q: torch.Tensor = None,
k: torch.Tensor = None,
v: torch.Tensor = None,
freqs_cis: torch.Tensor = None,
) -> torch.Tensor:
if x is not None:
q = k = v = x
if any(t is None for t in [q, k, v]):
raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
bsz, q_len, dim = q.shape
_, k_len, _ = k.shape
_, v_len, _ = v.shape
qkv = self.c_attn(x)
q, k, v = torch.chunk(qkv, 3, dim=2)
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, k_len, self.num_heads, self.head_dim)
v = v.view(bsz, v_len, self.num_heads, self.head_dim)
if self.alibi_slopes is None: # Use either ALiBi or RoPE
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
q=q, k=k, v=v,
dropout_p=self.dropout if self.training else 0.0,
causal=True,
window_size=(self.window_size, 0), # Set to config.seq_len if full attention
alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
)
y = y.contiguous().view(bsz, q_len, -1)
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
# https://arxiv.org/pdf/2002.05202
super().__init__()
self.hidden_size = config.dim
self.intermediate_size = config.dim * config.mlp_scale
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
outputs = self.dropout(outputs)
return outputs
class AttentionLayer(nn.Module):
def __init__(self, config) -> None:
super(AttentionLayer, self).__init__()
self.attn_norm = nn.RMSNorm(config.dim)
self.attn = Attention(config=config)
self.mlp_norm = nn.RMSNorm(config.dim)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
x = x + self.mlp(self.mlp_norm(x))
return x |