thomwolf HF staff commited on
Commit
d045a91
·
1 Parent(s): a8a77bd
Files changed (36) hide show
  1. dist/assets/images/activation_recomputation.png +3 -0
  2. dist/assets/images/conclusion_llama3_parallelism.png +3 -0
  3. dist/assets/images/dp_diagram.png +3 -0
  4. dist/assets/images/ep_schema.png +3 -0
  5. dist/assets/images/flashattn.png +3 -0
  6. dist/assets/images/flashattn2.png +3 -0
  7. dist/assets/images/fp8_diagram.png +3 -0
  8. dist/assets/images/fp8_divergence.png +3 -0
  9. dist/assets/images/fused_kernels1.png +3 -0
  10. dist/assets/images/fused_kernels2.png +3 -0
  11. dist/assets/images/gradaccumulation_diag.png +3 -0
  12. dist/assets/images/memorycoalescing.png +3 -0
  13. dist/assets/images/memorycoalescing2.png +3 -0
  14. dist/assets/images/memorycoalescing3.png +3 -0
  15. dist/assets/images/memorycoalescing4.png +3 -0
  16. dist/assets/images/memorycoalescing5.png +3 -0
  17. dist/assets/images/mixedprecision.png +3 -0
  18. dist/assets/images/mixedprecision_2.png +3 -0
  19. dist/assets/images/pp_1f1b_scaling.png +3 -0
  20. dist/assets/images/pp_bubblesize.png +3 -0
  21. dist/assets/images/pp_llama3.1_schedule.png +3 -0
  22. dist/assets/images/pp_zerobubble_compgraph.png +3 -0
  23. dist/assets/images/pp_zerobubble_dualpipe.png +3 -0
  24. dist/assets/images/pp_zerobubble_ppschedule.png +3 -0
  25. dist/assets/images/ring-attention.gif +0 -0
  26. dist/assets/images/threadcoarsening.png +3 -0
  27. dist/assets/images/tiling.png +3 -0
  28. dist/assets/images/tp_diagram.png +3 -0
  29. dist/assets/images/tp_diagram2.png +3 -0
  30. dist/assets/images/tp_diagram3.png +3 -0
  31. dist/assets/images/tp_diagram4.png +3 -0
  32. dist/assets/images/tp_full_diagram.png +3 -0
  33. dist/assets/images/tp_sp_diagram.png +3 -0
  34. dist/assets/images/tp_sp_diagram_zoomed.png +3 -0
  35. dist/index.html +84 -81
  36. src/index.html +0 -10
dist/assets/images/activation_recomputation.png ADDED

Git LFS Details

  • SHA256: 322496303f8133466e128f152e8cb2248bc2a0d5665a57b7894d80048612e64f
  • Pointer size: 130 Bytes
  • Size of remote file: 74.5 kB
dist/assets/images/conclusion_llama3_parallelism.png ADDED

Git LFS Details

  • SHA256: e7282f28522dc436f176a5ceca8891397d9bb8b8522f9b020bdf20e31258d324
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB
dist/assets/images/dp_diagram.png ADDED

Git LFS Details

  • SHA256: 70ad6657c4dd1dc1e2f4ad132206a7c4c8682e44a8277f638753532de9aa7f71
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
dist/assets/images/ep_schema.png ADDED

Git LFS Details

  • SHA256: 63bf8bb1bbe2ff46b4da5cf874df9880532c077995880565e7306cd26a9053b0
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
dist/assets/images/flashattn.png ADDED

Git LFS Details

  • SHA256: 2ca3528348a2cc037d31521c11ad44cf7078653a7f453483e346508ba139ab4d
  • Pointer size: 130 Bytes
  • Size of remote file: 98.3 kB
dist/assets/images/flashattn2.png ADDED

Git LFS Details

  • SHA256: 4312d0a3b349219f2215887926555a08261507e92b992a93337659fd7aff1157
  • Pointer size: 131 Bytes
  • Size of remote file: 396 kB
dist/assets/images/fp8_diagram.png ADDED

Git LFS Details

  • SHA256: 2517479bff358569b4410ffe302d63ff530fa2883722603012b6e14a18fefd75
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
dist/assets/images/fp8_divergence.png ADDED

Git LFS Details

  • SHA256: 81e8495d96c8e40fbd36ee1030fc5325adabd5c2541b4a5c6041b07320ef76c6
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
dist/assets/images/fused_kernels1.png ADDED

Git LFS Details

  • SHA256: 51c0e08c1d245d4bf529a97990eb85b15f31c9dea10f9bfdb18de6969957c20d
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
dist/assets/images/fused_kernels2.png ADDED

Git LFS Details

  • SHA256: 949e208e9f7e140395c303aac6684ded8d120e7cd788c20e5fd395c2df5b5b91
  • Pointer size: 130 Bytes
  • Size of remote file: 73 kB
dist/assets/images/gradaccumulation_diag.png ADDED

Git LFS Details

  • SHA256: 0a7acb4c1e4832272beb247588f2a154a703d5b6f468b5e0b7dcffbcda41bbdc
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
dist/assets/images/memorycoalescing.png ADDED

Git LFS Details

  • SHA256: 96ed02089819123c2ec48b178d4b673cc4f628f4c903f02ad98c5588cf3e1931
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
dist/assets/images/memorycoalescing2.png ADDED

Git LFS Details

  • SHA256: c1708d3f4588768350a78a5b38d7e2a968fb6115d1be8bc0f02f7f81dc6e767c
  • Pointer size: 130 Bytes
  • Size of remote file: 36.4 kB
