Markus28 commited on
Commit
eec6c0e
·
1 Parent(s): 5944ec8

reference the flash attention GitHub

Browse files
Files changed (5) hide show
  1. bert_padding.py +5 -0
  2. block.py +5 -0
  3. embedding.py +5 -0
  4. mha.py +9 -0
  5. mlp.py +5 -0
bert_padding.py CHANGED
@@ -1,5 +1,10 @@
1
  # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
 
 
 
 
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  from einops import rearrange, repeat
 
1
  # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
 
3
+ """"
4
+ The implementation was further adapted from
5
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
6
+ """
7
+
8
  import torch
9
  import torch.nn.functional as F
10
  from einops import rearrange, repeat
block.py CHANGED
@@ -1,5 +1,10 @@
1
  # Copyright (c) 2024, Tri Dao.
2
 
 
 
 
 
 
3
  from functools import partial
4
  from typing import Optional
5
 
 
1
  # Copyright (c) 2024, Tri Dao.
2
 
3
+ """"
4
+ The implementation was adopted from
5
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
6
+ """
7
+
8
  from functools import partial
9
  from typing import Optional
10
 
embedding.py CHANGED
@@ -1,5 +1,10 @@
1
  # Copyright (c) 2022, Tri Dao.
2
 
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from torch import Tensor
 
1
  # Copyright (c) 2022, Tri Dao.
2
 
3
+ """"
4
+ The implementation was adopted from
5
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
6
+ """
7
+
8
  import torch
9
  import torch.nn as nn
10
  from torch import Tensor
mha.py CHANGED
@@ -1,5 +1,14 @@
1
  # Copyright (c) 2023, Tri Dao.
2
 
 
 
 
 
 
 
 
 
 
3
  import math
4
  from functools import partial
5
 
 
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
+ """"
4
+ The implementation was adopted from
5
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
6
+ and made modifications to
7
+ - support QK normalization
8
+ - make ALiBi run with MHA (needed to cast alibi slopes to fp32)
9
+ - make ALiBi run on CPU
10
+ """
11
+
12
  import math
13
  from functools import partial
14
 
mlp.py CHANGED
@@ -1,5 +1,10 @@
1
  # Copyright (c) 2023, Tri Dao.
2
 
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
+ """"
4
+ The implementation was adopted from
5
+ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
6
+ """
7
+
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F