NEOX / megatron /model /transformer_engine.py
akswelh's picture
Upload 251 files
d90b3a8 verified
import torch
try:
import transformer_engine as te
except ImportError:
raise ImportError(
"Unable to import transformer-engine. Please refer to "
"https://github.com/NVIDIA/TransformerEngine for installation instructions."
)
class TERMSNorm(torch.nn.Module):
def __init__(self, dim, eps=1e-8, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`RMSNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1e-8
"""
super(TERMSNorm, self).__init__()
self.d = dim
self.eps = eps
self.norm = te.pytorch.RMSNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)
def forward(self, x):
return self.norm(x)
class TELayerNorm(torch.nn.Module):
def __init__(self, dim, eps=1.0e-5, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1.0e-5
"""
super(TELayerNorm, self).__init__()
self.d = dim
self.eps = eps
self.norm = te.pytorch.LayerNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)
def forward(self, x):
return self.norm(x)
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
"""
def __init__(self):
# TODO
return
def forward(self, x):
# TODO
return
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(self):
# TODO
return