dist/assets/images/memorycoalescing3.png ADDED

Git LFS Details

  • SHA256: 2fa6b2066aaac9a5dad1a96489414478c680a53bc39ebea704931c466af8d343
  • Pointer size: 130 Bytes
  • Size of remote file: 56.8 kB
dist/assets/images/memorycoalescing4.png ADDED

Git LFS Details

  • SHA256: 62621f72d70635d79b48c7815127cd31da119292981bea58ab20a6b578d3aff3
  • Pointer size: 130 Bytes
  • Size of remote file: 59.3 kB
dist/assets/images/memorycoalescing5.png ADDED

Git LFS Details

  • SHA256: c33982566e567cc075f544aae349bb37dca63b6ce16e2d7e753293826e4a06dd
  • Pointer size: 130 Bytes
  • Size of remote file: 36.4 kB
dist/assets/images/mixedprecision.png ADDED

Git LFS Details

  • SHA256: 8891a3a71a819f217c5b2bfa39d68c794eca480f24bfbb74b618aebde2971fc8
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
dist/assets/images/mixedprecision_2.png ADDED

Git LFS Details

  • SHA256: d4cac7b16899d1c36f4936ddcf6751ce96391831397199735a3ef64b6daa0a07
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
dist/assets/images/pp_1f1b_scaling.png ADDED

Git LFS Details

  • SHA256: 5191c89bcffed1ead467742eb4cec4c89c53f22e8c7391d115ca06ece15cf21c
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
dist/assets/images/pp_bubblesize.png ADDED

Git LFS Details

  • SHA256: 784528719df2d3cbb4765802463b1ab14e1b20d80a593b27ea328086fb67a5cb
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
dist/assets/images/pp_llama3.1_schedule.png ADDED

Git LFS Details

  • SHA256: a055afbdb270c6c319e41aa2b7d4e2893c55da18fd6102764a43ce7935d224e2
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
dist/assets/images/pp_zerobubble_compgraph.png ADDED

Git LFS Details

  • SHA256: 58b04bbae5360ee205560670c196df82964b6ddb552f962b86889727d292ff08
  • Pointer size: 130 Bytes
  • Size of remote file: 47.6 kB
dist/assets/images/pp_zerobubble_dualpipe.png ADDED

Git LFS Details

  • SHA256: e3d4c7070550b4a76f1c39577edb9c62b467817558721a652cdc9f6e4bdcba1f
  • Pointer size: 131 Bytes
  • Size of remote file: 206 kB
dist/assets/images/pp_zerobubble_ppschedule.png ADDED

Git LFS Details

  • SHA256: 12f18a861d558fa68b8aefdcdecc8b63326ec4b56350e9a1536d45c2cc1238ef
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
dist/assets/images/ring-attention.gif ADDED
dist/assets/images/threadcoarsening.png ADDED

Git LFS Details

  • SHA256: 007f2426210d2328df00dcd1122b056831b47a2c604f6512ff08172fcb943621
  • Pointer size: 130 Bytes
  • Size of remote file: 38.6 kB
dist/assets/images/tiling.png ADDED

Git LFS Details

  • SHA256: 8889c4317e8a78a16404af34cfb5153fc10cfefca6b25dc0a9eb7561abf012c3
  • Pointer size: 130 Bytes
  • Size of remote file: 21.6 kB
dist/assets/images/tp_diagram.png ADDED

Git LFS Details

  • SHA256: fb5ae9993740f216bfc4f8481536739c0e85853ef798fe1940f4e6b3bee0683d
  • Pointer size: 130 Bytes
  • Size of remote file: 43.3 kB
dist/assets/images/tp_diagram2.png ADDED

Git LFS Details

  • SHA256: f075304c019e12be1ac0ef8afa9241c03bc466f568dca0c66e20b1391a471bca
  • Pointer size: 131 Bytes
  • Size of remote file: 486 kB
dist/assets/images/tp_diagram3.png ADDED

Git LFS Details

  • SHA256: beff9be457b6363c370d9831f42155ae9674240d2588eac6270f62aeb58f0a70
  • Pointer size: 131 Bytes
  • Size of remote file: 486 kB
dist/assets/images/tp_diagram4.png ADDED

Git LFS Details

  • SHA256: 6885bbcb5ba13aad0a111b69eacf6e679a0ed5dd688cc0ac0b58b21318fce852
  • Pointer size: 131 Bytes
  • Size of remote file: 211 kB
dist/assets/images/tp_full_diagram.png ADDED

Git LFS Details

  • SHA256: 01f9fd3bc4b0a97b167d6ce2d47b0a207492441592d1e2a51fab1e8bfad9962e
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
dist/assets/images/tp_sp_diagram.png ADDED

Git LFS Details

  • SHA256: d2463f346a6a3e16d447329a94eb8e9120d38effacf1230637ca25cd35d4c250
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB
dist/assets/images/tp_sp_diagram_zoomed.png ADDED

Git LFS Details

  • SHA256: f86131810347fba77e74d4972ad0115e8bdfab4d42448b8f07e1d79c3d6eef6a
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
dist/index.html CHANGED
@@ -297,14 +297,9 @@
297
 
298
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
299
 
300
- <<<<<<< HEAD
301
- <!--<div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
302
- <script src="../assets/images/first_steps_memory_profile.js"></script>-->
303
- =======
304
  <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
305
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
306
  <script src="../assets/images/first_steps_memory_profile.js"></script>
307
- >>>>>>> a1429a9 (update)
308
 
309
  <iframe id="plotFrame" src="assets/data/benchmarks/memory-profile.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
