radna commited on
Commit
82b99e4
1 Parent(s): 2a742fb

Update triton_flash_atn.py

Browse files
Files changed (1) hide show
  1. triton_flash_atn.py +184 -493
triton_flash_atn.py CHANGED
@@ -25,27 +25,17 @@ import triton.language as tl
25
  # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
 
27
  # AMD E5M2B16
28
- TORCH_HAS_FP8E5B16 = hasattr(torch, "float8_e5m2fnuz")
29
 
30
 
31
  @triton.jit
32
- def _attn_fwd_inner(
33
- acc,
34
- l_i,
35
- m_i,
36
- q,
37
- K_block_ptr,
38
- V_block_ptr,
39
- start_m,
40
- BLOCK_M: tl.constexpr,
41
- BLOCK_DMODEL: tl.constexpr,
42
- BLOCK_N: tl.constexpr,
43
- STAGE: tl.constexpr,
44
- offs_m: tl.constexpr,
45
- offs_n: tl.constexpr,
46
- N_CTX,
47
- pre_load_v: tl.constexpr,
48
- ):
49
  # range of values handled by this stage
50
  if STAGE == 1:
51
  lo, hi = 0, start_m * BLOCK_M
@@ -93,119 +83,37 @@ def _attn_fwd_inner(
93
  # re-tuning.
94
  @triton.autotune(
95
  configs=[
96
- triton.Config(
97
- {
98
- "BLOCK_M": 64,
99
- "BLOCK_N": 16,
100
- "waves_per_eu": 2,
101
- "slice_k_tile": 0,
102
- "pre_load_v": False,
103
- },
104
- num_stages=1,
105
- num_warps=2,
106
- ),
107
- triton.Config(
108
- {
109
- "BLOCK_M": 64,
110
- "BLOCK_N": 16,
111
- "waves_per_eu": 2,
112
- "slice_k_tile": 32,
113
- "pre_load_v": False,
114
- },
115
- num_stages=1,
116
- num_warps=2,
117
- ),
118
- triton.Config(
119
- {
120
- "BLOCK_M": 32,
121
- "BLOCK_N": 32,
122
- "waves_per_eu": 2,
123
- "slice_k_tile": 0,
124
- "pre_load_v": False,
125
- },
126
- num_stages=1,
127
- num_warps=1,
128
- ),
129
- triton.Config(
130
- {
131
- "BLOCK_M": 32,
132
- "BLOCK_N": 32,
133
- "waves_per_eu": 2,
134
- "slice_k_tile": 32,
135
- "pre_load_v": False,
136
- },
137
- num_stages=1,
138
- num_warps=1,
139
- ),
140
- triton.Config(
141
- {
142
- "BLOCK_M": 64,
143
- "BLOCK_N": 32,
144
- "waves_per_eu": 2,
145
- "slice_k_tile": 0,
146
- "pre_load_v": False,
147
- },
148
- num_stages=1,
149
- num_warps=2,
150
- ),
151
- triton.Config(
152
- {
153
- "BLOCK_M": 32,
154
- "BLOCK_N": 16,
155
- "waves_per_eu": 3,
156
- "slice_k_tile": 0,
157
- "pre_load_v": True,
158
- },
159
- num_stages=1,
160
- num_warps=1,
161
- ),
162
- triton.Config(
163
- {
164
- "BLOCK_M": 32,
165
- "BLOCK_N": 16,
166
- "waves_per_eu": 3,
167
- "slice_k_tile": 0,
168
- "pre_load_v": False,
169
- },
170
- num_stages=1,
171
- num_warps=1,
172
- ),
173
  ],
174
- key=["Z", "H", "N_CTX", "STAGE", "BLOCK_DMODEL"],
175
  )
176
  @triton.jit
177
- def _attn_fwd(
178
- Q,
179
- K,
180
- V,
181
- sm_scale,
182
- M,
183
- Out,
184
- stride_qz,
185
- stride_qh,
186
- stride_qm,
187
- stride_qk,
188
- stride_kz,
189
- stride_kh,
190
- stride_kn,
191
- stride_kk,
192
- stride_vz,
193
- stride_vh,
194
- stride_vk,
195
- stride_vn,
196
- stride_oz,
197
- stride_oh,
198
- stride_om,
199
- stride_on,
200
- Z,
201
- H,
202
- N_CTX,
203
- BLOCK_DMODEL: tl.constexpr,
204
- STAGE: tl.constexpr,
205
- BLOCK_M: tl.constexpr,
206
- BLOCK_N: tl.constexpr,
207
- pre_load_v: tl.constexpr,
208
- ):
209
  start_m = tl.program_id(0)
210
  off_hz = tl.program_id(1)
211
  qvk_offset = off_hz * stride_qh
@@ -261,45 +169,23 @@ def _attn_fwd(
261
  # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
262
  # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
263
  if STAGE & 1:
264
- acc, l_i, m_i = _attn_fwd_inner(
265
- acc,
266
- l_i,
267
- m_i,
268
- q,
269
- K_block_ptr,
270
- V_block_ptr,
271
- start_m,
272
- BLOCK_M,
273
- BLOCK_DMODEL,
274
- BLOCK_N,
275
- 4 - STAGE,
276
- offs_m,
277
- offs_n,
278
- N_CTX,
279
- pre_load_v,
280
- )
281
  # stage 2: on-band
282
  if STAGE & 2:
283
  # barrier makes it easier for compielr to schedule the
284
  # two loops independently
285
  tl.debug_barrier()
286
- acc, l_i, m_i = _attn_fwd_inner(
287
- acc,
288
- l_i,
289
- m_i,
290
- q,
291
- K_block_ptr,
292
- V_block_ptr,
293
- start_m,
294
- BLOCK_M,
295
- BLOCK_DMODEL,
296
- BLOCK_N,
297
- 2,
298
- offs_m,
299
- offs_n,
300
- N_CTX,
301
- pre_load_v,
302
- )
303
  # epilogue
304
  # write back m
305
  acc = acc / l_i[:, None]
@@ -309,46 +195,36 @@ def _attn_fwd(
309
 
310
 
311
  @triton.jit
312
- def _attn_bwd_preprocess(
313
- O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
314
- ):
 
 
315
  off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
316
  off_hz = tl.program_id(1)
317
  off_n = tl.arange(0, D_HEAD)
318
- o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
319
- do = tl.load(
320
- DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]
321
- ).to(tl.float32)
322
  delta = tl.sum(o * do, axis=1)
323
  tl.store(Delta + off_hz * N_CTX + off_m, delta)
324
 
325
 
326
  # The main inner-loop logic for computing dK and dV.
327
  @triton.jit
328
- def _attn_bwd_dkdv(
329
- dk,
330
- dv,
331
- Q,
332
- k,
333
- v,
334
- sm_scale,
335
- DO,
336
- M,
337
- D,
338
- # shared by Q/K/V/DO.
339
- stride_tok,
340
- stride_d,
341
- H,
342
- N_CTX,
343
- BLOCK_M1: tl.constexpr,
344
- BLOCK_N1: tl.constexpr,
345
- BLOCK_DMODEL: tl.constexpr,
346
- # Filled in by the wrapper.
347
- start_n,
348
- start_m,
349
- num_steps,
350
- MASK: tl.constexpr,
351
- ):
352
  offs_m = start_m + tl.arange(0, BLOCK_M1)
353
  offs_n = start_n + tl.arange(0, BLOCK_N1)
354
  offs_k = tl.arange(0, BLOCK_DMODEL)
@@ -358,7 +234,7 @@ def _attn_bwd_dkdv(
358
  strides=(stride_d, stride_tok),
359
  offsets=(0, start_m),
360
  block_shape=(BLOCK_DMODEL, BLOCK_M1),
361
- order=(0, 1),
362
  )
363
  DO_block_ptr = tl.make_block_ptr(
364
  base=DO,
@@ -366,7 +242,7 @@ def _attn_bwd_dkdv(
366
  strides=(stride_tok, stride_d),
367
  offsets=(start_m, 0),
368
  block_shape=(BLOCK_M1, BLOCK_DMODEL),
369
- order=(1, 0),
370
  )
371
  # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
372
  tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
@@ -381,7 +257,7 @@ def _attn_bwd_dkdv(
381
  pT = tl.math.exp2(qkT - m[None, :])
382
  # Autoregressive masking.
383
  if MASK:
384
- mask = offs_m[None, :] >= offs_n[:, None]
385
  pT = tl.where(mask, pT, 0.0)
386
  do = tl.load(DO_block_ptr)
387
  # Compute dV.
@@ -404,28 +280,17 @@ def _attn_bwd_dkdv(
404
 
405
  # the main inner-loop logic for computing dQ
406
  @triton.jit
407
- def _attn_bwd_dq(
408
- dq,
409
- q,
410
- K,
411
- V,
412
- do,
413
- m,
414
- D,
415
- # shared by Q/K/V/DO.
416
- stride_tok,
417
- stride_d,
418
- H,
419
- N_CTX,
420
- BLOCK_M2: tl.constexpr,
421
- BLOCK_N2: tl.constexpr,
422
- BLOCK_DMODEL: tl.constexpr,
423
- # Filled in by the wrapper.
424
- start_m,
425
- start_n,
426
- num_steps,
427
- MASK: tl.constexpr,
428
- ):
429
  offs_m = start_m + tl.arange(0, BLOCK_M2)
430
  offs_n = start_n + tl.arange(0, BLOCK_N2)
431
  offs_k = tl.arange(0, BLOCK_DMODEL)
@@ -435,7 +300,7 @@ def _attn_bwd_dq(
435
  strides=(stride_d, stride_tok),
436
  offsets=(0, start_n),
437
  block_shape=(BLOCK_DMODEL, BLOCK_N2),
438
- order=(0, 1),
439
  )
440
  VT_block_ptr = tl.make_block_ptr(
441
  base=V,
@@ -443,7 +308,7 @@ def _attn_bwd_dq(
443
  strides=(stride_d, stride_tok),
444
  offsets=(0, start_n),
445
  block_shape=(BLOCK_DMODEL, BLOCK_N2),
446
- order=(0, 1),
447
  )
448
  # D (= delta) is pre-divided by ds_scale.
449
  Di = tl.load(D + offs_m)
@@ -458,7 +323,7 @@ def _attn_bwd_dq(
458
  # Autoregressive masking.
459
  if MASK:
460
  offs_n = curr_n + tl.arange(0, BLOCK_N2)
461
- mask = offs_m[:, None] >= offs_n[None, :]
462
  p = tl.where(mask, p, 0.0)
463
  # Compute dP and dS.
464
  vT = tl.load(VT_block_ptr)
@@ -477,135 +342,42 @@ def _attn_bwd_dq(
477
 
478
  @triton.autotune(
479
  configs=[
480
- triton.Config(
481
- {
482
- "BLOCK_M1": 32,
483
- "BLOCK_N1": 64,
484
- "BLOCK_M2": 64,
485
- "BLOCK_N2": 32,
486
- "BLK_SLICE_FACTOR": 1,
487
- },
488
- num_stages=1,
489
- num_warps=4,
490
- ),
491
- triton.Config(
492
- {
493
- "BLOCK_M1": 32,
494
- "BLOCK_N1": 64,
495
- "BLOCK_M2": 64,
496
- "BLOCK_N2": 32,
497
- "BLK_SLICE_FACTOR": 2,
498
- },
499
- num_stages=1,
500
- num_warps=4,
501
- ),
502
- triton.Config(
503
- {
504
- "BLOCK_M1": 64,
505
- "BLOCK_N1": 128,
506
- "BLOCK_M2": 128,
507
- "BLOCK_N2": 64,
508
- "BLK_SLICE_FACTOR": 1,
509
- },
510
- num_stages=1,
511
- num_warps=4,
512
- ),
513
- triton.Config(
514
- {
515
- "BLOCK_M1": 64,
516
- "BLOCK_N1": 128,
517
- "BLOCK_M2": 128,
518
- "BLOCK_N2": 64,
519
- "BLK_SLICE_FACTOR": 2,
520
- },
521
- num_stages=1,
522
- num_warps=4,
523
- ),
524
- triton.Config(
525
- {
526
- "BLOCK_M1": 64,
527
- "BLOCK_N1": 64,
528
- "BLOCK_M2": 64,
529
- "BLOCK_N2": 64,
530
- "BLK_SLICE_FACTOR": 1,
531
- },
532
- num_stages=1,
533
- num_warps=4,
534
- ),
535
- triton.Config(
536
- {
537
- "BLOCK_M1": 64,
538
- "BLOCK_N1": 64,
539
- "BLOCK_M2": 64,
540
- "BLOCK_N2": 64,
541
- "BLK_SLICE_FACTOR": 2,
542
- },
543
- num_stages=1,
544
- num_warps=4,
545
- ),
546
- triton.Config(
547
- {
548
- "BLOCK_M1": 32,
549
- "BLOCK_N1": 128,
550
- "BLOCK_M2": 128,
551
- "BLOCK_N2": 32,
552
- "BLK_SLICE_FACTOR": 1,
553
- },
554
- num_stages=1,
555
- num_warps=4,
556
- ),
557
- triton.Config(
558
- {
559
- "BLOCK_M1": 32,
560
- "BLOCK_N1": 128,
561
- "BLOCK_M2": 128,
562
- "BLOCK_N2": 32,
563
- "BLK_SLICE_FACTOR": 2,
564
- },
565
- num_stages=1,
566
- num_warps=4,
567
- ),
568
- triton.Config(
569
- {
570
- "BLOCK_M1": 32,
571
- "BLOCK_N1": 128,
572
- "BLOCK_M2": 128,
573
- "BLOCK_N2": 32,
574
- "BLK_SLICE_FACTOR": 2,
575
- },
576
- num_stages=1,
577
- num_warps=8,
578
- ),
579
  ],
580
- key=["H", "N_CTX", "BLOCK_DMODEL"],
581
  )
582
  @triton.jit
583
- def _attn_bwd(
584
- Q,
585
- K,
586
- V,
587
- sm_scale,
588
- DO,
589
- DQ,
590
- DK,
591
- DV,
592
- M,
593
- D,
594
- # shared by Q/K/V/DO.
595
- stride_z,
596
- stride_h,
597
- stride_tok,
598
- stride_d,
599
- # H = 16, N_CTX = 1024
600
- H,
601
- N_CTX,
602
- BLOCK_DMODEL: tl.constexpr,
603
- BLOCK_M1: tl.constexpr,
604
- BLOCK_N1: tl.constexpr,
605
- BLOCK_M2: tl.constexpr,
606
- BLOCK_N2: tl.constexpr,
607
- BLK_SLICE_FACTOR: tl.constexpr,
608
- ):
609
  LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
610
 
611
  bhid = tl.program_id(2)
@@ -661,54 +433,31 @@ def _attn_bwd(
661
 
662
  num_steps = BLOCK_N1 // MASK_BLOCK_M1
663
 
664
- dk, dv = _attn_bwd_dkdv(
665
- dk,
666
- dv,
667
- Q,
668
- k,
669
- v,
670
- sm_scale,
671
- DO,
672
- M,
673
- D,
674
- stride_tok,
675
- stride_d,
676
- H,
677
- N_CTX,
678
- MASK_BLOCK_M1,
679
- BLOCK_N1,
680
- BLOCK_DMODEL,
681
- start_n,
682
- start_m,
683
- num_steps,
684
- MASK=True,
685
- )
686
 
687
  start_m += num_steps * MASK_BLOCK_M1
688
  num_steps = (N_CTX - start_m) // BLOCK_M1
689
 
690
  # Compute dK and dV for non-masked blocks.
691
  dk, dv = _attn_bwd_dkdv(
692
- dk,
693
- dv,
694
- Q,
695
- k,
696
- v,
697
- sm_scale,
698
  DO,
699
- M,
700
- D,
701
- stride_tok,
702
- stride_d,
703
- H,
704
- N_CTX,
705
- BLOCK_M1,
706
- BLOCK_N1,
707
- BLOCK_DMODEL,
708
- start_n,
709
- start_m,
710
- num_steps,
711
- MASK=False,
712
  )
713
 
714
  DV_block_ptrs = tl.make_block_ptr(
@@ -717,7 +466,7 @@ def _attn_bwd(
717
  strides=(stride_tok, stride_d),
718
  offsets=(start_n, 0),
719
  block_shape=(BLOCK_N1, BLOCK_DMODEL),
720
- order=(1, 0),
721
  )
722
  tl.store(DV_block_ptrs, dv.to(tl.float16))
723
 
@@ -729,7 +478,7 @@ def _attn_bwd(
729
  strides=(stride_tok, stride_d),
730
  offsets=(start_n, 0),
731
  block_shape=(BLOCK_N1, BLOCK_DMODEL),
732
- order=(1, 0),
733
  )
734
  tl.store(DK_block_ptrs, dk.to(tl.float16))
735
 
@@ -746,7 +495,7 @@ def _attn_bwd(
746
  strides=(stride_tok, stride_d),
747
  offsets=(start_m, 0),
748
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
749
- order=(1, 0),
750
  )
751
 
752
  DO_block_ptr = tl.make_block_ptr(
@@ -755,7 +504,7 @@ def _attn_bwd(
755
  strides=(stride_tok, stride_d),
756
  offsets=(start_m, 0),
757
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
758
- order=(1, 0),
759
  )
760
  q = tl.load(Q_block_ptr)
761
  do = tl.load(DO_block_ptr)
@@ -770,49 +519,25 @@ def _attn_bwd(
770
  # not due to anything important. I just wanted to reuse the loop
771
  # structure for dK & dV above as much as possible.
772
  num_steps = BLOCK_M2 // MASK_BLOCK_N2
773
- dq = _attn_bwd_dq(
774
- dq,
775
- q,
776
- K,
777
- V,
778
- do,
779
- m,
780
- D,
781
- stride_tok,
782
- stride_d,
783
- H,
784
- N_CTX,
785
- BLOCK_M2,
786
- MASK_BLOCK_N2,
787
- BLOCK_DMODEL,
788
- start_m,
789
- end_n - num_steps * MASK_BLOCK_N2,
790
- num_steps,
791
- MASK=True,
792
- )
793
  end_n -= num_steps * MASK_BLOCK_N2
794
  # stage 2
795
  num_steps = end_n // BLOCK_N2
796
- dq = _attn_bwd_dq(
797
- dq,
798
- q,
799
- K,
800
- V,
801
- do,
802
- m,
803
- D,
804
- stride_tok,
805
- stride_d,
806
- H,
807
- N_CTX,
808
- BLOCK_M2,
809
- BLOCK_N2,
810
- BLOCK_DMODEL,
811
- start_m,
812
- end_n - num_steps * BLOCK_N2,
813
- num_steps,
814
- MASK=False,
815
- )
816
  # Write back dQ.
817
  DQ_block_ptr = tl.make_block_ptr(
818
  base=DQ,
@@ -820,7 +545,7 @@ def _attn_bwd(
820
  strides=(stride_tok, stride_d),
821
  offsets=(start_m, 0),
822
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
823
- order=(1, 0),
824
  )
825
  dq *= LN2
826
  tl.store(DQ_block_ptr, dq.to(tl.float16))
@@ -849,41 +574,20 @@ class _attention(torch.autograd.Function):
849
  num_stages = 7 if Lk >= 64 else 3
850
  stage = 3 if causal else 1
851
 
852
- def grid(META):
853
- return (
854
- triton.cdiv(q.shape[2], META["BLOCK_M"]),
855
- q.shape[0] * q.shape[1],
856
- 1,
857
- )
858
-
859
- M = torch.empty(
860
- (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
861
  )
 
 
862
  _attn_fwd[grid](
863
- q,
864
- k,
865
- v,
866
- sm_scale,
867
- M,
868
- o,
869
- q.stride(0),
870
- q.stride(1),
871
- q.stride(2),
872
- q.stride(3),
873
- k.stride(0),
874
- k.stride(1),
875
- k.stride(2),
876
- k.stride(3),
877
- v.stride(0),
878
- v.stride(1),
879
- v.stride(2),
880
- v.stride(3),
881
- o.stride(0),
882
- o.stride(1),
883
- o.stride(2),
884
- o.stride(3),
885
- q.shape[0],
886
- q.shape[1],
887
  N_CTX=q.shape[2],
888
  BLOCK_DMODEL=Lk,
889
  STAGE=stage,
@@ -925,39 +629,26 @@ class _attention(torch.autograd.Function):
925
  pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
926
  delta = torch.empty_like(M)
927
  _attn_bwd_preprocess[pre_grid](
928
- o,
929
- do,
930
  delta,
931
- BATCH,
932
- N_HEAD,
933
- N_CTX,
934
- BLOCK_M=PRE_BLOCK,
935
- D_HEAD=ctx.BLOCK_DMODEL,
936
  )
937
 
938
- def grid(META):
939
- return (triton.cdiv(N_CTX, META["BLOCK_N1"]), 1, BATCH * N_HEAD)
940
-
 
 
941
  _attn_bwd[grid](
942
- q,
943
- arg_k,
944
- v,
945
- ctx.sm_scale,
946
- do,
947
- dq,
948
- dk,
949
- dv,
950
- M,
951
- delta,
952
- q.stride(0),
953
- q.stride(1),
954
- q.stride(2),
955
- q.stride(3),
956
- N_HEAD,
957
- N_CTX,
958
- BLOCK_DMODEL=ctx.BLOCK_DMODEL,
959
  )
960
 
961
  return dq, dk, dv, None, None
962
 
963
 
 
 
25
  # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
 
27
  # AMD E5M2B16
28
+ TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
29
 
30
 
31
  @triton.jit
32
+ def _attn_fwd_inner(acc, l_i, m_i, q,
33
+ K_block_ptr, V_block_ptr,
34
+ start_m,
35
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
36
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
37
+ N_CTX,
38
+ pre_load_v: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
39
  # range of values handled by this stage
40
  if STAGE == 1:
41
  lo, hi = 0, start_m * BLOCK_M
 
83
  # re-tuning.
84
  @triton.autotune(
85
  configs=[
86
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
87
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
88
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
89
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
90
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
91
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
92
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
93
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
94
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
95
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
96
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
97
+ 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
98
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
99
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ],
101
+ key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
102
  )
103
  @triton.jit
104
+ def _attn_fwd(Q, K, V, sm_scale, M, Out,
105
+ stride_qz, stride_qh, stride_qm, stride_qk,
106
+ stride_kz, stride_kh, stride_kn, stride_kk,
107
+ stride_vz, stride_vh, stride_vk, stride_vn,
108
+ stride_oz, stride_oh, stride_om, stride_on,
109
+ Z, H,
110
+ N_CTX,
111
+ BLOCK_DMODEL: tl.constexpr,
112
+ STAGE: tl.constexpr,
113
+ BLOCK_M: tl.constexpr,
114
+ BLOCK_N: tl.constexpr,
115
+ pre_load_v: tl.constexpr,
116
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  start_m = tl.program_id(0)
118
  off_hz = tl.program_id(1)
119
  qvk_offset = off_hz * stride_qh
 
169
  # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
170
  # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
171
  if STAGE & 1:
172
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
173
+ start_m,
174
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
175
+ 4 - STAGE, offs_m, offs_n, N_CTX,
176
+ pre_load_v,
177
+ )
 
 
 
 
 
 
 
 
 
 
 
178
  # stage 2: on-band
179
  if STAGE & 2:
180
  # barrier makes it easier for compielr to schedule the
181
  # two loops independently
182
  tl.debug_barrier()
183
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
184
+ start_m,
185
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
186
+ 2, offs_m, offs_n, N_CTX,
187
+ pre_load_v,
188
+ )
 
 
 
 
 
 
 
 
 
 
 
189
  # epilogue
190
  # write back m
191
  acc = acc / l_i[:, None]
 
195
 
196
 
197
  @triton.jit
198
+ def _attn_bwd_preprocess(O, DO,
199
+ Delta,
200
+ Z, H, N_CTX,
201
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
202
+ ):
203
  off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
204
  off_hz = tl.program_id(1)
205
  off_n = tl.arange(0, D_HEAD)
206
+ o = tl.load(O + off_hz * D_HEAD * N_CTX +
207
+ off_m[:, None] * D_HEAD + off_n[None, :])
208
+ do = tl.load(DO + off_hz * D_HEAD * N_CTX +
209
+ off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
210
  delta = tl.sum(o * do, axis=1)
211
  tl.store(Delta + off_hz * N_CTX + off_m, delta)
212
 
213
 
214
  # The main inner-loop logic for computing dK and dV.
215
  @triton.jit
216
+ def _attn_bwd_dkdv(dk, dv,
217
+ Q, k, v, sm_scale,
218
+ DO,
219
+ M, D,
220
+ # shared by Q/K/V/DO.
221
+ stride_tok, stride_d,
222
+ H, N_CTX, BLOCK_M1: tl.constexpr,
223
+ BLOCK_N1: tl.constexpr,
224
+ BLOCK_DMODEL: tl.constexpr,
225
+ # Filled in by the wrapper.
226
+ start_n, start_m, num_steps,
227
+ MASK: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
228
  offs_m = start_m + tl.arange(0, BLOCK_M1)
229
  offs_n = start_n + tl.arange(0, BLOCK_N1)
230
  offs_k = tl.arange(0, BLOCK_DMODEL)
 
234
  strides=(stride_d, stride_tok),
235
  offsets=(0, start_m),
236
  block_shape=(BLOCK_DMODEL, BLOCK_M1),
237
+ order=(0, 1)
238
  )
239
  DO_block_ptr = tl.make_block_ptr(
240
  base=DO,
 
242
  strides=(stride_tok, stride_d),
243
  offsets=(start_m, 0),
244
  block_shape=(BLOCK_M1, BLOCK_DMODEL),
245
+ order=(1, 0)
246
  )
247
  # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
248
  tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
 
257
  pT = tl.math.exp2(qkT - m[None, :])
258
  # Autoregressive masking.
259
  if MASK:
260
+ mask = (offs_m[None, :] >= offs_n[:, None])
261
  pT = tl.where(mask, pT, 0.0)
262
  do = tl.load(DO_block_ptr)
263
  # Compute dV.
 
280
 
281
  # the main inner-loop logic for computing dQ
282
  @triton.jit
283
+ def _attn_bwd_dq(dq, q, K, V,
284
+ do, m, D,
285
+ # shared by Q/K/V/DO.
286
+ stride_tok, stride_d,
287
+ H, N_CTX,
288
+ BLOCK_M2: tl.constexpr,
289
+ BLOCK_N2: tl.constexpr,
290
+ BLOCK_DMODEL: tl.constexpr,
291
+ # Filled in by the wrapper.
292
+ start_m, start_n, num_steps,
293
+ MASK: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
294
  offs_m = start_m + tl.arange(0, BLOCK_M2)
295
  offs_n = start_n + tl.arange(0, BLOCK_N2)
296
  offs_k = tl.arange(0, BLOCK_DMODEL)
 
300
  strides=(stride_d, stride_tok),
301
  offsets=(0, start_n),
302
  block_shape=(BLOCK_DMODEL, BLOCK_N2),
303
+ order=(0, 1)
304
  )
305
  VT_block_ptr = tl.make_block_ptr(
306
  base=V,
 
308
  strides=(stride_d, stride_tok),
309
  offsets=(0, start_n),
310
  block_shape=(BLOCK_DMODEL, BLOCK_N2),
311
+ order=(0, 1)
312
  )
313
  # D (= delta) is pre-divided by ds_scale.
314
  Di = tl.load(D + offs_m)
 
323
  # Autoregressive masking.
324
  if MASK:
325
  offs_n = curr_n + tl.arange(0, BLOCK_N2)
326
+ mask = (offs_m[:, None] >= offs_n[None, :])
327
  p = tl.where(mask, p, 0.0)
328
  # Compute dP and dS.
329
  vT = tl.load(VT_block_ptr)
 
342
 
343
  @triton.autotune(
344
  configs=[
345
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
346
+ num_stages=1, num_warps=4),
347
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
348
+ num_stages=1, num_warps=4),
349
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
350
+ num_stages=1, num_warps=4),
351
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
352
+ num_stages=1, num_warps=4),
353
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
354
+ num_stages=1, num_warps=4),
355
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
356
+ num_stages=1, num_warps=4),
357
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
358
+ num_stages=1, num_warps=4),
359
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
360
+ num_stages=1, num_warps=4),
361
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
362
+ num_stages=1, num_warps=8),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  ],
364
+ key=['H', 'N_CTX', 'BLOCK_DMODEL'],
365
  )
366
  @triton.jit
367
+ def _attn_bwd(Q, K, V, sm_scale,
368
+ DO,
369
+ DQ, DK, DV,
370
+ M, D,
371
+ # shared by Q/K/V/DO.
372
+ stride_z, stride_h, stride_tok, stride_d,
373
+ # H = 16, N_CTX = 1024
374
+ H, N_CTX,
375
+ BLOCK_DMODEL: tl.constexpr,
376
+ BLOCK_M1: tl.constexpr,
377
+ BLOCK_N1: tl.constexpr,
378
+ BLOCK_M2: tl.constexpr,
379
+ BLOCK_N2: tl.constexpr,
380
+ BLK_SLICE_FACTOR: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
381
  LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
382
 
383
  bhid = tl.program_id(2)
 
433
 
434
  num_steps = BLOCK_N1 // MASK_BLOCK_M1
435
 
436
+ dk, dv = _attn_bwd_dkdv(dk, dv,
437
+ Q, k, v, sm_scale,
438
+ DO,
439
+ M, D,
440
+ stride_tok, stride_d,
441
+ H, N_CTX,
442
+ MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
443
+ start_n, start_m, num_steps,
444
+ MASK=True
445
+ )
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  start_m += num_steps * MASK_BLOCK_M1
448
  num_steps = (N_CTX - start_m) // BLOCK_M1
449
 
450
  # Compute dK and dV for non-masked blocks.
451
  dk, dv = _attn_bwd_dkdv(
452
+ dk, dv,
453
+ Q, k, v, sm_scale,
 
 
 
 
454
  DO,
455
+ M, D,
456
+ stride_tok, stride_d,
457
+ H, N_CTX,
458
+ BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
459
+ start_n, start_m, num_steps,
460
+ MASK=False
 
 
 
 
 
 
 
461
  )
462
 
463
  DV_block_ptrs = tl.make_block_ptr(
 
466
  strides=(stride_tok, stride_d),
467
  offsets=(start_n, 0),
468
  block_shape=(BLOCK_N1, BLOCK_DMODEL),
469
+ order=(1, 0)
470
  )
471
  tl.store(DV_block_ptrs, dv.to(tl.float16))
472
 
 
478
  strides=(stride_tok, stride_d),
479
  offsets=(start_n, 0),
480
  block_shape=(BLOCK_N1, BLOCK_DMODEL),
481
+ order=(1, 0)
482
  )
483
  tl.store(DK_block_ptrs, dk.to(tl.float16))
484
 
 
495
  strides=(stride_tok, stride_d),
496
  offsets=(start_m, 0),
497
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
498
+ order=(1, 0)
499
  )
500
 
501
  DO_block_ptr = tl.make_block_ptr(
 
504
  strides=(stride_tok, stride_d),
505
  offsets=(start_m, 0),
506
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
507
+ order=(1, 0)
508
  )
509
  q = tl.load(Q_block_ptr)
510
  do = tl.load(DO_block_ptr)
 
519
  # not due to anything important. I just wanted to reuse the loop
520
  # structure for dK & dV above as much as possible.
521
  num_steps = BLOCK_M2 // MASK_BLOCK_N2
522
+ dq = _attn_bwd_dq(dq, q, K, V,
523
+ do, m, D,
524
+ stride_tok, stride_d,
525
+ H, N_CTX,
526
+ BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
527
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
528
+ MASK=True
529
+ )
 
 
 
 
 
 
 
 
 
 
 
 
530
  end_n -= num_steps * MASK_BLOCK_N2
531
  # stage 2
532
  num_steps = end_n // BLOCK_N2
533
+ dq = _attn_bwd_dq(dq, q, K, V,
534
+ do, m, D,
535
+ stride_tok, stride_d,
536
+ H, N_CTX,
537
+ BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
538
+ start_m, end_n - num_steps * BLOCK_N2, num_steps,
539
+ MASK=False
540
+ )
 
 
 
 
 
 
 
 
 
 
 
 
541
  # Write back dQ.
542
  DQ_block_ptr = tl.make_block_ptr(
543
  base=DQ,
 
545
  strides=(stride_tok, stride_d),
546
  offsets=(start_m, 0),
547
  block_shape=(BLOCK_M2, BLOCK_DMODEL),
548
+ order=(1, 0)
549
  )
550
  dq *= LN2
551
  tl.store(DQ_block_ptr, dq.to(tl.float16))
 
574
  num_stages = 7 if Lk >= 64 else 3
575
  stage = 3 if causal else 1
576
 
577
+ def grid(META): return (
578
+ triton.cdiv(q.shape[2], META['BLOCK_M']),
579
+ q.shape[0] * q.shape[1],
580
+ 1
 
 
 
 
 
581
  )
582
+ M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
583
+ device=q.device, dtype=torch.float32)
584
  _attn_fwd[grid](
585
+ q, k, v, sm_scale, M, o,
586
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
587
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
588
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
589
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
590
+ q.shape[0], q.shape[1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  N_CTX=q.shape[2],
592
  BLOCK_DMODEL=Lk,
593
  STAGE=stage,
 
629
  pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
630
  delta = torch.empty_like(M)
631
  _attn_bwd_preprocess[pre_grid](
632
+ o, do,
 
633
  delta,
634
+ BATCH, N_HEAD, N_CTX,
635
+ BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
 
 
 
636
  )
637
 
638
+ def grid(META): return (
639
+ triton.cdiv(N_CTX, META['BLOCK_N1']),
640
+ 1,
641
+ BATCH * N_HEAD
642
+ )
643
  _attn_bwd[grid](
644
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
645
+ M, delta,
646
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
647
+ N_HEAD, N_CTX,
648
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL
 
 
 
 
 
 
 
 
 
 
 
 
649
  )
650
 
651
  return dq, dk, dv, None, None
652
 
653
 
654
+ attention = _attention.apply