File size: 3,049 Bytes
d90b3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
try:
import transformer_engine as te
except ImportError:
raise ImportError(
"Unable to import transformer-engine. Please refer to "
"https://github.com/NVIDIA/TransformerEngine for installation instructions."
)
class TERMSNorm(torch.nn.Module):
def __init__(self, dim, eps=1e-8, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`RMSNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1e-8
"""
super(TERMSNorm, self).__init__()
self.d = dim
self.eps = eps
self.norm = te.pytorch.RMSNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)
def forward(self, x):
return self.norm(x)
class TELayerNorm(torch.nn.Module):
def __init__(self, dim, eps=1.0e-5, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1.0e-5
"""
super(TELayerNorm, self).__init__()
self.d = dim
self.eps = eps
self.norm = te.pytorch.LayerNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)
def forward(self, x):
return self.norm(x)
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(self):
# TODO
return
|