|
|
|
|
|
from functools import partial |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from torchvision.ops import StochasticDepth |
|
|
|
from flash_attn.modules.mha import MHA |
|
from flash_attn.modules.mlp import Mlp |
|
|
|
try: |
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm |
|
except ImportError: |
|
layer_norm_fn, RMSNorm = None, None |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
mixer_cls=None, |
|
mlp_cls=None, |
|
norm_cls=nn.LayerNorm, |
|
dropout_cls=nn.Dropout, |
|
prenorm=True, |
|
resid_dropout1=0.0, |
|
resid_dropout2=0.0, |
|
drop_path1=0.0, |
|
drop_path2=0.0, |
|
fused_dropout_add_ln=False, |
|
return_residual=False, |
|
residual_in_fp32=False, |
|
sequence_parallel=False, |
|
mark_shared_params=False, |
|
): |
|
""" |
|
For prenorm=True, this Block has a slightly different structure compared to a regular |
|
prenorm Transformer block. |
|
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. |
|
[Ref: https://arxiv.org/abs/2002.04745] |
|
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both |
|
the hidden_states (output of the MLP) and the residual. |
|
This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
|
The residual needs to be provided (except for the very first block). |
|
|
|
For prenorm=False, this Block has the same structure as a regular postnorm Transformer |
|
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. |
|
|
|
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. |
|
This is for performance reason: for post-norm architecture, returning the input allows us |
|
to fuse the backward of nn.Linear with the residual connection. |
|
""" |
|
super().__init__() |
|
self.prenorm = prenorm |
|
self.fused_dropout_add_ln = fused_dropout_add_ln |
|
self.return_residual = return_residual |
|
self.residual_in_fp32 = residual_in_fp32 |
|
if self.residual_in_fp32: |
|
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" |
|
if mixer_cls is None: |
|
mixer_cls = partial(MHA, num_heads=dim // 64) |
|
if mlp_cls is None: |
|
mlp_cls = partial(Mlp, hidden_features=4 * dim) |
|
self.mixer = mixer_cls(dim) |
|
self.dropout1 = dropout_cls(resid_dropout1) |
|
self.drop_path1 = StochasticDepth(drop_path1, mode="row") |
|
self.norm1 = norm_cls(dim) |
|
self.mlp = mlp_cls(dim) |
|
if not isinstance(self.mlp, nn.Identity): |
|
self.dropout2 = dropout_cls(resid_dropout2) |
|
self.drop_path2 = StochasticDepth(drop_path2, mode="row") |
|
self.norm2 = norm_cls(dim) |
|
|
|
if self.fused_dropout_add_ln: |
|
assert layer_norm_fn is not None, "Triton is not installed" |
|
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
|
self.dropout1, nn.Dropout |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sequence_parallel: |
|
for p in self.norm1.parameters(): |
|
p._sequence_parallel = True |
|
if hasattr(self, "norm2"): |
|
for p in self.norm2.parameters(): |
|
p._sequence_parallel = True |
|
|
|
if mark_shared_params: |
|
for p in self.norm1.parameters(): |
|
p._shared_params = True |
|
if hasattr(self, "norm2"): |
|
for p in self.norm2.parameters(): |
|
p._shared_params = True |
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Tensor, |
|
residual: Optional[Tensor] = None, |
|
mixer_subset=None, |
|
mixer_kwargs=None, |
|
): |
|
r"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
hidden_states: the sequence to the encoder layer (required). |
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) |
|
mixer_subset: for cross-attention only. If not None, will take a subset of x |
|
before applying the query projection. Useful for e.g., ViT where we only care |
|
about the CLS token in the last layer. |
|
""" |
|
if self.prenorm: |
|
if not self.fused_dropout_add_ln: |
|
dropped = self.drop_path1(self.dropout1(hidden_states)) |
|
residual = (dropped + residual) if residual is not None else dropped |
|
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
|
if self.residual_in_fp32: |
|
residual = residual.to(torch.float32) |
|
else: |
|
if self.drop_path1.p == 0 or not self.training: |
|
rowscale1 = None |
|
else: |
|
rowscale1 = self.drop_path1( |
|
torch.ones( |
|
hidden_states.shape[:-1], |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
) |
|
hidden_states, residual = layer_norm_fn( |
|
hidden_states, |
|
self.norm1.weight, |
|
self.norm1.bias, |
|
residual=residual, |
|
eps=self.norm1.eps, |
|
dropout_p=self.dropout1.p if self.training else 0.0, |
|
rowscale=rowscale1, |
|
prenorm=True, |
|
residual_in_fp32=self.residual_in_fp32, |
|
is_rms_norm=isinstance(self.norm1, RMSNorm) |
|
) |
|
if mixer_kwargs is None: |
|
mixer_kwargs = {} |
|
if mixer_subset is not None: |
|
mixer_kwargs["mixer_subset"] = mixer_subset |
|
hidden_states = self.mixer(hidden_states, **mixer_kwargs) |
|
if mixer_subset is not None: |
|
residual = residual[:, mixer_subset] |
|
if not isinstance(self.mlp, nn.Identity): |
|
if not self.fused_dropout_add_ln: |
|
dropped = self.drop_path2(self.dropout2(hidden_states)) |
|
residual = (dropped + residual) if residual is not None else dropped |
|
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
|
if self.residual_in_fp32: |
|
residual = residual.to(torch.float32) |
|
else: |
|
if self.drop_path2.p == 0 or not self.training: |
|
rowscale2 = None |
|
else: |
|
rowscale2 = self.drop_path2( |
|
torch.ones( |
|
hidden_states.shape[:-1], |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
) |
|
hidden_states, residual = layer_norm_fn( |
|
hidden_states, |
|
self.norm2.weight, |
|
self.norm2.bias, |
|
residual=residual, |
|
eps=self.norm2.eps, |
|
dropout_p=self.dropout2.p if self.training else 0.0, |
|
rowscale=rowscale2, |
|
prenorm=True, |
|
residual_in_fp32=self.residual_in_fp32, |
|
is_rms_norm=isinstance(self.norm2, RMSNorm) |
|
) |
|
hidden_states = self.mlp(hidden_states) |
|
return hidden_states, residual |
|
else: |
|
assert residual is None |
|
mixer_out = self.mixer( |
|
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) |
|
) |
|
if self.return_residual: |
|
mixer_out, hidden_states = mixer_out |
|
if not self.fused_dropout_add_ln: |
|
hidden_states = self.norm1( |
|
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( |
|
dtype=self.norm1.weight.dtype |
|
) |
|
) |
|
else: |
|
if self.drop_path1.p == 0 or not self.training: |
|
rowscale1 = None |
|
else: |
|
rowscale1 = self.drop_path1( |
|
torch.ones( |
|
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype |
|
) |
|
) |
|
hidden_states = layer_norm_fn( |
|
mixer_out, |
|
self.norm1.weight, |
|
self.norm1.bias, |
|
residual=hidden_states, |
|
eps=self.norm1.eps, |
|
dropout_p=self.dropout1.p if self.training else 0.0, |
|
rowscale=rowscale1, |
|
prenorm=False, |
|
is_rms_norm=isinstance(self.norm1, RMSNorm) |
|
) |
|
if not isinstance(self.mlp, nn.Identity): |
|
mlp_out = self.mlp(hidden_states) |
|
if self.return_residual: |
|
mlp_out, hidden_states = mlp_out |
|
if not self.fused_dropout_add_ln: |
|
hidden_states = self.norm2( |
|
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( |
|
dtype=self.norm2.weight.dtype |
|
) |
|
) |
|
else: |
|
if self.drop_path2.p == 0 or not self.training: |
|
rowscale2 = None |
|
else: |
|
rowscale2 = self.drop_path2( |
|
torch.ones( |
|
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype |
|
) |
|
) |
|
hidden_states = layer_norm_fn( |
|
mlp_out, |
|
self.norm2.weight, |
|
self.norm2.bias, |
|
residual=hidden_states, |
|
eps=self.norm2.eps, |
|
dropout_p=self.dropout2.p if self.training else 0.0, |
|
rowscale=rowscale2, |
|
prenorm=False, |
|
is_rms_norm=isinstance(self.norm2, RMSNorm) |
|
) |
|
return hidden_states |
|
|
|
|
|
class ParallelBlock(nn.Module): |
|
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, |
|
and PaLM. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
mixer_cls=None, |
|
mlp_cls=None, |
|
norm_cls=nn.LayerNorm, |
|
dropout_cls=nn.Dropout, |
|
resid_dropout1=0.0, |
|
resid_dropout2=0.0, |
|
tied_norm=False, |
|
fused_dropout_add_ln=False, |
|
residual_in_fp32=False, |
|
sequence_parallel=False, |
|
mark_shared_params=False, |
|
): |
|
""" |
|
This Block has a slightly different structure compared to a regular |
|
prenorm Transformer block. |
|
The standard block is: LN -> MHA / MLP -> Dropout -> Add. |
|
[Ref: https://arxiv.org/abs/2002.04745] |
|
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both |
|
the hidden_states (output1 of the MHA / MLP) and the residual. |
|
This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
|
The residual needs to be provided (except for the very first block). |
|
""" |
|
super().__init__() |
|
self.tied_norm = tied_norm |
|
self.fused_dropout_add_ln = fused_dropout_add_ln |
|
self.residual_in_fp32 = residual_in_fp32 |
|
if mixer_cls is None: |
|
mixer_cls = partial(MHA, num_heads=dim // 64) |
|
if mlp_cls is None: |
|
mlp_cls = partial(Mlp, hidden_features=4 * dim) |
|
self.mixer = mixer_cls(dim) |
|
self.dropout1 = dropout_cls(resid_dropout1) |
|
self.norm1 = norm_cls(dim) |
|
self.mlp = mlp_cls(dim) |
|
self.dropout2 = dropout_cls(resid_dropout2) |
|
if not self.tied_norm: |
|
self.norm2 = norm_cls(dim) |
|
|
|
if self.fused_dropout_add_ln: |
|
assert layer_norm_fn is not None, "Triton is not installed" |
|
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
|
self.dropout1, nn.Dropout |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sequence_parallel: |
|
for p in self.norm1.parameters(): |
|
p._sequence_parallel = True |
|
if hasattr(self, "norm2"): |
|
for p in self.norm2.parameters(): |
|
p._sequence_parallel = True |
|
|
|
if mark_shared_params: |
|
for p in self.norm1.parameters(): |
|
p._shared_params = True |
|
if hasattr(self, "norm2"): |
|
for p in self.norm2.parameters(): |
|
p._shared_params = True |
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
|
def forward( |
|
self, |
|
hidden_states1: Tensor, |
|
hidden_states2: Optional[Tensor] = None, |
|
residual: Optional[Tensor] = None, |
|
mixer_kwargs=None, |
|
): |
|
r"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
hidden_states1: the output of the previous attention (mixer) or embedding layer. |
|
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). |
|
residual. |
|
""" |
|
|
|
|
|
if not self.fused_dropout_add_ln: |
|
dropped1 = self.dropout1(hidden_states1) |
|
|
|
if hidden_states2 is not None: |
|
dropped2 = self.dropout2(hidden_states2) |
|
residual = ( |
|
(residual + dropped1 + dropped2) |
|
if residual is not None |
|
else dropped1 + dropped2 |
|
) |
|
else: |
|
residual = (residual + dropped1) if residual is not None else dropped1 |
|
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
|
hidden_states2 = ( |
|
self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
|
if not self.tied_norm |
|
else hidden_states1 |
|
) |
|
if self.residual_in_fp32: |
|
residual = residual.to(torch.float32) |
|
else: |
|
weight2, bias2 = ( |
|
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) |
|
) |
|
hidden_states1, *rest, residual = layer_norm_fn( |
|
hidden_states1, |
|
self.norm1.weight, |
|
self.norm1.bias, |
|
residual=residual, |
|
x1=hidden_states2, |
|
weight1=weight2, |
|
bias1=bias2, |
|
eps=self.norm1.eps, |
|
dropout_p=self.dropout1.p if self.training else 0.0, |
|
prenorm=True, |
|
residual_in_fp32=self.residual_in_fp32, |
|
is_rms_norm=isinstance(self.norm1, RMSNorm) |
|
) |
|
if self.tied_norm: |
|
hidden_states2 = hidden_states1 |
|
else: |
|
hidden_states2, = rest |
|
if mixer_kwargs is None: |
|
mixer_kwargs = {} |
|
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) |
|
hidden_states2 = self.mlp(hidden_states2) |
|
return hidden_states1, hidden_states2, residual |
|
|