File size: 15,053 Bytes
3f9c425 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
try:
import cudnn
except ImportError:
cudnn = None
from einops import rearrange, repeat
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
# Need to install triton nightly:
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
try:
from triton_fused_attention import attention as triton_attention
except ImportError:
triton_attention = None
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
b, nheads, seqlen_q, headdim = q.shape
_, nheads_kv, seqlen_k, _ = k.shape
assert v.shape == (b, nheads_kv, seqlen_k, headdim)
assert cudnn is not None, 'CUDNN is not available'
q_gpu, k_gpu, v_gpu = q, k, v
o_gpu, stats_gpu = o, stats
graph_forward = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(q.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_forward = graph_forward.tensor_like(q_gpu.detach())
k_forward = graph_forward.tensor_like(k_gpu.detach())
v_forward = graph_forward.tensor_like(v_gpu.detach())
seqlens_reshaped = seqlens if varlen else None
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
o_forward, stats_forward = graph_forward.sdpa(
name="sdpa",
q=q_forward,
k=k_forward,
v=v_forward,
is_inference=False,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
use_padding_mask=varlen,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
)
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
graph_forward.validate()
graph_forward.build_operation_graph()
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_forward.check_support()
graph_forward.build_plans()
variant_pack_forward = {
q_forward: q_gpu,
k_forward: k_gpu,
v_forward: v_gpu,
o_forward: o_gpu,
stats_forward: stats_gpu,
seq_len_q: seqlens_reshaped,
seq_len_kv: seqlens_reshaped,
}
dQ_gpu = torch.empty_like(q_gpu)
dK_gpu = torch.empty_like(k_gpu)
dV_gpu = torch.empty_like(v_gpu)
dO_gpu = grad
graph_backward = cudnn.pygraph(
io_data_type=cudnn.data_type.HALF,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q_backward = graph_backward.tensor_like(q_gpu.detach())
k_backward = graph_backward.tensor_like(k_gpu.detach())
v_backward = graph_backward.tensor_like(v_gpu.detach())
o_backward = graph_backward.tensor_like(o_gpu.detach())
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
name="sdpa_backward",
q=q_backward,
k=k_backward,
v=v_backward,
o=o_backward,
dO=dO_backward,
stats=stats_backward,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
use_padding_mask=varlen,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
)
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
graph_backward.validate()
graph_backward.build_operation_graph()
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph_backward.check_support()
graph_backward.build_plans()
variant_pack_backward = {
q_backward: q_gpu,
k_backward: k_gpu,
v_backward: v_gpu,
o_backward: o_gpu,
dO_backward: dO_gpu,
stats_backward: stats_gpu,
dQ_backward: dQ_gpu,
dK_backward: dK_gpu,
dV_backward: dV_gpu,
seq_len_q: seqlens_reshaped,
seq_len_kv: seqlens_reshaped,
}
workspace = torch.empty(
max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
device="cuda", dtype=torch.uint8
)
def run_fwd(*args, **kwargs):
graph_forward.execute(variant_pack_forward, workspace)
return o_gpu, stats_gpu
def run_bwd(*args, **kwargs):
graph_backward.execute(variant_pack_backward, workspace)
return dQ_gpu, dK_gpu, dV_gpu
return run_fwd, run_bwd
torch.manual_seed(0)
repeats = 100
dropout_p = 0.0
causal = False
dtype = torch.float16
device = 'cuda'
verbose = False
batch_size = 2
# seqlen = 2048
seqlen = 8192
# seqlen = 4096
# seqlen = 2047
dim = 2048
# headdim = 128
# headdim = 64
headdim = 256
for mode in ['fwd', 'bwd']:
# for mode in ['bwd']:
for headdim in [64, 128, 256]:
# for headdim in [128]:
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
# for seqlen in [8192]:
nheads = dim // headdim
# nheads = 24
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
# nheads = 16
# headdim = 128
nheads_kv = nheads
# nheads_kv = 1
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
grad_t = grad.transpose(1, 2).contiguous()
o_t = torch.empty_like(q.transpose(1, 2))
stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
for causal in [False, True]:
# for causal in [True]:
print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
# For var-seq-len
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
if headdim <= 128 and cudnn is not None:
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
if mode == 'bwd':
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
if headdim <= 128:
if triton_attention is not None and nheads_kv == nheads:
if mode == 'fwd':
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
# TODO: fix Triton numeric errors.
# if mode == 'bwd':
# dv, v_t.grad = v_t.grad.clone(), None
# dk, k_t.grad = k_t.grad.clone(), None
# dq, q_t.grad = q_t.grad.clone(), None
# torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
# torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
if cudnn is not None:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
if mode == 'fwd':
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
cudnn_sdpa_fwd()
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
cudnn_sdpa_fwd_varlen()
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
else:
cudnn_sdpa_fwd()
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
_, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
dq, dk, dv = cudnn_sdpa_bwd()
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
dq, dk, dv = cudnn_sdpa_bwd_varlen()
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
# pytorch_profiler(cudnn_sdpa, backward=False)
if headdim <= 128 or mode == 'fwd':
time.sleep(1)
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
time.sleep(1)
if mode == 'bwd':
dv, v.grad = v.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dq, q.grad = q.grad.clone(), None
torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
bench_var_fn = bench_fn
if mode == 'bwd':
grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
bench_var_fn = partial(benchmark_backward, grad=grad_var)
_, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
if headdim <= 128:
if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
if headdim <= 128 or mode == 'fwd':
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|