radna commited on
Commit
2a742fb
1 Parent(s): 6e45b59

Update triton_flash_atn.py

Browse files
Files changed (1) hide show
  1. triton_flash_atn.py +963 -964
triton_flash_atn.py CHANGED
@@ -1,964 +1,963 @@
1
- """
2
- Fused Attention
3
- ===============
4
-
5
- This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6
- Credits: OpenAI kernel team
7
-
8
- Extra Credits:
9
- - Original flash attention paper (https://arxiv.org/abs/2205.14135)
10
- - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
11
-
12
- """
13
-
14
- import pytest
15
- import torch
16
-
17
- import triton
18
- import triton.language as tl
19
-
20
- # Pick the fp8 data type
21
-
22
- # AMD E4M3B8
23
- # Note: When picking this f8 data type, scaling is required when using f8
24
- # for the second gemm
25
- # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
-
27
- # AMD E5M2B16
28
- TORCH_HAS_FP8E5B16 = hasattr(torch, "float8_e5m2fnuz")
29
-
30
-
31
- @triton.jit
32
- def _attn_fwd_inner(
33
- acc,
34
- l_i,
35
- m_i,
36
- q,
37
- K_block_ptr,
38
- V_block_ptr,
39
- start_m,
40
- BLOCK_M: tl.constexpr,
41
- BLOCK_DMODEL: tl.constexpr,
42
- BLOCK_N: tl.constexpr,
43
- STAGE: tl.constexpr,
44
- offs_m: tl.constexpr,
45
- offs_n: tl.constexpr,
46
- N_CTX,
47
- pre_load_v: tl.constexpr,
48
- ):
49
- # range of values handled by this stage
50
- if STAGE == 1:
51
- lo, hi = 0, start_m * BLOCK_M
52
- elif STAGE == 2:
53
- lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
54
- lo = tl.multiple_of(lo, BLOCK_M)
55
- K_block_ptr = tl.advance(K_block_ptr, (0, lo))
56
- V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
57
- # causal = False
58
- else:
59
- lo, hi = 0, N_CTX
60
- # loop over k, v and update accumulator
61
- for start_n in range(lo, hi, BLOCK_N):
62
- start_n = tl.multiple_of(start_n, BLOCK_N)
63
- # -- compute qk ----
64
- k = tl.load(K_block_ptr)
65
- if pre_load_v:
66
- v = tl.load(V_block_ptr)
67
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
68
- if STAGE == 2:
69
- mask = offs_m[:, None] >= (start_n + offs_n[None, :])
70
- qk = tl.where(mask, qk, float("-inf"))
71
- qk += tl.dot(q, k)
72
- m_ij = tl.maximum(m_i, tl.max(qk, 1))
73
- qk = qk - m_ij[:, None]
74
- p = tl.math.exp2(qk)
75
- # -- update output accumulator --
76
- alpha = tl.math.exp2(m_i - m_ij)
77
- acc = acc * alpha[:, None]
78
- if not pre_load_v:
79
- v = tl.load(V_block_ptr)
80
- acc += tl.dot(p.to(v.dtype), v)
81
- # -- update m_i and l_i
82
- l_ij = tl.sum(p, 1)
83
- l_i = l_i * alpha + l_ij
84
- # update m_i and l_i
85
- m_i = m_ij
86
- V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
87
- K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
88
- return acc, l_i, m_i
89
-
90
-
91
- # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
92
- # the code below and commenting out the equivalent parameters is convenient for
93
- # re-tuning.
94
- @triton.autotune(
95
- configs=[
96
- triton.Config(
97
- {
98
- "BLOCK_M": 64,
99
- "BLOCK_N": 16,
100
- "waves_per_eu": 2,
101
- "slice_k_tile": 0,
102
- "pre_load_v": False,
103
- },
104
- num_stages=1,
105
- num_warps=2,
106
- ),
107
- triton.Config(
108
- {
109
- "BLOCK_M": 64,
110
- "BLOCK_N": 16,
111
- "waves_per_eu": 2,
112
- "slice_k_tile": 32,
113
- "pre_load_v": False,
114
- },
115
- num_stages=1,
116
- num_warps=2,
117
- ),
118
- triton.Config(
119
- {
120
- "BLOCK_M": 32,
121
- "BLOCK_N": 32,
122
- "waves_per_eu": 2,
123
- "slice_k_tile": 0,
124
- "pre_load_v": False,
125
- },
126
- num_stages=1,
127
- num_warps=1,
128
- ),
129
- triton.Config(
130
- {
131
- "BLOCK_M": 32,
132
- "BLOCK_N": 32,
133
- "waves_per_eu": 2,
134
- "slice_k_tile": 32,
135
- "pre_load_v": False,
136
- },
137
- num_stages=1,
138
- num_warps=1,
139
- ),
140
- triton.Config(
141
- {
142
- "BLOCK_M": 64,
143
- "BLOCK_N": 32,
144
- "waves_per_eu": 2,
145
- "slice_k_tile": 0,
146
- "pre_load_v": False,
147
- },
148
- num_stages=1,
149
- num_warps=2,
150
- ),
151
- triton.Config(
152
- {
153
- "BLOCK_M": 32,
154
- "BLOCK_N": 16,
155
- "waves_per_eu": 3,
156
- "slice_k_tile": 0,
157
- "pre_load_v": True,
158
- },
159
- num_stages=1,
160
- num_warps=1,
161
- ),
162
- triton.Config(
163
- {
164
- "BLOCK_M": 32,
165
- "BLOCK_N": 16,
166
- "waves_per_eu": 3,
167
- "slice_k_tile": 0,
168
- "pre_load_v": False,
169
- },
170
- num_stages=1,
171
- num_warps=1,
172
- ),
173
- ],
174
- key=["Z", "H", "N_CTX", "STAGE", "BLOCK_DMODEL"],
175
- )
176
- @triton.jit
177
- def _attn_fwd(
178
- Q,
179
- K,
180
- V,
181
- sm_scale,
182
- M,
183
- Out,
184
- stride_qz,
185
- stride_qh,
186
- stride_qm,
187
- stride_qk,
188
- stride_kz,
189
- stride_kh,
190
- stride_kn,
191
- stride_kk,
192
- stride_vz,
193
- stride_vh,
194
- stride_vk,
195
- stride_vn,
196
- stride_oz,
197
- stride_oh,
198
- stride_om,
199
- stride_on,
200
- Z,
201
- H,
202
- N_CTX,
203
- BLOCK_DMODEL: tl.constexpr,
204
- STAGE: tl.constexpr,
205
- BLOCK_M: tl.constexpr,
206
- BLOCK_N: tl.constexpr,
207
- pre_load_v: tl.constexpr,
208
- ):
209
- start_m = tl.program_id(0)
210
- off_hz = tl.program_id(1)
211
- qvk_offset = off_hz * stride_qh
212
-
213
- # block pointers
214
- Q_block_ptr = tl.make_block_ptr(
215
- base=Q + qvk_offset,
216
- shape=(N_CTX, BLOCK_DMODEL),
217
- strides=(stride_qm, stride_qk),
218
- offsets=(start_m * BLOCK_M, 0),
219
- block_shape=(BLOCK_M, BLOCK_DMODEL),
220
- order=(1, 0),
221
- )
222
- V_block_ptr = tl.make_block_ptr(
223
- base=V + qvk_offset,
224
- shape=(N_CTX, BLOCK_DMODEL),
225
- strides=(stride_vk, stride_vn),
226
- offsets=(0, 0),
227
- block_shape=(BLOCK_N, BLOCK_DMODEL),
228
- order=(1, 0),
229
- )
230
- K_block_ptr = tl.make_block_ptr(
231
- base=K + qvk_offset,
232
- shape=(BLOCK_DMODEL, N_CTX),
233
- strides=(stride_kk, stride_kn),
234
- offsets=(0, 0),
235
- block_shape=(BLOCK_DMODEL, BLOCK_N),
236
- order=(0, 1),
237
- )
238
- O_block_ptr = tl.make_block_ptr(
239
- base=Out + qvk_offset,
240
- shape=(N_CTX, BLOCK_DMODEL),
241
- strides=(stride_om, stride_on),
242
- offsets=(start_m * BLOCK_M, 0),
243
- block_shape=(BLOCK_M, BLOCK_DMODEL),
244
- order=(1, 0),
245
- )
246
- # initialize offsets
247
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
248
- offs_n = tl.arange(0, BLOCK_N)
249
- # initialize pointer to m and l
250
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
251
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
252
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
253
- # scale sm_scale by log_2(e) and use
254
- # 2^x instead of exp in the loop because CSE and LICM
255
- # don't work as expected with `exp` in the loop
256
- qk_scale = sm_scale * 1.44269504
257
- # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
258
- q = tl.load(Q_block_ptr)
259
- q = (q * qk_scale).to(q.dtype)
260
- # stage 1: off-band
261
- # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
262
- # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
263
- if STAGE & 1:
264
- acc, l_i, m_i = _attn_fwd_inner(
265
- acc,
266
- l_i,
267
- m_i,
268
- q,
269
- K_block_ptr,
270
- V_block_ptr,
271
- start_m,
272
- BLOCK_M,
273
- BLOCK_DMODEL,
274
- BLOCK_N,
275
- 4 - STAGE,
276
- offs_m,
277
- offs_n,
278
- N_CTX,
279
- pre_load_v,
280
- )
281
- # stage 2: on-band
282
- if STAGE & 2:
283
- # barrier makes it easier for compielr to schedule the
284
- # two loops independently
285
- tl.debug_barrier()
286
- acc, l_i, m_i = _attn_fwd_inner(
287
- acc,
288
- l_i,
289
- m_i,
290
- q,
291
- K_block_ptr,
292
- V_block_ptr,
293
- start_m,
294
- BLOCK_M,
295
- BLOCK_DMODEL,
296
- BLOCK_N,
297
- 2,
298
- offs_m,
299
- offs_n,
300
- N_CTX,
301
- pre_load_v,
302
- )
303
- # epilogue
304
- # write back m
305
- acc = acc / l_i[:, None]
306
- m_ptrs = M + off_hz * N_CTX + offs_m
307
- tl.store(m_ptrs, m_i + tl.math.log2(l_i))
308
- tl.store(O_block_ptr, acc.to(Out.type.element_ty))
309
-
310
-
311
- @triton.jit
312
- def _attn_bwd_preprocess(
313
- O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
314
- ):
315
- off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
316
- off_hz = tl.program_id(1)
317
- off_n = tl.arange(0, D_HEAD)
318
- o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
319
- do = tl.load(
320
- DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]
321
- ).to(tl.float32)
322
- delta = tl.sum(o * do, axis=1)
323
- tl.store(Delta + off_hz * N_CTX + off_m, delta)
324
-
325
-
326
- # The main inner-loop logic for computing dK and dV.
327
- @triton.jit
328
- def _attn_bwd_dkdv(
329
- dk,
330
- dv,
331
- Q,
332
- k,
333
- v,
334
- sm_scale,
335
- DO,
336
- M,
337
- D,
338
- # shared by Q/K/V/DO.
339
- stride_tok,
340
- stride_d,
341
- H,
342
- N_CTX,
343
- BLOCK_M1: tl.constexpr,
344
- BLOCK_N1: tl.constexpr,
345
- BLOCK_DMODEL: tl.constexpr,
346
- # Filled in by the wrapper.
347
- start_n,
348
- start_m,
349
- num_steps,
350
- MASK: tl.constexpr,
351
- ):
352
- offs_m = start_m + tl.arange(0, BLOCK_M1)
353
- offs_n = start_n + tl.arange(0, BLOCK_N1)
354
- offs_k = tl.arange(0, BLOCK_DMODEL)
355
- QT_block_ptr = tl.make_block_ptr(
356
- base=Q,
357
- shape=(BLOCK_DMODEL, N_CTX),
358
- strides=(stride_d, stride_tok),
359
- offsets=(0, start_m),
360
- block_shape=(BLOCK_DMODEL, BLOCK_M1),
361
- order=(0, 1),
362
- )
363
- DO_block_ptr = tl.make_block_ptr(
364
- base=DO,
365
- shape=(N_CTX, BLOCK_DMODEL),
366
- strides=(stride_tok, stride_d),
367
- offsets=(start_m, 0),
368
- block_shape=(BLOCK_M1, BLOCK_DMODEL),
369
- order=(1, 0),
370
- )
371
- # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
372
- tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
373
- curr_m = start_m
374
- step_m = BLOCK_M1
375
- for blk_idx in range(num_steps):
376
- qT = tl.load(QT_block_ptr)
377
- # Load m before computing qk to reduce pipeline stall.
378
- offs_m = curr_m + tl.arange(0, BLOCK_M1)
379
- m = tl.load(M + offs_m)
380
- qkT = tl.dot(k, qT)
381
- pT = tl.math.exp2(qkT - m[None, :])
382
- # Autoregressive masking.
383
- if MASK:
384
- mask = offs_m[None, :] >= offs_n[:, None]
385
- pT = tl.where(mask, pT, 0.0)
386
- do = tl.load(DO_block_ptr)
387
- # Compute dV.
388
- ppT = pT
389
- ppT = ppT.to(tl.float16)
390
- dv += tl.dot(ppT, do)
391
- # D (= delta) is pre-divided by ds_scale.
392
- Di = tl.load(D + offs_m)
393
- # Compute dP and dS.
394
- dpT = tl.dot(v, tl.trans(do))
395
- dsT = pT * (dpT - Di[None, :])
396
- dsT = dsT.to(tl.float16)
397
- dk += tl.dot(dsT, tl.trans(qT))
398
- # Increment pointers.
399
- curr_m += step_m
400
- QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
401
- DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
402
- return dk, dv
403
-
404
-
405
- # the main inner-loop logic for computing dQ
406
- @triton.jit
407
- def _attn_bwd_dq(
408
- dq,
409
- q,
410
- K,
411
- V,
412
- do,
413
- m,
414
- D,
415
- # shared by Q/K/V/DO.
416
- stride_tok,
417
- stride_d,
418
- H,
419
- N_CTX,
420
- BLOCK_M2: tl.constexpr,
421
- BLOCK_N2: tl.constexpr,
422
- BLOCK_DMODEL: tl.constexpr,
423
- # Filled in by the wrapper.
424
- start_m,
425
- start_n,
426
- num_steps,
427
- MASK: tl.constexpr,
428
- ):
429
- offs_m = start_m + tl.arange(0, BLOCK_M2)
430
- offs_n = start_n + tl.arange(0, BLOCK_N2)
431
- offs_k = tl.arange(0, BLOCK_DMODEL)
432
- KT_block_ptr = tl.make_block_ptr(
433
- base=K,
434
- shape=(BLOCK_DMODEL, N_CTX),
435
- strides=(stride_d, stride_tok),
436
- offsets=(0, start_n),
437
- block_shape=(BLOCK_DMODEL, BLOCK_N2),
438
- order=(0, 1),
439
- )
440
- VT_block_ptr = tl.make_block_ptr(
441
- base=V,
442
- shape=(BLOCK_DMODEL, N_CTX),
443
- strides=(stride_d, stride_tok),
444
- offsets=(0, start_n),
445
- block_shape=(BLOCK_DMODEL, BLOCK_N2),
446
- order=(0, 1),
447
- )
448
- # D (= delta) is pre-divided by ds_scale.
449
- Di = tl.load(D + offs_m)
450
- # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
451
- tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
452
- curr_n = start_n
453
- step_n = BLOCK_N2
454
- for blk_idx in range(num_steps):
455
- kT = tl.load(KT_block_ptr)
456
- qk = tl.dot(q, kT)
457
- p = tl.math.exp2(qk - m)
458
- # Autoregressive masking.
459
- if MASK:
460
- offs_n = curr_n + tl.arange(0, BLOCK_N2)
461
- mask = offs_m[:, None] >= offs_n[None, :]
462
- p = tl.where(mask, p, 0.0)
463
- # Compute dP and dS.
464
- vT = tl.load(VT_block_ptr)
465
- dp = tl.dot(do, vT).to(tl.float32)
466
- ds = p * (dp - Di[:, None])
467
- ds = ds.to(tl.float16)
468
- # Compute dQ.
469
- # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
470
- dq += tl.dot(ds, tl.trans(kT))
471
- # Increment pointers.
472
- curr_n += step_n
473
- KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
474
- VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
475
- return dq
476
-
477
-
478
- @triton.autotune(
479
- configs=[
480
- triton.Config(
481
- {
482
- "BLOCK_M1": 32,
483
- "BLOCK_N1": 64,
484
- "BLOCK_M2": 64,
485
- "BLOCK_N2": 32,
486
- "BLK_SLICE_FACTOR": 1,
487
- },
488
- num_stages=1,
489
- num_warps=4,
490
- ),
491
- triton.Config(
492
- {
493
- "BLOCK_M1": 32,
494
- "BLOCK_N1": 64,
495
- "BLOCK_M2": 64,
496
- "BLOCK_N2": 32,
497
- "BLK_SLICE_FACTOR": 2,
498
- },
499
- num_stages=1,
500
- num_warps=4,
501
- ),
502
- triton.Config(
503
- {
504
- "BLOCK_M1": 64,
505
- "BLOCK_N1": 128,
506
- "BLOCK_M2": 128,
507
- "BLOCK_N2": 64,
508
- "BLK_SLICE_FACTOR": 1,
509
- },
510
- num_stages=1,
511
- num_warps=4,
512
- ),
513
- triton.Config(
514
- {
515
- "BLOCK_M1": 64,
516
- "BLOCK_N1": 128,
517
- "BLOCK_M2": 128,
518
- "BLOCK_N2": 64,
519
- "BLK_SLICE_FACTOR": 2,
520
- },
521
- num_stages=1,
522
- num_warps=4,
523
- ),
524
- triton.Config(
525
- {
526
- "BLOCK_M1": 64,
527
- "BLOCK_N1": 64,
528
- "BLOCK_M2": 64,
529
- "BLOCK_N2": 64,
530
- "BLK_SLICE_FACTOR": 1,
531
- },
532
- num_stages=1,
533
- num_warps=4,
534
- ),
535
- triton.Config(
536
- {
537
- "BLOCK_M1": 64,
538
- "BLOCK_N1": 64,
539
- "BLOCK_M2": 64,
540
- "BLOCK_N2": 64,
541
- "BLK_SLICE_FACTOR": 2,
542
- },
543
- num_stages=1,
544
- num_warps=4,
545
- ),
546
- triton.Config(
547
- {
548
- "BLOCK_M1": 32,
549
- "BLOCK_N1": 128,
550
- "BLOCK_M2": 128,
551
- "BLOCK_N2": 32,
552
- "BLK_SLICE_FACTOR": 1,
553
- },
554
- num_stages=1,
555
- num_warps=4,
556
- ),
557
- triton.Config(
558
- {
559
- "BLOCK_M1": 32,
560
- "BLOCK_N1": 128,
561
- "BLOCK_M2": 128,
562
- "BLOCK_N2": 32,
563
- "BLK_SLICE_FACTOR": 2,
564
- },
565
- num_stages=1,
566
- num_warps=4,
567
- ),
568
- triton.Config(
569
- {
570
- "BLOCK_M1": 32,
571
- "BLOCK_N1": 128,
572
- "BLOCK_M2": 128,
573
- "BLOCK_N2": 32,
574
- "BLK_SLICE_FACTOR": 2,
575
- },
576
- num_stages=1,
577
- num_warps=8,
578
- ),
579
- ],
580
- key=["H", "N_CTX", "BLOCK_DMODEL"],
581
- )
582
- @triton.jit
583
- def _attn_bwd(
584
- Q,
585
- K,
586
- V,
587
- sm_scale,
588
- DO,
589
- DQ,
590
- DK,
591
- DV,
592
- M,
593
- D,
594
- # shared by Q/K/V/DO.
595
- stride_z,
596
- stride_h,
597
- stride_tok,
598
- stride_d,
599
- # H = 16, N_CTX = 1024
600
- H,
601
- N_CTX,
602
- BLOCK_DMODEL: tl.constexpr,
603
- BLOCK_M1: tl.constexpr,
604
- BLOCK_N1: tl.constexpr,
605
- BLOCK_M2: tl.constexpr,
606
- BLOCK_N2: tl.constexpr,
607
- BLK_SLICE_FACTOR: tl.constexpr,
608
- ):
609
- LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
610
-
611
- bhid = tl.program_id(2)
612
- off_chz = (bhid * N_CTX).to(tl.int64)
613
- adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
614
- pid = tl.program_id(0)
615
-
616
- # offset pointers for batch/head
617
- Q += adj
618
- K += adj
619
- V += adj
620
- DO += adj
621
- DQ += adj
622
- DK += adj
623
- DV += adj
624
- M += off_chz
625
- D += off_chz
626
-
627
- offs_k = tl.arange(0, BLOCK_DMODEL)
628
-
629
- start_n = pid * BLOCK_N1
630
- # This assignment is important. It is what allows us to pick the diagonal
631
- # blocks. Later, when we want to do the lower triangular, we update start_m
632
- # after the first dkdv call.
633
- start_m = start_n
634
-
635
- MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
636
- offs_n = start_n + tl.arange(0, BLOCK_N1)
637
-
638
- dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
639
- dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
640
-
641
- K_block_ptr = tl.make_block_ptr(
642
- base=K,
643
- shape=(N_CTX, BLOCK_DMODEL),
644
- strides=(stride_tok, stride_d),
645
- offsets=(start_n, 0),
646
- block_shape=(BLOCK_N1, BLOCK_DMODEL),
647
- order=(1, 0),
648
- )
649
- V_block_ptr = tl.make_block_ptr(
650
- base=V,
651
- shape=(N_CTX, BLOCK_DMODEL),
652
- strides=(stride_tok, stride_d),
653
- offsets=(start_n, 0),
654
- block_shape=(BLOCK_N1, BLOCK_DMODEL),
655
- order=(1, 0),
656
- )
657
-
658
- # load K and V: they stay in SRAM throughout the inner loop for dkdv.
659
- k = tl.load(K_block_ptr)
660
- v = tl.load(V_block_ptr)
661
-
662
- num_steps = BLOCK_N1 // MASK_BLOCK_M1
663
-
664
- dk, dv = _attn_bwd_dkdv(
665
- dk,
666
- dv,
667
- Q,
668
- k,
669
- v,
670
- sm_scale,
671
- DO,
672
- M,
673
- D,
674
- stride_tok,
675
- stride_d,
676
- H,
677
- N_CTX,
678
- MASK_BLOCK_M1,
679
- BLOCK_N1,
680
- BLOCK_DMODEL,
681
- start_n,
682
- start_m,
683
- num_steps,
684
- MASK=True,
685
- )
686
-
687
- start_m += num_steps * MASK_BLOCK_M1
688
- num_steps = (N_CTX - start_m) // BLOCK_M1
689
-
690
- # Compute dK and dV for non-masked blocks.
691
- dk, dv = _attn_bwd_dkdv(
692
- dk,
693
- dv,
694
- Q,
695
- k,
696
- v,
697
- sm_scale,
698
- DO,
699
- M,
700
- D,
701
- stride_tok,
702
- stride_d,
703
- H,
704
- N_CTX,
705
- BLOCK_M1,
706
- BLOCK_N1,
707
- BLOCK_DMODEL,
708
- start_n,
709
- start_m,
710
- num_steps,
711
- MASK=False,
712
- )
713
-
714
- DV_block_ptrs = tl.make_block_ptr(
715
- base=DV,
716
- shape=(N_CTX, BLOCK_DMODEL),
717
- strides=(stride_tok, stride_d),
718
- offsets=(start_n, 0),
719
- block_shape=(BLOCK_N1, BLOCK_DMODEL),
720
- order=(1, 0),
721
- )
722
- tl.store(DV_block_ptrs, dv.to(tl.float16))
723
-
724
- # Write back dK.
725
- dk *= sm_scale
726
- DK_block_ptrs = tl.make_block_ptr(
727
- base=DK,
728
- shape=(N_CTX, BLOCK_DMODEL),
729
- strides=(stride_tok, stride_d),
730
- offsets=(start_n, 0),
731
- block_shape=(BLOCK_N1, BLOCK_DMODEL),
732
- order=(1, 0),
733
- )
734
- tl.store(DK_block_ptrs, dk.to(tl.float16))
735
-
736
- # THIS BLOCK DOES DQ:
737
- start_m = pid * BLOCK_M2
738
- end_n = start_m + BLOCK_M2
739
-
740
- MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
741
- offs_m = start_m + tl.arange(0, BLOCK_M2)
742
-
743
- Q_block_ptr = tl.make_block_ptr(
744
- base=Q,
745
- shape=(N_CTX, BLOCK_DMODEL),
746
- strides=(stride_tok, stride_d),
747
- offsets=(start_m, 0),
748
- block_shape=(BLOCK_M2, BLOCK_DMODEL),
749
- order=(1, 0),
750
- )
751
-
752
- DO_block_ptr = tl.make_block_ptr(
753
- base=DO,
754
- shape=(N_CTX, BLOCK_DMODEL),
755
- strides=(stride_tok, stride_d),
756
- offsets=(start_m, 0),
757
- block_shape=(BLOCK_M2, BLOCK_DMODEL),
758
- order=(1, 0),
759
- )
760
- q = tl.load(Q_block_ptr)
761
- do = tl.load(DO_block_ptr)
762
- dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
763
-
764
- m = tl.load(M + offs_m)
765
- m = m[:, None]
766
-
767
- # Compute dQ for masked (diagonal) blocks.
768
- # NOTE: This code scans each row of QK^T backward (from right to left,
769
- # but inside each call to _attn_bwd_dq, from left to right), but that's
770
- # not due to anything important. I just wanted to reuse the loop
771
- # structure for dK & dV above as much as possible.
772
- num_steps = BLOCK_M2 // MASK_BLOCK_N2
773
- dq = _attn_bwd_dq(
774
- dq,
775
- q,
776
- K,
777
- V,
778
- do,
779
- m,
780
- D,
781
- stride_tok,
782
- stride_d,
783
- H,
784
- N_CTX,
785
- BLOCK_M2,
786
- MASK_BLOCK_N2,
787
- BLOCK_DMODEL,
788
- start_m,
789
- end_n - num_steps * MASK_BLOCK_N2,
790
- num_steps,
791
- MASK=True,
792
- )
793
- end_n -= num_steps * MASK_BLOCK_N2
794
- # stage 2
795
- num_steps = end_n // BLOCK_N2
796
- dq = _attn_bwd_dq(
797
- dq,
798
- q,
799
- K,
800
- V,
801
- do,
802
- m,
803
- D,
804
- stride_tok,
805
- stride_d,
806
- H,
807
- N_CTX,
808
- BLOCK_M2,
809
- BLOCK_N2,
810
- BLOCK_DMODEL,
811
- start_m,
812
- end_n - num_steps * BLOCK_N2,
813
- num_steps,
814
- MASK=False,
815
- )
816
- # Write back dQ.
817
- DQ_block_ptr = tl.make_block_ptr(
818
- base=DQ,
819
- shape=(N_CTX, BLOCK_DMODEL),
820
- strides=(stride_tok, stride_d),
821
- offsets=(start_m, 0),
822
- block_shape=(BLOCK_M2, BLOCK_DMODEL),
823
- order=(1, 0),
824
- )
825
- dq *= LN2
826
- tl.store(DQ_block_ptr, dq.to(tl.float16))
827
-
828
-
829
- empty = torch.empty(128, device="cuda")
830
-
831
-
832
- class _attention(torch.autograd.Function):
833
-
834
- @staticmethod
835
- def forward(ctx, q, k, v, causal, sm_scale):
836
- # shape constraints
837
- Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
838
- assert Lq == Lk and Lk == Lv
839
- assert Lk in {16, 32, 64, 128}
840
- o = torch.empty_like(q, dtype=v.dtype)
841
- if torch.version.hip is None:
842
- BLOCK_M = 128
843
- BLOCK_N = 64 if Lk <= 64 else 32
844
- num_stages = 4 if Lk <= 64 else 3
845
- num_warps = 4 if Lk <= 64 else 8
846
- # Tuning for H100
847
- if torch.cuda.get_device_capability()[0] == 9:
848
- num_warps = 8
849
- num_stages = 7 if Lk >= 64 else 3
850
- stage = 3 if causal else 1
851
-
852
- def grid(META):
853
- return (
854
- triton.cdiv(q.shape[2], META["BLOCK_M"]),
855
- q.shape[0] * q.shape[1],
856
- 1,
857
- )
858
-
859
- M = torch.empty(
860
- (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
861
- )
862
- _attn_fwd[grid](
863
- q,
864
- k,
865
- v,
866
- sm_scale,
867
- M,
868
- o,
869
- q.stride(0),
870
- q.stride(1),
871
- q.stride(2),
872
- q.stride(3),
873
- k.stride(0),
874
- k.stride(1),
875
- k.stride(2),
876
- k.stride(3),
877
- v.stride(0),
878
- v.stride(1),
879
- v.stride(2),
880
- v.stride(3),
881
- o.stride(0),
882
- o.stride(1),
883
- o.stride(2),
884
- o.stride(3),
885
- q.shape[0],
886
- q.shape[1],
887
- N_CTX=q.shape[2],
888
- BLOCK_DMODEL=Lk,
889
- STAGE=stage,
890
- )
891
-
892
- # restore the grid for bwd kernel
893
- best_config = _attn_fwd.get_best_config()
894
- block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
895
- grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
896
-
897
- ctx.save_for_backward(q, k, v, o, M)
898
- ctx.grid = grid
899
- ctx.sm_scale = sm_scale
900
- ctx.BLOCK_DMODEL = Lk
901
- ctx.causal = causal
902
- return o
903
-
904
- @staticmethod
905
- def backward(ctx, do):
906
- if torch.version.hip is not None:
907
- BLOCK = 64
908
- else:
909
- BLOCK = 128
910
- q, k, v, o, M = ctx.saved_tensors
911
- assert do.is_contiguous()
912
- assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
913
- dq = torch.empty_like(q)
914
- dk = torch.empty_like(k)
915
- dv = torch.empty_like(v)
916
- BATCH, N_HEAD, N_CTX = q.shape[:3]
917
- PRE_BLOCK = 128
918
- NUM_WARPS, NUM_STAGES = 4, 1
919
- BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
920
- BLK_SLICE_FACTOR = 2
921
- RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
922
- arg_k = k
923
- arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
924
- assert N_CTX % PRE_BLOCK == 0
925
- pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
926
- delta = torch.empty_like(M)
927
- _attn_bwd_preprocess[pre_grid](
928
- o,
929
- do,
930
- delta,
931
- BATCH,
932
- N_HEAD,
933
- N_CTX,
934
- BLOCK_M=PRE_BLOCK,
935
- D_HEAD=ctx.BLOCK_DMODEL,
936
- )
937
-
938
- def grid(META):
939
- return (triton.cdiv(N_CTX, META["BLOCK_N1"]), 1, BATCH * N_HEAD)
940
-
941
- _attn_bwd[grid](
942
- q,
943
- arg_k,
944
- v,
945
- ctx.sm_scale,
946
- do,
947
- dq,
948
- dk,
949
- dv,
950
- M,
951
- delta,
952
- q.stride(0),
953
- q.stride(1),
954
- q.stride(2),
955
- q.stride(3),
956
- N_HEAD,
957
- N_CTX,
958
- BLOCK_DMODEL=ctx.BLOCK_DMODEL,
959
- )
960
-
961
- return dq, dk, dv, None, None
962
-
963
-
964
- attention = _attention.apply
 
1
+ """
2
+ Fused Attention
3
+ ===============
4
+
5
+ This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6
+ Credits: OpenAI kernel team
7
+
8
+ Extra Credits:
9
+ - Original flash attention paper (https://arxiv.org/abs/2205.14135)
10
+ - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
11
+
12
+ """
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ # Pick the fp8 data type
21
+
22
+ # AMD E4M3B8
23
+ # Note: When picking this f8 data type, scaling is required when using f8
24
+ # for the second gemm
25
+ # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
+
27
+ # AMD E5M2B16
28
+ TORCH_HAS_FP8E5B16 = hasattr(torch, "float8_e5m2fnuz")
29
+
30
+
31
+ @triton.jit
32
+ def _attn_fwd_inner(
33
+ acc,
34
+ l_i,
35
+ m_i,
36
+ q,
37
+ K_block_ptr,
38
+ V_block_ptr,
39
+ start_m,
40
+ BLOCK_M: tl.constexpr,
41
+ BLOCK_DMODEL: tl.constexpr,
42
+ BLOCK_N: tl.constexpr,
43
+ STAGE: tl.constexpr,
44
+ offs_m: tl.constexpr,
45
+ offs_n: tl.constexpr,
46
+ N_CTX,
47
+ pre_load_v: tl.constexpr,
48
+ ):
49
+ # range of values handled by this stage
50
+ if STAGE == 1:
51
+ lo, hi = 0, start_m * BLOCK_M
52
+ elif STAGE == 2:
53
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
54
+ lo = tl.multiple_of(lo, BLOCK_M)
55
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
56
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
57
+ # causal = False
58
+ else:
59
+ lo, hi = 0, N_CTX
60
+ # loop over k, v and update accumulator
61
+ for start_n in range(lo, hi, BLOCK_N):
62
+ start_n = tl.multiple_of(start_n, BLOCK_N)
63
+ # -- compute qk ----
64
+ k = tl.load(K_block_ptr)
65
+ if pre_load_v:
66
+ v = tl.load(V_block_ptr)
67
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
68
+ if STAGE == 2:
69
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
70
+ qk = tl.where(mask, qk, float("-inf"))
71
+ qk += tl.dot(q, k)
72
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
73
+ qk = qk - m_ij[:, None]
74
+ p = tl.math.exp2(qk)
75
+ # -- update output accumulator --
76
+ alpha = tl.math.exp2(m_i - m_ij)
77
+ acc = acc * alpha[:, None]
78
+ if not pre_load_v:
79
+ v = tl.load(V_block_ptr)
80
+ acc += tl.dot(p.to(v.dtype), v)
81
+ # -- update m_i and l_i
82
+ l_ij = tl.sum(p, 1)
83
+ l_i = l_i * alpha + l_ij
84
+ # update m_i and l_i
85
+ m_i = m_ij
86
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
87
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
88
+ return acc, l_i, m_i
89
+
90
+
91
+ # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
92
+ # the code below and commenting out the equivalent parameters is convenient for
93
+ # re-tuning.
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config(
97
+ {
98
+ "BLOCK_M": 64,
99
+ "BLOCK_N": 16,
100
+ "waves_per_eu": 2,
101
+ "slice_k_tile": 0,
102
+ "pre_load_v": False,
103
+ },
104
+ num_stages=1,
105
+ num_warps=2,
106
+ ),
107
+ triton.Config(
108
+ {
109
+ "BLOCK_M": 64,
110
+ "BLOCK_N": 16,
111
+ "waves_per_eu": 2,
112
+ "slice_k_tile": 32,
113
+ "pre_load_v": False,
114
+ },
115
+ num_stages=1,
116
+ num_warps=2,
117
+ ),
118
+ triton.Config(
119
+ {
120
+ "BLOCK_M": 32,
121
+ "BLOCK_N": 32,
122
+ "waves_per_eu": 2,
123
+ "slice_k_tile": 0,
124
+ "pre_load_v": False,
125
+ },
126
+ num_stages=1,
127
+ num_warps=1,
128
+ ),
129
+ triton.Config(
130
+ {
131
+ "BLOCK_M": 32,
132
+ "BLOCK_N": 32,
133
+ "waves_per_eu": 2,
134
+ "slice_k_tile": 32,
135
+ "pre_load_v": False,
136
+ },
137
+ num_stages=1,
138
+ num_warps=1,
139
+ ),
140
+ triton.Config(
141
+ {
142
+ "BLOCK_M": 64,
143
+ "BLOCK_N": 32,
144
+ "waves_per_eu": 2,
145
+ "slice_k_tile": 0,
146
+ "pre_load_v": False,
147
+ },
148
+ num_stages=1,
149
+ num_warps=2,
150
+ ),
151
+ triton.Config(
152
+ {
153
+ "BLOCK_M": 32,
154
+ "BLOCK_N": 16,
155
+ "waves_per_eu": 3,
156
+ "slice_k_tile": 0,
157
+ "pre_load_v": True,
158
+ },
159
+ num_stages=1,
160
+ num_warps=1,
161
+ ),
162
+ triton.Config(
163
+ {
164
+ "BLOCK_M": 32,
165
+ "BLOCK_N": 16,
166
+ "waves_per_eu": 3,
167
+ "slice_k_tile": 0,
168
+ "pre_load_v": False,
169
+ },
170
+ num_stages=1,
171
+ num_warps=1,
172
+ ),
173
+ ],
174
+ key=["Z", "H", "N_CTX", "STAGE", "BLOCK_DMODEL"],
175
+ )
176
+ @triton.jit
177
+ def _attn_fwd(
178
+ Q,
179
+ K,
180
+ V,
181
+ sm_scale,
182
+ M,
183
+ Out,
184
+ stride_qz,
185
+ stride_qh,
186
+ stride_qm,
187
+ stride_qk,
188
+ stride_kz,
189
+ stride_kh,
190
+ stride_kn,
191
+ stride_kk,
192
+ stride_vz,
193
+ stride_vh,
194
+ stride_vk,
195
+ stride_vn,
196
+ stride_oz,
197
+ stride_oh,
198
+ stride_om,
199
+ stride_on,
200
+ Z,
201
+ H,
202
+ N_CTX,
203
+ BLOCK_DMODEL: tl.constexpr,
204
+ STAGE: tl.constexpr,
205
+ BLOCK_M: tl.constexpr,
206
+ BLOCK_N: tl.constexpr,
207
+ pre_load_v: tl.constexpr,
208
+ ):
209
+ start_m = tl.program_id(0)
210
+ off_hz = tl.program_id(1)
211
+ qvk_offset = off_hz * stride_qh
212
+
213
+ # block pointers
214
+ Q_block_ptr = tl.make_block_ptr(
215
+ base=Q + qvk_offset,
216
+ shape=(N_CTX, BLOCK_DMODEL),
217
+ strides=(stride_qm, stride_qk),
218
+ offsets=(start_m * BLOCK_M, 0),
219
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
220
+ order=(1, 0),
221
+ )
222
+ V_block_ptr = tl.make_block_ptr(
223
+ base=V + qvk_offset,
224
+ shape=(N_CTX, BLOCK_DMODEL),
225
+ strides=(stride_vk, stride_vn),
226
+ offsets=(0, 0),
227
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
228
+ order=(1, 0),
229
+ )
230
+ K_block_ptr = tl.make_block_ptr(
231
+ base=K + qvk_offset,
232
+ shape=(BLOCK_DMODEL, N_CTX),
233
+ strides=(stride_kk, stride_kn),
234
+ offsets=(0, 0),
235
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
236
+ order=(0, 1),
237
+ )
238
+ O_block_ptr = tl.make_block_ptr(
239
+ base=Out + qvk_offset,
240
+ shape=(N_CTX, BLOCK_DMODEL),
241
+ strides=(stride_om, stride_on),
242
+ offsets=(start_m * BLOCK_M, 0),
243
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
244
+ order=(1, 0),
245
+ )
246
+ # initialize offsets
247
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
248
+ offs_n = tl.arange(0, BLOCK_N)
249
+ # initialize pointer to m and l
250
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
251
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
252
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
253
+ # scale sm_scale by log_2(e) and use
254
+ # 2^x instead of exp in the loop because CSE and LICM
255
+ # don't work as expected with `exp` in the loop
256
+ qk_scale = sm_scale * 1.44269504
257
+ # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
258
+ q = tl.load(Q_block_ptr)
259
+ q = (q * qk_scale).to(q.dtype)
260
+ # stage 1: off-band
261
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
262
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
263
+ if STAGE & 1:
264
+ acc, l_i, m_i = _attn_fwd_inner(
265
+ acc,
266
+ l_i,
267
+ m_i,
268
+ q,
269
+ K_block_ptr,
270
+ V_block_ptr,
271
+ start_m,
272
+ BLOCK_M,
273
+ BLOCK_DMODEL,
274
+ BLOCK_N,
275
+ 4 - STAGE,
276
+ offs_m,
277
+ offs_n,
278
+ N_CTX,
279
+ pre_load_v,
280
+ )
281
+ # stage 2: on-band
282
+ if STAGE & 2:
283
+ # barrier makes it easier for compielr to schedule the
284
+ # two loops independently
285
+ tl.debug_barrier()
286
+ acc, l_i, m_i = _attn_fwd_inner(
287
+ acc,
288
+ l_i,
289
+ m_i,
290
+ q,
291
+ K_block_ptr,
292
+ V_block_ptr,
293
+ start_m,
294
+ BLOCK_M,
295
+ BLOCK_DMODEL,
296
+ BLOCK_N,
297
+ 2,
298
+ offs_m,
299
+ offs_n,
300
+ N_CTX,
301
+ pre_load_v,
302
+ )
303
+ # epilogue
304
+ # write back m
305
+ acc = acc / l_i[:, None]
306
+ m_ptrs = M + off_hz * N_CTX + offs_m
307
+ tl.store(m_ptrs, m_i + tl.math.log2(l_i))
308
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
309
+
310
+
311
+ @triton.jit
312
+ def _attn_bwd_preprocess(
313
+ O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
314
+ ):
315
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
316
+ off_hz = tl.program_id(1)
317
+ off_n = tl.arange(0, D_HEAD)
318
+ o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
319
+ do = tl.load(
320
+ DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]
321
+ ).to(tl.float32)
322
+ delta = tl.sum(o * do, axis=1)
323
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
324
+
325
+
326
+ # The main inner-loop logic for computing dK and dV.
327
+ @triton.jit
328
+ def _attn_bwd_dkdv(
329
+ dk,
330
+ dv,
331
+ Q,
332
+ k,
333
+ v,
334
+ sm_scale,
335
+ DO,
336
+ M,
337
+ D,
338
+ # shared by Q/K/V/DO.
339
+ stride_tok,
340
+ stride_d,
341
+ H,
342
+ N_CTX,
343
+ BLOCK_M1: tl.constexpr,
344
+ BLOCK_N1: tl.constexpr,
345
+ BLOCK_DMODEL: tl.constexpr,
346
+ # Filled in by the wrapper.
347
+ start_n,
348
+ start_m,
349
+ num_steps,
350
+ MASK: tl.constexpr,
351
+ ):
352
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
353
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
354
+ offs_k = tl.arange(0, BLOCK_DMODEL)
355
+ QT_block_ptr = tl.make_block_ptr(
356
+ base=Q,
357
+ shape=(BLOCK_DMODEL, N_CTX),
358
+ strides=(stride_d, stride_tok),
359
+ offsets=(0, start_m),
360
+ block_shape=(BLOCK_DMODEL, BLOCK_M1),
361
+ order=(0, 1),
362
+ )
363
+ DO_block_ptr = tl.make_block_ptr(
364
+ base=DO,
365
+ shape=(N_CTX, BLOCK_DMODEL),
366
+ strides=(stride_tok, stride_d),
367
+ offsets=(start_m, 0),
368
+ block_shape=(BLOCK_M1, BLOCK_DMODEL),
369
+ order=(1, 0),
370
+ )
371
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
372
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
373
+ curr_m = start_m
374
+ step_m = BLOCK_M1
375
+ for blk_idx in range(num_steps):
376
+ qT = tl.load(QT_block_ptr)
377
+ # Load m before computing qk to reduce pipeline stall.
378
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
379
+ m = tl.load(M + offs_m)
380
+ qkT = tl.dot(k, qT)
381
+ pT = tl.math.exp2(qkT - m[None, :])
382
+ # Autoregressive masking.
383
+ if MASK:
384
+ mask = offs_m[None, :] >= offs_n[:, None]
385
+ pT = tl.where(mask, pT, 0.0)
386
+ do = tl.load(DO_block_ptr)
387
+ # Compute dV.
388
+ ppT = pT
389
+ ppT = ppT.to(tl.float16)
390
+ dv += tl.dot(ppT, do)
391
+ # D (= delta) is pre-divided by ds_scale.
392
+ Di = tl.load(D + offs_m)
393
+ # Compute dP and dS.
394
+ dpT = tl.dot(v, tl.trans(do))
395
+ dsT = pT * (dpT - Di[None, :])
396
+ dsT = dsT.to(tl.float16)
397
+ dk += tl.dot(dsT, tl.trans(qT))
398
+ # Increment pointers.
399
+ curr_m += step_m
400
+ QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
401
+ DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
402
+ return dk, dv
403
+
404
+
405
+ # the main inner-loop logic for computing dQ
406
+ @triton.jit
407
+ def _attn_bwd_dq(
408
+ dq,
409
+ q,
410
+ K,
411
+ V,
412
+ do,
413
+ m,
414
+ D,
415
+ # shared by Q/K/V/DO.
416
+ stride_tok,
417
+ stride_d,
418
+ H,
419
+ N_CTX,
420
+ BLOCK_M2: tl.constexpr,
421
+ BLOCK_N2: tl.constexpr,
422
+ BLOCK_DMODEL: tl.constexpr,
423
+ # Filled in by the wrapper.
424
+ start_m,
425
+ start_n,
426
+ num_steps,
427
+ MASK: tl.constexpr,
428
+ ):
429
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
430
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
431
+ offs_k = tl.arange(0, BLOCK_DMODEL)
432
+ KT_block_ptr = tl.make_block_ptr(
433
+ base=K,
434
+ shape=(BLOCK_DMODEL, N_CTX),
435
+ strides=(stride_d, stride_tok),
436
+ offsets=(0, start_n),
437
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
438
+ order=(0, 1),
439
+ )
440
+ VT_block_ptr = tl.make_block_ptr(
441
+ base=V,
442
+ shape=(BLOCK_DMODEL, N_CTX),
443
+ strides=(stride_d, stride_tok),
444
+ offsets=(0, start_n),
445
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
446
+ order=(0, 1),
447
+ )
448
+ # D (= delta) is pre-divided by ds_scale.
449
+ Di = tl.load(D + offs_m)
450
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
451
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
452
+ curr_n = start_n
453
+ step_n = BLOCK_N2
454
+ for blk_idx in range(num_steps):
455
+ kT = tl.load(KT_block_ptr)
456
+ qk = tl.dot(q, kT)
457
+ p = tl.math.exp2(qk - m)
458
+ # Autoregressive masking.
459
+ if MASK:
460
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
461
+ mask = offs_m[:, None] >= offs_n[None, :]
462
+ p = tl.where(mask, p, 0.0)
463
+ # Compute dP and dS.
464
+ vT = tl.load(VT_block_ptr)
465
+ dp = tl.dot(do, vT).to(tl.float32)
466
+ ds = p * (dp - Di[:, None])
467
+ ds = ds.to(tl.float16)
468
+ # Compute dQ.
469
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
470
+ dq += tl.dot(ds, tl.trans(kT))
471
+ # Increment pointers.
472
+ curr_n += step_n
473
+ KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
474
+ VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
475
+ return dq
476
+
477
+
478
+ @triton.autotune(
479
+ configs=[
480
+ triton.Config(
481
+ {
482
+ "BLOCK_M1": 32,
483
+ "BLOCK_N1": 64,
484
+ "BLOCK_M2": 64,
485
+ "BLOCK_N2": 32,
486
+ "BLK_SLICE_FACTOR": 1,
487
+ },
488
+ num_stages=1,
489
+ num_warps=4,
490
+ ),
491
+ triton.Config(
492
+ {
493
+ "BLOCK_M1": 32,
494
+ "BLOCK_N1": 64,
495
+ "BLOCK_M2": 64,
496
+ "BLOCK_N2": 32,
497
+ "BLK_SLICE_FACTOR": 2,
498
+ },
499
+ num_stages=1,
500
+ num_warps=4,
501
+ ),
502
+ triton.Config(
503
+ {
504
+ "BLOCK_M1": 64,
505
+ "BLOCK_N1": 128,
506
+ "BLOCK_M2": 128,
507
+ "BLOCK_N2": 64,
508
+ "BLK_SLICE_FACTOR": 1,
509
+ },
510
+ num_stages=1,
511
+ num_warps=4,
512
+ ),
513
+ triton.Config(
514
+ {
515
+ "BLOCK_M1": 64,
516
+ "BLOCK_N1": 128,
517
+ "BLOCK_M2": 128,
518
+ "BLOCK_N2": 64,
519
+ "BLK_SLICE_FACTOR": 2,
520
+ },
521
+ num_stages=1,
522
+ num_warps=4,
523
+ ),
524
+ triton.Config(
525
+ {
526
+ "BLOCK_M1": 64,
527
+ "BLOCK_N1": 64,
528
+ "BLOCK_M2": 64,
529
+ "BLOCK_N2": 64,
530
+ "BLK_SLICE_FACTOR": 1,
531
+ },
532
+ num_stages=1,
533
+ num_warps=4,
534
+ ),
535
+ triton.Config(
536
+ {
537
+ "BLOCK_M1": 64,
538
+ "BLOCK_N1": 64,
539
+ "BLOCK_M2": 64,
540
+ "BLOCK_N2": 64,
541
+ "BLK_SLICE_FACTOR": 2,
542
+ },
543
+ num_stages=1,
544
+ num_warps=4,
545
+ ),
546
+ triton.Config(
547
+ {
548
+ "BLOCK_M1": 32,
549
+ "BLOCK_N1": 128,
550
+ "BLOCK_M2": 128,
551
+ "BLOCK_N2": 32,
552
+ "BLK_SLICE_FACTOR": 1,
553
+ },
554
+ num_stages=1,
555
+ num_warps=4,
556
+ ),
557
+ triton.Config(
558
+ {
559
+ "BLOCK_M1": 32,
560
+ "BLOCK_N1": 128,
561
+ "BLOCK_M2": 128,
562
+ "BLOCK_N2": 32,
563
+ "BLK_SLICE_FACTOR": 2,
564
+ },
565
+ num_stages=1,
566
+ num_warps=4,
567
+ ),
568
+ triton.Config(
569
+ {
570
+ "BLOCK_M1": 32,
571
+ "BLOCK_N1": 128,
572
+ "BLOCK_M2": 128,
573
+ "BLOCK_N2": 32,
574
+ "BLK_SLICE_FACTOR": 2,
575
+ },
576
+ num_stages=1,
577
+ num_warps=8,
578
+ ),
579
+ ],
580
+ key=["H", "N_CTX", "BLOCK_DMODEL"],
581
+ )
582
+ @triton.jit
583
+ def _attn_bwd(
584
+ Q,
585
+ K,
586
+ V,
587
+ sm_scale,
588
+ DO,
589
+ DQ,
590
+ DK,
591
+ DV,
592
+ M,
593
+ D,
594
+ # shared by Q/K/V/DO.
595
+ stride_z,
596
+ stride_h,
597
+ stride_tok,
598
+ stride_d,
599
+ # H = 16, N_CTX = 1024
600
+ H,
601
+ N_CTX,
602
+ BLOCK_DMODEL: tl.constexpr,
603
+ BLOCK_M1: tl.constexpr,
604
+ BLOCK_N1: tl.constexpr,
605
+ BLOCK_M2: tl.constexpr,
606
+ BLOCK_N2: tl.constexpr,
607
+ BLK_SLICE_FACTOR: tl.constexpr,
608
+ ):
609
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
610
+
611
+ bhid = tl.program_id(2)
612
+ off_chz = (bhid * N_CTX).to(tl.int64)
613
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
614
+ pid = tl.program_id(0)
615
+
616
+ # offset pointers for batch/head
617
+ Q += adj
618
+ K += adj
619
+ V += adj
620
+ DO += adj
621
+ DQ += adj
622
+ DK += adj
623
+ DV += adj
624
+ M += off_chz
625
+ D += off_chz
626
+
627
+ offs_k = tl.arange(0, BLOCK_DMODEL)
628
+
629
+ start_n = pid * BLOCK_N1
630
+ # This assignment is important. It is what allows us to pick the diagonal
631
+ # blocks. Later, when we want to do the lower triangular, we update start_m
632
+ # after the first dkdv call.
633
+ start_m = start_n
634
+
635
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
636
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
637
+
638
+ dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
639
+ dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
640
+
641
+ K_block_ptr = tl.make_block_ptr(
642
+ base=K,
643
+ shape=(N_CTX, BLOCK_DMODEL),
644
+ strides=(stride_tok, stride_d),
645
+ offsets=(start_n, 0),
646
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
647
+ order=(1, 0),
648
+ )
649
+ V_block_ptr = tl.make_block_ptr(
650
+ base=V,
651
+ shape=(N_CTX, BLOCK_DMODEL),
652
+ strides=(stride_tok, stride_d),
653
+ offsets=(start_n, 0),
654
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
655
+ order=(1, 0),
656
+ )
657
+
658
+ # load K and V: they stay in SRAM throughout the inner loop for dkdv.
659
+ k = tl.load(K_block_ptr)
660
+ v = tl.load(V_block_ptr)
661
+
662
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
663
+
664
+ dk, dv = _attn_bwd_dkdv(
665
+ dk,
666
+ dv,
667
+ Q,
668
+ k,
669
+ v,
670
+ sm_scale,
671
+ DO,
672
+ M,
673
+ D,
674
+ stride_tok,
675
+ stride_d,
676
+ H,
677
+ N_CTX,
678
+ MASK_BLOCK_M1,
679
+ BLOCK_N1,
680
+ BLOCK_DMODEL,
681
+ start_n,
682
+ start_m,
683
+ num_steps,
684
+ MASK=True,
685
+ )
686
+
687
+ start_m += num_steps * MASK_BLOCK_M1
688
+ num_steps = (N_CTX - start_m) // BLOCK_M1
689
+
690
+ # Compute dK and dV for non-masked blocks.
691
+ dk, dv = _attn_bwd_dkdv(
692
+ dk,
693
+ dv,
694
+ Q,
695
+ k,
696
+ v,
697
+ sm_scale,
698
+ DO,
699
+ M,
700
+ D,
701
+ stride_tok,
702
+ stride_d,
703
+ H,
704
+ N_CTX,
705
+ BLOCK_M1,
706
+ BLOCK_N1,
707
+ BLOCK_DMODEL,
708
+ start_n,
709
+ start_m,
710
+ num_steps,
711
+ MASK=False,
712
+ )
713
+
714
+ DV_block_ptrs = tl.make_block_ptr(
715
+ base=DV,
716
+ shape=(N_CTX, BLOCK_DMODEL),
717
+ strides=(stride_tok, stride_d),
718
+ offsets=(start_n, 0),
719
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
720
+ order=(1, 0),
721
+ )
722
+ tl.store(DV_block_ptrs, dv.to(tl.float16))
723
+
724
+ # Write back dK.
725
+ dk *= sm_scale
726
+ DK_block_ptrs = tl.make_block_ptr(
727
+ base=DK,
728
+ shape=(N_CTX, BLOCK_DMODEL),
729
+ strides=(stride_tok, stride_d),
730
+ offsets=(start_n, 0),
731
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
732
+ order=(1, 0),
733
+ )
734
+ tl.store(DK_block_ptrs, dk.to(tl.float16))
735
+
736
+ # THIS BLOCK DOES DQ:
737
+ start_m = pid * BLOCK_M2
738
+ end_n = start_m + BLOCK_M2
739
+
740
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
741
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
742
+
743
+ Q_block_ptr = tl.make_block_ptr(
744
+ base=Q,
745
+ shape=(N_CTX, BLOCK_DMODEL),
746
+ strides=(stride_tok, stride_d),
747
+ offsets=(start_m, 0),
748
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
749
+ order=(1, 0),
750
+ )
751
+
752
+ DO_block_ptr = tl.make_block_ptr(
753
+ base=DO,
754
+ shape=(N_CTX, BLOCK_DMODEL),
755
+ strides=(stride_tok, stride_d),
756
+ offsets=(start_m, 0),
757
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
758
+ order=(1, 0),
759
+ )
760
+ q = tl.load(Q_block_ptr)
761
+ do = tl.load(DO_block_ptr)
762
+ dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
763
+
764
+ m = tl.load(M + offs_m)
765
+ m = m[:, None]
766
+
767
+ # Compute dQ for masked (diagonal) blocks.
768
+ # NOTE: This code scans each row of QK^T backward (from right to left,
769
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
770
+ # not due to anything important. I just wanted to reuse the loop
771
+ # structure for dK & dV above as much as possible.
772
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
773
+ dq = _attn_bwd_dq(
774
+ dq,
775
+ q,
776
+ K,
777
+ V,
778
+ do,
779
+ m,
780
+ D,
781
+ stride_tok,
782
+ stride_d,
783
+ H,
784
+ N_CTX,
785
+ BLOCK_M2,
786
+ MASK_BLOCK_N2,
787
+ BLOCK_DMODEL,
788
+ start_m,
789
+ end_n - num_steps * MASK_BLOCK_N2,
790
+ num_steps,
791
+ MASK=True,
792
+ )
793
+ end_n -= num_steps * MASK_BLOCK_N2
794
+ # stage 2
795
+ num_steps = end_n // BLOCK_N2
796
+ dq = _attn_bwd_dq(
797
+ dq,
798
+ q,
799
+ K,
800
+ V,
801
+ do,
802
+ m,
803
+ D,
804
+ stride_tok,
805
+ stride_d,
806
+ H,
807
+ N_CTX,
808
+ BLOCK_M2,
809
+ BLOCK_N2,
810
+ BLOCK_DMODEL,
811
+ start_m,
812
+ end_n - num_steps * BLOCK_N2,
813
+ num_steps,
814
+ MASK=False,
815
+ )
816
+ # Write back dQ.
817
+ DQ_block_ptr = tl.make_block_ptr(
818
+ base=DQ,
819
+ shape=(N_CTX, BLOCK_DMODEL),
820
+ strides=(stride_tok, stride_d),
821
+ offsets=(start_m, 0),
822
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
823
+ order=(1, 0),
824
+ )
825
+ dq *= LN2
826
+ tl.store(DQ_block_ptr, dq.to(tl.float16))
827
+
828
+
829
+ empty = torch.empty(128, device="cuda")
830
+
831
+
832
+ class _attention(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ def forward(ctx, q, k, v, causal, sm_scale):
836
+ # shape constraints
837
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
838
+ assert Lq == Lk and Lk == Lv
839
+ assert Lk in {16, 32, 64, 128}
840
+ o = torch.empty_like(q, dtype=v.dtype)
841
+ if torch.version.hip is None:
842
+ BLOCK_M = 128
843
+ BLOCK_N = 64 if Lk <= 64 else 32
844
+ num_stages = 4 if Lk <= 64 else 3
845
+ num_warps = 4 if Lk <= 64 else 8
846
+ # Tuning for H100
847
+ if torch.cuda.get_device_capability()[0] == 9:
848
+ num_warps = 8
849
+ num_stages = 7 if Lk >= 64 else 3
850
+ stage = 3 if causal else 1
851
+
852
+ def grid(META):
853
+ return (
854
+ triton.cdiv(q.shape[2], META["BLOCK_M"]),
855
+ q.shape[0] * q.shape[1],
856
+ 1,
857
+ )
858
+
859
+ M = torch.empty(
860
+ (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
861
+ )
862
+ _attn_fwd[grid](
863
+ q,
864
+ k,
865
+ v,
866
+ sm_scale,
867
+ M,
868
+ o,
869
+ q.stride(0),
870
+ q.stride(1),
871
+ q.stride(2),
872
+ q.stride(3),
873
+ k.stride(0),
874
+ k.stride(1),
875
+ k.stride(2),
876
+ k.stride(3),
877
+ v.stride(0),
878
+ v.stride(1),
879
+ v.stride(2),
880
+ v.stride(3),
881
+ o.stride(0),
882
+ o.stride(1),
883
+ o.stride(2),
884
+ o.stride(3),
885
+ q.shape[0],
886
+ q.shape[1],
887
+ N_CTX=q.shape[2],
888
+ BLOCK_DMODEL=Lk,
889
+ STAGE=stage,
890
+ )
891
+
892
+ # restore the grid for bwd kernel
893
+ best_config = _attn_fwd.get_best_config()
894
+ block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
895
+ grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
896
+
897
+ ctx.save_for_backward(q, k, v, o, M)
898
+ ctx.grid = grid
899
+ ctx.sm_scale = sm_scale
900
+ ctx.BLOCK_DMODEL = Lk
901
+ ctx.causal = causal
902
+ return o
903
+
904
+ @staticmethod
905
+ def backward(ctx, do):
906
+ if torch.version.hip is not None:
907
+ BLOCK = 64
908
+ else:
909
+ BLOCK = 128
910
+ q, k, v, o, M = ctx.saved_tensors
911
+ assert do.is_contiguous()
912
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
913
+ dq = torch.empty_like(q)
914
+ dk = torch.empty_like(k)
915
+ dv = torch.empty_like(v)
916
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
917
+ PRE_BLOCK = 128
918
+ NUM_WARPS, NUM_STAGES = 4, 1
919
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
920
+ BLK_SLICE_FACTOR = 2
921
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
922
+ arg_k = k
923
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
924
+ assert N_CTX % PRE_BLOCK == 0
925
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
926
+ delta = torch.empty_like(M)
927
+ _attn_bwd_preprocess[pre_grid](
928
+ o,
929
+ do,
930
+ delta,
931
+ BATCH,
932
+ N_HEAD,
933
+ N_CTX,
934
+ BLOCK_M=PRE_BLOCK,
935
+ D_HEAD=ctx.BLOCK_DMODEL,
936
+ )
937
+
938
+ def grid(META):
939
+ return (triton.cdiv(N_CTX, META["BLOCK_N1"]), 1, BATCH * N_HEAD)
940
+
941
+ _attn_bwd[grid](
942
+ q,
943
+ arg_k,
944
+ v,
945
+ ctx.sm_scale,
946
+ do,
947
+ dq,
948
+ dk,
949
+ dv,
950
+ M,
951
+ delta,
952
+ q.stride(0),
953
+ q.stride(1),
954
+ q.stride(2),
955
+ q.stride(3),
956
+ N_HEAD,
957
+ N_CTX,
958
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL,
959
+ )
960
+
961
+ return dq, dk, dv, None, None
962
+
963
+