310
 
@@ -421,7 +416,6 @@
421
 
422
  <p>An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (<code>bs=1</code>):</p>
423
 
424
- <!-- <p><img alt="llama-memory-bars-no-recomp.png" src="/assets/images/placeholder.png" /></p> -->
425
  <iframe class="l-body-outset" id="plotFrame2" src="assets/data/benchmarks/memusage_activations.html" width="90%" scrolling="no" frameborder="0"></iframe>
426
  <script>
427
  window.addEventListener('load', function() {
@@ -446,7 +440,6 @@
446
  <div class="svg-container" id="svg-activation_recomputation"> </div>
447
  <div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
448
  <script src="../assets/images/activation_recomputation.js"></script>
449
-
450
  <p>There are several strategies to select key activations to store:</p>
451
 
452
  <ul>
@@ -505,7 +498,7 @@
505
 
506
  <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
507
 
508
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
509
 
510
  <aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
511
 
@@ -524,13 +517,13 @@
524
 
525
  <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
526
 
527
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
528
 
529
  <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
530
 
531
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
532
 
533
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
534
 
535
  <p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
536
 
@@ -556,7 +549,7 @@
556
  if p.requires_grad is True:
557
  p.register_post_accumulate_grad_hook(hook)</d-code>
558
 
559
- <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
560
 
561
  <p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
562
 
@@ -590,7 +583,7 @@
590
  </div>
591
  </details>
592
 
593
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
594
 
595
  <h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
596
 
@@ -650,7 +643,7 @@
650
 
651
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
652
 
653
- <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
654
 
655
  <p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
656
 
@@ -658,7 +651,7 @@
658
 
659
  <p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
660
 
661
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
662
 
663
  <aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
664
 
@@ -704,7 +697,7 @@
704
 
705
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
706
 
707
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
708
  <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
709
 
710
 
@@ -730,11 +723,11 @@
730
 
731
  <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
732
 
733
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
734
 
735
  <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
736
 
737
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
738
 
739
  <p>If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:</p>
740
 
@@ -758,13 +751,13 @@
758
 
759
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
760
 
761
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
762
 
763
  <p>It’s easy to see now that sharding the gradients leads to to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math> and as <d-math>N_d</d-math> is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.</p>
764
 
765
  <p>In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.</p>
766
 
767
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
768
 
769
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
770
 
@@ -783,13 +776,15 @@
783
 
784
  <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:</p>
785
 
786
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
787
 
788
- <p>So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we dont need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards: </p>
789
 
790
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
791
 
792
-
 
 
793
 
794
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
795
 
@@ -804,7 +799,7 @@
804
 
805
  <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
806
 
807
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
808
 
809
  <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
810
 
@@ -830,13 +825,13 @@
830
 
831
  <p>In practice a small example of the operation looks like this:</p>
832
 
833
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
834
 
835
  <p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
836
 
837
  <p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
838
 
839
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
840
 
841
  <p>Here's the code implementation of column wise tensor parallelism:</p>
842
 
@@ -853,7 +848,7 @@
853
 
854
  <p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
855
 
856
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
857
 
858
  <p>Here's the implementation for row-wise tensor parallelism:</p>
859
 
@@ -874,7 +869,7 @@
874
 
875
  <p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
876
 
877
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
878
 
879
  <p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
880
 
@@ -882,17 +877,17 @@
882
 
883
  <p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
884
 
885
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
886
 
887
  <p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
888
 
889
- <p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/placeholder.png" /></p>
890
 
891
  <p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
892
 
893
  <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
894
 
895
- <p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/placeholder.png" /></p>
896
 
897
  <p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
898
 
@@ -900,7 +895,7 @@
900
 
901
  <p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
902
 
903
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
904
 
905
  <p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
906
 
@@ -940,7 +935,7 @@
940
 
941
  <p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
942
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
943
- SP region needs full hidden_dim" src="/assets/images/placeholder.png" /></p>
944
 
945
  <p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
946
 
@@ -961,7 +956,7 @@
961
 
962
  <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
963
 
964
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
965
 
966
  <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
967
 
@@ -1049,13 +1044,13 @@
1049
 
1050
  <p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
1051
 
1052
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1053
 
1054
  <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1055
 
1056
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1057
 
1058
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1059
 
1060
  <p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
1061
 
@@ -1064,7 +1059,7 @@
1064
 
1065
  <p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
1066
 
1067
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1068
  <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
1069
 
1070
  <p>Let’s summarize our observations:</p>
@@ -1094,7 +1089,7 @@
1094
 
1095
  <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
1096
 
1097
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1098
 
1099
  <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
1100
 
@@ -1102,7 +1097,7 @@
1102
 
1103
  <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previous with Tensor + Sequence Parallelism.</p>
1104
 
1105
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1106
 
1107
  <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
1108
 
@@ -1133,13 +1128,13 @@
1133
 
1134
  <p>The whole process with 4 GPUs is shown in the following animation:</p>
1135
 
1136
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1137
 
1138
  <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
1139
 
1140
  <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
1141
 
1142
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1143
 
1144
  <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
1145
 
@@ -1149,14 +1144,14 @@
1149
 
1150
  <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention<d-cite bibtex-key="attention brandon2023fasterring"></d-cite> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
1151
 
1152
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1153
 
1154
  <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
1155
 
1156
  <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
1157
 
1158
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1159
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1160
 
1161
  <p>The key difference between these two implementations lies in their communication patterns and memory usage:</p>
1162
 
@@ -1166,7 +1161,6 @@
1166
  <li>All GPUs simultaneously gather the complete key/value pairs from all other GPUs</li>
1167
  <li>Requires more temporary memory as each GPU needs to store the full KV pairs at once</li>
1168
  <li>Communication happens in one step but with larger memory overhead</li>
1169
- <li>Used in MegatronLM's implementation of context parallelism</li>
1170
  </ul>
1171
 
1172
  <p><strong>2. All-to-All (Ring) Implementation:</strong></p>
@@ -1175,7 +1169,6 @@
1175
  <li>GPUs exchange KV pairs in a ring-like pattern, one chunk at a time</li>
1176
  <li>More memory efficient as each GPU only needs to store one additional chunk temporarily</li>
1177
  <li>Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps</li>
1178
- <li>Used in DeepSpeed's implementation of context parallelism</li>
1179
  </ul>
1180
 
1181
  <p>The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.</p>
@@ -1186,12 +1179,12 @@
1186
 
1187
  <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
1188
 
1189
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1190
  <p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
1191
 
1192
  <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
1193
 
1194
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1195
 
1196
  <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
1197
 
@@ -1205,7 +1198,7 @@
1205
 
1206
  <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
1207
 
1208
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1209
  <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
1210
 
1211
  <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
@@ -1224,7 +1217,7 @@
1224
 
1225
  <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
1226
 
1227
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1228
 
1229
  <aside>Before the numbers in the diagram indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure. </aside>
1230
 
@@ -1257,11 +1250,12 @@
1257
 
1258
  <p>This schedule is called <strong><em>one-forward-one-backward (1F1B)</em></strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
1259
 
1260
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1261
 
1262
  <p>The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for <d-math>p</d-math> micro-batches instead of <d-math>m</d-math> which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.</p>
1263
 
1264
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
1265
 
1266
  <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
1267
 
@@ -1292,7 +1286,7 @@
1292
 
1293
  <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
1294
 
1295
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1296
 
1297
  <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of <d-math>v</d-math>, where <d-math>v</d-math> is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
1298
 
@@ -1307,14 +1301,14 @@
1307
 
1308
  <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by <d-math>v</d-math> so it’s a trade off. In the following plot you can see several configurations for a PP setup with <d-math>p=8</d-math>, where the special case of <d-math>m=1, v=1</d-math> corresponds to naive pipeline parallelism and the configurations with <d-math>v=1</d-math> are AFAB or 1F1B setups and <d-math>v \neq 1</d-math> are interleaved configurations.</p>
1309
 
1310
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1311
 
1312
 
1313
  <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in detail in the "Breadth-Fist Pipeline" paper<d-cite bibtex-key="lamypoirier2023breadthfirstpipelineparallelism"></d-cite>.</p>
1314
 
1315
  <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
1316
 
1317
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1318
 
1319
  <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
1320
 
@@ -1323,14 +1317,15 @@
1323
  <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
1324
 
1325
  <p>Let’s very quickly see how this can work by detailing briefly the ZeroBubble<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
1326
-
1327
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
1328
 
1329
  <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
1330
 
1331
  <p>DeepSeek’s DualPipe introduced with V3 proposes an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph:</p>
1332
 
1333
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1334
 
1335
  <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the ZeroBubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
1336
 
@@ -1341,7 +1336,7 @@
1341
 
1342
  <p>Mixture-of-expert models have gained some traction with models such as Mixtral<d-cite bibtex-key="jiang2024mixtralexperts"></d-cite> or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context:</p>
1343
 
1344
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1345
  <p>Source: A Survey on Mixture of Experts<d-cite bibtex-key="cai2024surveymixtureexperts"></d-cite> </p>
1346
 
1347
  <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
@@ -1474,7 +1469,7 @@
1474
 
1475
  <p>And to have an idea of the memory benefits of each parallelism:</p>
1476
 
1477
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1478
 
1479
  <h2>How to Find the Best Training Configuration</h2>
1480
 
@@ -1633,12 +1628,12 @@
1633
 
1634
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1635
 
1636
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1637
  <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1638
 
1639
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1640
 
1641
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1642
  <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1643
 
1644
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
@@ -1791,16 +1786,17 @@
1791
 
1792
  <p>Here’s an excellent visualization of the kernel from this <a href="https://siboehm.com/articles/22/CUDA-MMM">fantastic blogpost</a>: </p>
1793
 
1794
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1795
 
1796
  <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
1797
 
1798
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
1799
 
1800
 
1801
  <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load <d-math>A_{0,0}</d-math>, and thread <code>(1, 0)</code> will load <d-math>A_{1,0}</d-math>. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
1802
 
1803
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1804
 
1805
 
1806
  <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
@@ -1822,7 +1818,7 @@
1822
 
1823
  <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
1824
 
1825
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1826
 
1827
 
1828
  <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
@@ -1838,7 +1834,7 @@
1838
 
1839
  <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.</p>
1840
 
1841
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1842
  <p>From https://cnugteren.github.io/tutorial/pages/page4.html</p>
1843
 
1844
  <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
@@ -1883,7 +1879,7 @@
1883
 
1884
  <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
1885
 
1886
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1887
 
1888
 
1889
  <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that:</p>
@@ -1905,11 +1901,16 @@
1905
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
1906
 
1907
  <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
1908
-
1909
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1910
- <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
1911
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1912
- <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
 
 
 
 
 
1913
 
1914
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
1915
 
@@ -1926,13 +1927,13 @@
1926
 
1927
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1928
 
1929
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1930
 
1931
  <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1932
 
1933
  <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1934
 
1935
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1936
  <p>From the FLASH-ATTENTION paper<d-cite bibtex-key="dao2022flashattention"></d-cite></p>
1937
 
1938
  <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
@@ -2018,14 +2019,14 @@
2018
 
2019
  <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
2020
 
2021
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2022
 
2023
 
2024
  <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
2025
 
2026
  <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
2027
 
2028
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2029
 
2030
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2031
 
@@ -2063,7 +2064,7 @@
2063
 
2064
  <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
2065
 
2066
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2067
 
2068
  <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2069
 
@@ -2157,7 +2158,7 @@
2157
 
2158
  <p>Congratulations! You've completed quite a journey - from understanding how to train a simple model on a single GPU, all the way to mastering the complex techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3. By now, you should feel confident interpreting advanced parallelism diagrams like the one below, which would have seemed daunting when you first started.</p>
2159
 
2160
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2161
 
2162
  <p>In distributed training, many concepts sound easy enough when you first hear them, like “Pipeline parallelism just distributes layers on different GPUs”, but we also worked through all the challenging details when implementing those methods. </p>
2163
 
@@ -2207,12 +2208,14 @@
2207
 
2208
  <p>First, let's examine this heatmap visualization:</p>
2209
 
2210
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
2211
  <p>Heatmap visualization showing the optimal training configurations across different model sizes and compute node counts. For each combination, the configuration details include Data Parallelism (DP), Tensor Parallelism (TP), Pipeline Parallelism (PP), Gradient Accumulation Steps (GAS), Micro Batch Size (MBS), and ZeRO optimization stage. The color intensity indicates the Model FLOPs Utilization (MFU), with brighter colors representing higher efficiency.</p>
2212
 
2213
  <p>To complement this, let's look at the relationships between different parameters:</p>
2214
 
2215
- <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
 
 
2216
  <p>Parallel coordinates plot showing the relationship between different model parallelism configurations (Data Parallel degree, Tensor Parallel degree, Pipeline Parallel degree), training hyperparameters (gradient accumulation steps, micro batch size), ZeRO stage and the resulting Model FLOPs Utilization (MFU). Each line represents a different training configuration, with colors indicating the MFU value - warmer colors show higher efficiency.</p>
2217
 
2218
  <p>From these visualizations, we can draw several important insights:
 
297
 
298
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
299
 
 
 
 
 
300
  <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
301
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
302
  <script src="../assets/images/first_steps_memory_profile.js"></script>
 
303
 
304
  <iframe id="plotFrame" src="assets/data/benchmarks/memory-profile.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
305
 
 
416
 
417
  <p>An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (<code>bs=1</code>):</p>
418
 
 
419
  <iframe class="l-body-outset" id="plotFrame2" src="assets/data/benchmarks/memusage_activations.html" width="90%" scrolling="no" frameborder="0"></iframe>
420
  <script>
421
  window.addEventListener('load', function() {
 
440
  <div class="svg-container" id="svg-activation_recomputation"> </div>
441
  <div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
442
  <script src="../assets/images/activation_recomputation.js"></script>
 
443
  <p>There are several strategies to select key activations to store:</p>
444
 
445
  <ul>
 
498
 
499
  <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
500
 
501
+ <p><img alt="image.png" src="/assets/images/gradaccumulation_diag.png" /></p>
502
 
503
  <aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
504
 
 
517
 
518
  <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
519
 
520
+ <p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
521
 
522
  <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
523
 
524
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
525
 
526
+ <p><img alt="image.png" src="/assets/images/dp_overlap1.svg" /></p>
527
 
528
  <p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
529
 
 
549
  if p.requires_grad is True:
550
  p.register_post_accumulate_grad_hook(hook)</d-code>
551
 
552
+ <p><img alt="image.png" src="/assets/images/dp_overlap2.svg"/></p>
553
 
554
  <p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
555
 
 
583
  </div>
584
  </details>
585
 
586
+ <p><img alt="dp_overlap3.svg" src="/assets/images/dp_overlap3.svg" /></p>
587
 
588
  <h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
589
 
 
643
 
644
  <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
645
 
646
+ <p><img alt="image.png" src="/assets/images/dp_scaling.svg"/></p>
647
 
648
  <p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
649
 
 
651
 
652
  <p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
653
 
654
+ <p><img alt="dp_ourjourney_memoryusage.svg" src="/assets/images/dp_ourjourney_memoryusage.svg" /></p>
655
 
656
  <aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
657
 
 
697
 
698
  <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
699
 
700
+ <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
701
  <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
702
 
703
 
 
723
 
724
  <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
725
 
726
+ <p><img alt="dp_zero1.gif" src="/assets/images/dp_zero1.gif" /></p>
727
 
728
  <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
729
 
730
+ <p><img alt="dp_zero1_overlap.svg" src="/assets/images/dp_zero1_overlap.svg" /></p>
731
 
732
  <p>If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:</p>
733
 
 
751
 
752
  <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
753
 
754
+ <p><img alt="dp_zero2.gif" src="/assets/images/dp_zero2.gif" /></p>
755
 
756
  <p>It’s easy to see now that sharding the gradients leads to to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math> and as <d-math>N_d</d-math> is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.</p>
757
 
758
  <p>In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.</p>
759
 
760
+ <p><img alt="dp_zero2_overlap.svg" src="/assets/images/dp_zero2_overlap.svg" /></p>
761
 
762
  <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
763
 
 
776
 
777
  <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:</p>
778
 
779
+ <p><img alt="dp_zero3_fwd.svg" src="/assets/images/dp_zero3_fwd.svg" /></p>
780
 
781
+ <p>So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don't need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards: </p>
782
 
783
+ <p><img alt="dp_zero3_bwd.svg" src="/assets/images/dp_zero3_bwd.svg" /></p>
784
 
785
+ <p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to <d-math>2\cdot \text{num\_layers} -1</d-math> additional all-gathers in <strong>a training step</strong> compared to Zero-2, each comes with a small <strong>base latency</strong> overhead as we can see in the following figure:</p>
786
+
787
+ <p><img alt="dp_zero3_overlap.svg" src="/assets/images/dp_zero3_overlap.svg" /></p>
788
 
789
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
790
 
 
799
 
800
  <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
801
 
802
+ <p><img alt="zero3_memoryusage.svg" src="/assets/images/zero3_memoryusage.svg" /></p>
803
 
804
  <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
805
 
 
825
 
826
  <p>In practice a small example of the operation looks like this:</p>
827
 
828
+ <p><img alt="image.png" src="/assets/images/tp_diagram.png" /></p>
829
 
830
  <p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
831
 
832
  <p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
833
 
834
+ <p><img alt="image.png" src="/assets/images/tp_diagram2.png" /></p>
835
 
836
  <p>Here's the code implementation of column wise tensor parallelism:</p>
837
 
 
848
 
849
  <p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
850
 
851
+ <p><img alt="image.png" src="/assets/images/tp_diagram3.png" /></p>
852
 
853
  <p>Here's the implementation for row-wise tensor parallelism:</p>
854
 
 
869
 
870
  <p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
871
 
872
+ <p><img alt="image.png" src="/assets/images/tp_diagram4.png" /></p>
873
 
874
  <p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
875
 
 
877
 
878
  <p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
879
 
880
+ <p><img alt="image.png" src="/assets/images/tp_full_diagram.png" /></p>
881
 
882
  <p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
883
 
884
+ <p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/tp_overlap.svg" /></p>
885
 
886
  <p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
887
 
888
  <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
889
 
890
+ <p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/tp_scaling.svg" /></p>
891
 
892
  <p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
893
 
 
895
 
896
  <p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
897
 
898
+ <p><img alt="tp_memoryusage.svg" src="/assets/images/tp_memoryusage.svg" /></p>
899
 
900
  <p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
901
 
 
935
 
936
  <p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
937
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
938
+ SP region needs full hidden_dim" src="/assets/images/tp_sp_diagram.png" /></p>
939
 
940
  <p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
941
 
 
956
 
957
  <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
958
 
959
+ <p><img alt="image.png" src="/assets/images/tp_sp_diagram_zoomed.png" /></p>
960
 
961
  <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
962
 
 
1044
 
1045
  <p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
1046
 
1047
+ <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p>
1048
 
1049
  <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1050
 
1051
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1052
 
1053
+ <p><img alt="tp_sp_overlap.svg" src="/assets/images/tp_sp_overlap.svg" /></p>
1054
 
1055
  <p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
1056
 
 
1059
 
1060
  <p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
1061
 
1062
+ <p><img alt="tp_sp_scaling.svg" src="/assets/images/tp_sp_scaling.svg" /></p>
1063
  <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
1064
 
1065
  <p>Let’s summarize our observations:</p>
 
1089
 
1090
  <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
1091
 
1092
+ <p><img alt="image.png" src="/assets/images/cp_memoryusage.svg" /></p>
1093
 
1094
  <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
1095
 
 
1097
 
1098
  <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previous with Tensor + Sequence Parallelism.</p>
1099
 
1100
+ <p><img alt="cp_8Bmemoryusage.svg" src="/assets/images/cp_8Bmemoryusage.svg" /></p>
1101
 
1102
  <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
1103
 
 
1128
 
1129
  <p>The whole process with 4 GPUs is shown in the following animation:</p>
1130
 
1131
+ <p><img alt="ring-attention.gif" src="/assets/images/ring-attention.gif" /></p>
1132
 
1133
  <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
1134
 
1135
  <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
1136
 
1137
+ <p><img alt="cp_attnmask.svg" src="/assets/images/cp_attnmask.svg" /></p>
1138
 
1139
  <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
1140
 
 
1144
 
1145
  <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention<d-cite bibtex-key="attention brandon2023fasterring"></d-cite> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
1146
 
1147
+ <p><img alt="cp_zigzagmask.svg" src="/assets/images/cp_zigzagmask.svg" /></p>
1148
 
1149
  <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
1150
 
1151
  <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
1152
 
1153
+ <p><img alt="cp_overlap_allgather.svg" src="/assets/images/cp_overlap_allgather.svg" /></p>
1154
+ <p><img alt="cp_overlap_all2all.svg" src="/assets/images/cp_overlap_all2all.svg" /></p>
1155
 
1156
  <p>The key difference between these two implementations lies in their communication patterns and memory usage:</p>
1157
 
 
1161
  <li>All GPUs simultaneously gather the complete key/value pairs from all other GPUs</li>
1162
  <li>Requires more temporary memory as each GPU needs to store the full KV pairs at once</li>
1163
  <li>Communication happens in one step but with larger memory overhead</li>
 
1164
  </ul>
1165
 
1166
  <p><strong>2. All-to-All (Ring) Implementation:</strong></p>
 
1169
  <li>GPUs exchange KV pairs in a ring-like pattern, one chunk at a time</li>
1170
  <li>More memory efficient as each GPU only needs to store one additional chunk temporarily</li>
1171
  <li>Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps</li>
 
1172
  </ul>
1173
 
1174
  <p>The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.</p>
 
1179
 
1180
  <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
1181
 
1182
+ <p><img alt="pp_comm_bandwidth.svg" src="/assets/images/pp_comm_bandwidth.svg" /></p>
1183
  <p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
1184
 
1185
  <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
1186
 
1187
+ <p><img alt="pp_memoryusage.svg" src="/assets/images/pp_memoryusage.svg" /></p>
1188
 
1189
  <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
1190
 
 
1198
 
1199
  <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
1200
 
1201
+ <p><img alt="image.png" src="/assets/images/pp_afab.svg" /></p>
1202
  <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
1203
 
1204
  <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
 
1217
 
1218
  <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
1219
 
1220
+ <p><img alt="pp_afab2.svg" src="/assets/images/pp_afab2.svg" /></p>
1221
 
1222
  <aside>Before the numbers in the diagram indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure. </aside>
1223
 
 
1250
 
1251
  <p>This schedule is called <strong><em>one-forward-one-backward (1F1B)</em></strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
1252
 
1253
+ <p><img alt="image.png" src="/assets/images/pp_1f1b.svg" /></p>
1254
 
1255
  <p>The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for <d-math>p</d-math> micro-batches instead of <d-math>m</d-math> which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.</p>
1256
 
1257
+ <!-- TODO: @Nouamane add this figure -->
1258
+ <p><img alt="image.png" src="/assets/images/pp_1f1b_scaling.png" /></p>
1259
 
1260
  <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
1261
 
 
1286
 
1287
  <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
1288
 
1289
+ <p><img alt="pp_1f1b_interleaved.svg" src="/assets/images/pp_1f1b_interleaved.svg" /></p>
1290
 
1291
  <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of <d-math>v</d-math>, where <d-math>v</d-math> is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
1292
 
 
1301
 
1302
  <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by <d-math>v</d-math> so it’s a trade off. In the following plot you can see several configurations for a PP setup with <d-math>p=8</d-math>, where the special case of <d-math>m=1, v=1</d-math> corresponds to naive pipeline parallelism and the configurations with <d-math>v=1</d-math> are AFAB or 1F1B setups and <d-math>v \neq 1</d-math> are interleaved configurations.</p>
1303
 
1304
+ <p><img alt="pp_bubblesize.png" src="/assets/images/pp_bubblesize.png" /></p>
1305
 
1306
 
1307
  <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in detail in the "Breadth-Fist Pipeline" paper<d-cite bibtex-key="lamypoirier2023breadthfirstpipelineparallelism"></d-cite>.</p>
1308
 
1309
  <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
1310
 
1311
+ <p><img alt="pp_llama3.1_schedule.png" src="/assets/images/pp_llama3.1_schedule.png" /></p>
1312
 
1313
  <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
1314
 
 
1317
  <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
1318
 
1319
  <p>Let’s very quickly see how this can work by detailing briefly the ZeroBubble<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
1320
+
1321
+ <p><img alt="image.png" src="/assets/images/pp_zerobubble_compgraph.png" /></p>
1322
+ <p><img alt="image.png" src="/assets/images/pp_zerobubble_ppschedule.png" /></p>
1323
 
1324
  <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
1325
 
1326
  <p>DeepSeek’s DualPipe introduced with V3 proposes an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph:</p>
1327
 
1328
+ <p><img alt="image.png" src="/assets/images/pp_zerobubble_dualpipe.png" /></p>
1329
 
1330
  <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the ZeroBubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
1331
 
 
1336
 
1337
  <p>Mixture-of-expert models have gained some traction with models such as Mixtral<d-cite bibtex-key="jiang2024mixtralexperts"></d-cite> or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context:</p>
1338
 
1339
+ <p><img alt="ep_schema.png" src="/assets/images/ep_schema.png" /></p>
1340
  <p>Source: A Survey on Mixture of Experts<d-cite bibtex-key="cai2024surveymixtureexperts"></d-cite> </p>
1341
 
1342
  <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
 
1469
 
1470
  <p>And to have an idea of the memory benefits of each parallelism:</p>
1471
 
1472
+ <p><img alt="image.png" src="/assets/images/5Dparallelism_8Bmemoryusage.svg" /></p>
1473
 
1474
  <h2>How to Find the Best Training Configuration</h2>
1475
 
 
1628
 
1629
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1630
 
1631
+ <p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
1632
  <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1633
 
1634
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1635
 
1636
+ <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
1637
  <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1638
 
1639
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
 
1786
 
1787
  <p>Here’s an excellent visualization of the kernel from this <a href="https://siboehm.com/articles/22/CUDA-MMM">fantastic blogpost</a>: </p>
1788
 
1789
+ <p><img alt="image.png" src="/assets/images/memorycoalescing.png" /></p>
1790
 
1791
  <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
1792
 
1793
+ <p><img alt="image.png" src="/assets/images/memorycoalescing2.png" /></p>
1794
+ <p><img alt="image.png" src="/assets/images/memorycoalescing3.png" /></p>
1795
 
1796
 
1797
  <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load <d-math>A_{0,0}</d-math>, and thread <code>(1, 0)</code> will load <d-math>A_{1,0}</d-math>. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
1798
 
1799
+ <p><img alt="image.png" src="/assets/images/memorycoalescing4.png" /></p>
1800
 
1801
 
1802
  <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
 
1818
 
1819
  <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
1820
 
1821
+ <p><img alt="image.png" src="/assets/images/memorycoalescing5.png" /></p>
1822
 
1823
 
1824
  <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
 
1834
 
1835
  <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.</p>
1836
 
1837
+ <p><img alt="image.png" src="/assets/images/tiling.png" /></p>
1838
  <p>From https://cnugteren.github.io/tutorial/pages/page4.html</p>
1839
 
1840
  <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
 
1879
 
1880
  <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
1881
 
1882
+ <p><img alt="image.png" src="/assets/images/threadcoarsening.png" /></p>
1883
 
1884
 
1885
  <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that:</p>
 
1901
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
1902
 
1903
  <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
1904
+ <div style="display: flex; gap: 20px; align-items: flex-start;">
1905
+ <div style="width: 50%;">
1906
+ <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
1907
+ <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
1908
+ </div>
1909
+ <div style="width: 50%;">
1910
+ <img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
1911
+ <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
1912
+ </div>
1913
+ </div>
1914
 
1915
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
1916
 
 
1927
 
1928
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1929
 
1930
+ <p><img alt="image.png" src="/assets/images/flashattn.png" /></p>
1931
 
1932
  <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1933
 
1934
  <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1935
 
1936
+ <p><img alt="image.png" src="/assets/images/flashattn2.png" /></p>
1937
  <p>From the FLASH-ATTENTION paper<d-cite bibtex-key="dao2022flashattention"></d-cite></p>
1938
 
1939
  <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
 
2019
 
2020
  <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
2021
 
2022
+ <p><img alt="image.png" src="/assets/images/mixedprecision.png" /></p>
2023
 
2024
 
2025
  <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
2026
 
2027
  <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
2028
 
2029
+ <p><img alt="image.png" src="/assets/images/mixedprecision_2.png" /></p>
2030
 
2031
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2032
 
 
2064
 
2065
  <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
2066
 
2067
+ <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2068
 
2069
  <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2070
 
 
2158
 
2159
  <p>Congratulations! You've completed quite a journey - from understanding how to train a simple model on a single GPU, all the way to mastering the complex techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3. By now, you should feel confident interpreting advanced parallelism diagrams like the one below, which would have seemed daunting when you first started.</p>
2160
 
2161
+ <p><img alt="image.png" src="/assets/images/conclusion_llama3_parallelism.png" /></p>
2162
 
2163
  <p>In distributed training, many concepts sound easy enough when you first hear them, like “Pipeline parallelism just distributes layers on different GPUs”, but we also worked through all the challenging details when implementing those methods. </p>
2164
 
 
2208
 
2209
  <p>First, let's examine this heatmap visualization:</p>
2210
 
2211
+ <p><img alt="image.png" src="/assets/images/what_we_learnt_heatmap.svg" /></p>
2212
  <p>Heatmap visualization showing the optimal training configurations across different model sizes and compute node counts. For each combination, the configuration details include Data Parallelism (DP), Tensor Parallelism (TP), Pipeline Parallelism (PP), Gradient Accumulation Steps (GAS), Micro Batch Size (MBS), and ZeRO optimization stage. The color intensity indicates the Model FLOPs Utilization (MFU), with brighter colors representing higher efficiency.</p>
2213
 
2214
  <p>To complement this, let's look at the relationships between different parameters:</p>
2215
 
2216
+ <!-- <p><img alt="image.png" src="/assets/images/what_we_learnt_parallel_coordinates.html" /></p> -->
2217
+ <iframe id="plotFrame" src="/assets/images/what_we_learnt_parallel_coordinates.html" height="540" width="1000" scrolling="no" frameborder="0"></iframe>
2218
+
2219
  <p>Parallel coordinates plot showing the relationship between different model parallelism configurations (Data Parallel degree, Tensor Parallel degree, Pipeline Parallel degree), training hyperparameters (gradient accumulation steps, micro batch size), ZeRO stage and the resulting Model FLOPs Utilization (MFU). Each line represents a different training configuration, with colors indicating the MFU value - warmer colors show higher efficiency.</p>
2220
 
2221
  <p>From these visualizations, we can draw several important insights:
src/index.html CHANGED
@@ -297,14 +297,9 @@
297
 
298
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
299
 
300
- <<<<<<< HEAD
301
- <!--<div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
302
- <script src="../assets/images/first_steps_memory_profile.js"></script>-->
303
- =======
304
  <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
305
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
306
  <script src="../assets/images/first_steps_memory_profile.js"></script>
307
- >>>>>>> a1429a9 (update)
308
 
309
  <iframe id="plotFrame" src="assets/data/benchmarks/memory-profile.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
310
 
@@ -442,14 +437,9 @@
442
 
443
  <p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
444
 
445
- <<<<<<< HEAD
446
- <p><img alt="image.png" src="/assets/images/activation_recomputation.png" /></p>
447
- =======
448
  <div class="svg-container" id="svg-activation_recomputation"> </div>
449
  <div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
450
  <script src="../assets/images/activation_recomputation.js"></script>
451
- >>>>>>> a1429a9 (update)
452
-
453
  <p>There are several strategies to select key activations to store:</p>
454
 
455
  <ul>
 
297
 
298
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
299
 
 
 
 
 
300
  <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
301
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
302
  <script src="../assets/images/first_steps_memory_profile.js"></script>
 
303
 
304
  <iframe id="plotFrame" src="assets/data/benchmarks/memory-profile.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
305
 
 
437
 
438
  <p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
439
 
 
 
 
440
  <div class="svg-container" id="svg-activation_recomputation"> </div>
441
  <div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
442
  <script src="../assets/images/activation_recomputation.js"></script>
 
 
443
  <p>There are several strategies to select key activations to store:</p>
444
 
445
  <ul>