|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer.""" |
|
|
|
import math |
|
from contextlib import nullcontext |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from pkg_resources import packaging |
|
from importlib.metadata import version |
|
|
|
from .norms import get_norm |
|
from megatron import mpu |
|
from megatron.model import megablocks_utils |
|
from megatron.model.fused_softmax import FusedScaleMaskSoftmax |
|
from megatron.model.activations import get_activation |
|
from megatron.model.utils import exists, get_fusion_type |
|
from megatron.model.positional_embeddings import ( |
|
RotaryEmbedding, |
|
apply_rotary_pos_emb_torch, |
|
apply_rotary_pos_emb, |
|
AliBi, |
|
) |
|
from megatron.model.fused_rope import ( |
|
FusedRoPEFunc, |
|
fused_apply_rotary_pos_emb_cached, |
|
) |
|
from megatron.model.fused_bias_dropout import ( |
|
get_bias_dropout_add, |
|
bias_dropout_add_fused_train, |
|
bias_dropout_add_fused_inference, |
|
) |
|
from megatron.model.utils import configure_sparse_attention |
|
from deepspeed.moe.layer import MoE |
|
|
|
try: |
|
from flash_attn.ops.activations import swiglu |
|
except ImportError: |
|
swiglu = None |
|
|
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
torch._C._jit_set_profiling_executor(False) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
""" We use the following notation throughout this file: |
|
h: hidden size |
|
n: number of attention heads |
|
kv: number of key or value heads |
|
p: number of model parallel partitions |
|
np: n/p |
|
kvp: kv/p |
|
hp: h/p |
|
hn: h/n |
|
b: batch size |
|
s: sequence length |
|
l: number of layers |
|
Transformer takes input of size [s, b, h] and returns a |
|
tensor of the same size. We use the following arguments: |
|
hyperparameters: transformer hyperparameters |
|
attention_mask_func: a function that takes `unmasked-attention-scores` |
|
with size [b, np, s, s] and an `attention-mask` and will apply |
|
the masking. The function should return a masked score of the |
|
same size [b, np, s, s]. |
|
masked-attention-scores = attention_mask_func( |
|
unmasked-attention-scores, attention-mask) |
|
""" |
|
|
|
|
|
class ParallelMLP(nn.Module): |
|
"""MLP. |
|
|
|
MLP will take the input with h hidden state, project it to 4*h |
|
hidden dimension, perform nonlinear transformation, and project the |
|
state back into h hidden dimension. At the end, dropout is also |
|
applied. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
init_method, |
|
output_layer_init_method, |
|
parallel_output=False, |
|
multiple_of=256, |
|
MOE=False, |
|
MoE_mp_size=1, |
|
): |
|
super().__init__() |
|
assert ( |
|
neox_args.intermediate_size == None or neox_args.expansion_factor == None |
|
), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" |
|
|
|
self.activation_func, self.is_gated = get_activation(neox_args) |
|
self.activation_type = neox_args.activation |
|
self.bias_gelu_fusion = neox_args.bias_gelu_fusion |
|
self.multiple_of = multiple_of |
|
|
|
if neox_args.intermediate_size: |
|
ffn_dim = neox_args.intermediate_size |
|
elif neox_args.expansion_factor: |
|
ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) |
|
else: |
|
|
|
ffn_dim = 4 * neox_args.hidden_size |
|
ffn_dim_in = ffn_dim |
|
if self.is_gated: |
|
|
|
self.activation_func = Gated_Activation( |
|
self.activation_func, |
|
(swiglu is not None) |
|
and (neox_args.activation == "swiglu") |
|
and neox_args.use_flashattn_swiglu, |
|
) |
|
|
|
ffn_dim = int(ffn_dim * 2 / 3) |
|
ffn_dim_in = ffn_dim // 2 |
|
|
|
ffn_dim = int( |
|
(2 * self.multiple_of) |
|
* ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) |
|
) |
|
ffn_dim_in = int( |
|
self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) |
|
) |
|
|
|
self.linear1 = mpu.ColumnParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=ffn_dim, |
|
gather_output=False, |
|
init_method=init_method, |
|
skip_bias_add=True, |
|
MOE=MOE, |
|
MoE_mp_size=MoE_mp_size, |
|
bias=neox_args.use_bias_in_mlp, |
|
) |
|
|
|
self.linear2 = mpu.RowParallelLinear( |
|
neox_args=neox_args, |
|
input_size=ffn_dim_in, |
|
output_size=neox_args.hidden_size, |
|
input_is_parallel=True, |
|
init_method=output_layer_init_method, |
|
parallel_output=parallel_output, |
|
skip_bias_add=True, |
|
MOE=MOE, |
|
MoE_mp_size=MoE_mp_size, |
|
bias=neox_args.use_bias_in_mlp, |
|
) |
|
|
|
def forward(self, hidden_states): |
|
|
|
intermediate_parallel, bias_parallel = self.linear1(hidden_states) |
|
|
|
if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): |
|
intermediate_parallel = self.activation_func( |
|
intermediate_parallel, bias_parallel |
|
) |
|
else: |
|
intermediate_parallel = self.activation_func( |
|
intermediate_parallel + bias_parallel |
|
) |
|
|
|
|
|
output, output_bias = self.linear2(intermediate_parallel) |
|
return output, output_bias |
|
|
|
|
|
class Gated_Activation(torch.nn.Module): |
|
def __init__(self, activation_func, use_swiglu=False): |
|
super().__init__() |
|
self.activation_func = activation_func |
|
self.use_swiglu = use_swiglu |
|
|
|
def forward(self, x, bias=None): |
|
x, gate = x.chunk(2, dim=-1) |
|
if bias is not None: |
|
bias_1, bias_2 = bias.chunk(2, dim=-1) |
|
x = x + bias_1 |
|
gate = gate + bias_2 |
|
if not self.use_swiglu: |
|
intermediate_parallel = self.activation_func(gate) |
|
return intermediate_parallel * x |
|
else: |
|
return swiglu(gate, x) |
|
|
|
|
|
class ParallelLinear(nn.Module): |
|
""" |
|
A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
parallel_output=True, |
|
init_method=nn.init.xavier_normal_, |
|
is_last_layer=False, |
|
): |
|
super().__init__() |
|
self.is_rm = neox_args.train_impl == "rm" |
|
parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" |
|
if parallelism == "column": |
|
self.final_linear = mpu.ColumnParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=neox_args.padded_vocab_size, |
|
bias=False, |
|
init_method=init_method, |
|
gather_output=not parallel_output, |
|
skip_bias_add=False, |
|
mup_rescale_parameters=is_last_layer, |
|
seq_dim=1, |
|
) |
|
else: |
|
if not self.is_rm: |
|
print( |
|
'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' |
|
) |
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
self.rm_linear = mpu.RowParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=1, |
|
bias=False, |
|
input_is_parallel=False, |
|
init_method=init_method, |
|
parallel_output=False, |
|
skip_bias_add=False, |
|
mup_rescale_parameters=is_last_layer, |
|
) |
|
|
|
def forward(self, hidden_states): |
|
if not self.is_rm: |
|
return self.final_linear(hidden_states) |
|
else: |
|
return self.rm_linear(hidden_states) |
|
|
|
|
|
class _MegablocksAdapter(nn.Module): |
|
def __init__( |
|
self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group |
|
): |
|
super().__init__() |
|
megablocks_utils.assert_megablocks_is_available() |
|
args = megablocks_utils.as_megablocks_args(neox_args) |
|
args.device = torch.cuda.current_device() |
|
args.init_method = init_method |
|
args.output_layer_init_method = output_layer_init_method |
|
|
|
|
|
|
|
|
|
|
|
if args.moe_expert_model_parallelism: |
|
args.expert_parallel_group = ep_group |
|
|
|
self.moe = layer_cls(args) |
|
|
|
def forward(self, x): |
|
return self.moe.forward(x) |
|
|
|
|
|
class MbMoE(_MegablocksAdapter): |
|
def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): |
|
super().__init__( |
|
neox_args, |
|
megablocks_utils.moe.MoE, |
|
init_method, |
|
output_layer_init_method, |
|
ep_group, |
|
) |
|
|
|
|
|
class dMoE(_MegablocksAdapter): |
|
def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): |
|
super().__init__( |
|
neox_args, |
|
megablocks_utils.dmoe.dMoE, |
|
init_method, |
|
output_layer_init_method, |
|
ep_group, |
|
) |
|
|
|
|
|
class ParallelSelfAttention(nn.Module): |
|
"""Parallel self-attention layer abstract class. |
|
|
|
Self-attention layer takes input with size [b, s, h] |
|
and returns output of the same size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
attention_mask_func, |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
rpe=None, |
|
rotary=False, |
|
use_cache=False, |
|
parallel_output=False, |
|
): |
|
super().__init__() |
|
|
|
self.fp16 = neox_args.precision == "fp16" |
|
self.bf16 = neox_args.precision == "bfloat16" |
|
self.attention_mask_func = attention_mask_func |
|
self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling |
|
self.use_cache = use_cache |
|
self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 |
|
if self.apply_query_key_layer_scaling: |
|
self.attention_softmax_in_fp32 = True |
|
self.layer_number = layer_number |
|
|
|
world_size = mpu.get_model_parallel_world_size() |
|
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) |
|
self.hidden_size_per_attention_head = mpu.divide( |
|
neox_args.hidden_size, neox_args.num_attention_heads |
|
) |
|
self.num_attention_heads_per_partition = mpu.divide( |
|
neox_args.num_attention_heads, world_size |
|
) |
|
self.pos_emb = neox_args.pos_emb |
|
|
|
self.use_qk_layernorm = neox_args.use_qk_layernorm |
|
if self.use_qk_layernorm: |
|
norm, eps = get_norm(neox_args) |
|
self.qk_layernorm = norm( |
|
[ |
|
self.num_attention_heads_per_partition, |
|
self.hidden_size_per_attention_head, |
|
], |
|
eps=eps, |
|
) |
|
|
|
self.sliding_window_width = neox_args.sliding_window_width |
|
|
|
if ( |
|
not neox_args.num_kv_heads |
|
or neox_args.num_kv_heads == neox_args.num_attention_heads |
|
): |
|
self.gqa = False |
|
else: |
|
self.gqa = True |
|
if self.gqa: |
|
self.num_kv_heads_per_partition = mpu.divide( |
|
neox_args.num_kv_heads, world_size |
|
) |
|
self.kv_hidden_size = ( |
|
neox_args.num_kv_heads * self.hidden_size_per_attention_head |
|
) |
|
else: |
|
self.num_kv_heads_per_partition = self.num_attention_heads_per_partition |
|
self.kv_hidden_size = neox_args.hidden_size |
|
|
|
if not self.gqa: |
|
|
|
self.query_key_value = mpu.ColumnParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=3 * neox_args.hidden_size, |
|
gather_output=False, |
|
init_method=init_method, |
|
bias=neox_args.use_bias_in_attn_linear, |
|
) |
|
else: |
|
|
|
self.query_key_value = mpu.ColumnParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, |
|
gather_output=False, |
|
init_method=init_method, |
|
bias=neox_args.use_bias_in_attn_linear, |
|
) |
|
|
|
coeff = None |
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) |
|
if self.apply_query_key_layer_scaling: |
|
coeff = max(1, self.layer_number) |
|
self.norm_factor *= coeff |
|
|
|
if neox_args.use_mup: |
|
self.norm_factor = self.hidden_size_per_attention_head |
|
|
|
self.rpe = rpe |
|
|
|
if self.pos_emb == "alibi": |
|
self.alibi_embed = AliBi( |
|
neox_args.num_attention_heads, |
|
neox_args.model_parallel_size, |
|
mpu.get_model_parallel_rank(), |
|
) |
|
|
|
|
|
if rotary: |
|
if neox_args.rotary_pct == 1: |
|
self.rotary_ndims = None |
|
else: |
|
assert neox_args.rotary_pct < 1 |
|
self.rotary_ndims = int( |
|
self.hidden_size_per_attention_head * neox_args.rotary_pct |
|
) |
|
dim = ( |
|
self.rotary_ndims |
|
if self.rotary_ndims is not None |
|
else self.hidden_size_per_attention_head |
|
) |
|
self.rotary_emb = RotaryEmbedding( |
|
dim, |
|
base=neox_args.rotary_emb_base, |
|
max_seq_len=neox_args.seq_length, |
|
precision=neox_args.params_dtype, |
|
save_inv_freqs=neox_args.rotary_save_freqs_buffer, |
|
) |
|
else: |
|
self.rotary_emb = None |
|
|
|
self.rope_fusion = neox_args.rope_fusion |
|
self.attention_type = neox_args.attention_config[layer_number] |
|
self.use_flash_attention = self.attention_type == "flash" |
|
self.use_triton = ( |
|
self.use_flash_attention |
|
and self.pos_emb == "alibi" |
|
and ( |
|
not packaging.version.Version(version("flash-attn")) |
|
>= packaging.version.Version("2.4.0.post1") |
|
) |
|
) |
|
self.sparse = self.attention_type not in ("global", "flash") |
|
|
|
if self.gqa: |
|
assert not self.sparse |
|
|
|
if self.sparse: |
|
self.sparse_attn = configure_sparse_attention( |
|
neox_args, |
|
self.attention_type, |
|
self.num_attention_heads_per_partition, |
|
mpu=mpu, |
|
) |
|
else: |
|
if self.use_flash_attention: |
|
|
|
|
|
|
|
|
|
from flash_attn.flash_attn_interface import ( |
|
flash_attn_func, |
|
flash_attn_varlen_func, |
|
) |
|
from flash_attn.flash_attn_triton import ( |
|
flash_attn_func as flash_attn_unpadded_unpacked_func_triton, |
|
) |
|
|
|
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton |
|
self.flash_qkv_fn = flash_attn_func |
|
self.flash_varlen_qkv_fn = flash_attn_varlen_func |
|
else: |
|
self.scale_mask_softmax = FusedScaleMaskSoftmax( |
|
input_in_fp16=self.fp16, |
|
input_in_bf16=self.bf16, |
|
fusion_type=get_fusion_type(neox_args), |
|
mask_func=self.attention_mask_func, |
|
softmax_in_fp32=self.attention_softmax_in_fp32, |
|
scale=coeff, |
|
) |
|
|
|
|
|
|
|
|
|
self.dropout_p = neox_args.attention_dropout |
|
self.attention_dropout = nn.Dropout(self.dropout_p) |
|
|
|
|
|
self.dense = mpu.RowParallelLinear( |
|
neox_args=neox_args, |
|
input_size=neox_args.hidden_size, |
|
output_size=neox_args.hidden_size, |
|
input_is_parallel=True, |
|
init_method=output_layer_init_method, |
|
skip_bias_add=True, |
|
parallel_output=parallel_output, |
|
bias=neox_args.use_bias_in_attn_linear, |
|
) |
|
|
|
def attention( |
|
self, query_layer, key_layer, value_layer, layer_past, attention_mask |
|
): |
|
|
|
|
|
|
|
|
|
|
|
output_size = ( |
|
query_layer.size(1), |
|
query_layer.size(2), |
|
query_layer.size(0), |
|
key_layer.size(0), |
|
) |
|
|
|
query_layer = query_layer.view( |
|
output_size[2], output_size[0] * output_size[1], -1 |
|
) |
|
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) |
|
|
|
matmul_result = torch.empty( |
|
output_size[0] * output_size[1], |
|
output_size[2], |
|
output_size[3], |
|
dtype=query_layer.dtype, |
|
device=torch.cuda.current_device(), |
|
) |
|
|
|
|
|
matmul_result = torch.baddbmm( |
|
matmul_result, |
|
query_layer.transpose(0, 1), |
|
key_layer.transpose(0, 1).transpose(1, 2), |
|
beta=0.0, |
|
alpha=(1.0 / self.norm_factor), |
|
) |
|
|
|
|
|
attention_scores = matmul_result.view(*output_size) |
|
|
|
|
|
|
|
|
|
if self.use_cache: |
|
with torch.no_grad(): |
|
attention_mask = attention_mask[ |
|
..., : attention_scores.size(3), : attention_scores.size(3) |
|
] |
|
|
|
|
|
|
|
|
|
|
|
if exists(self.rpe): |
|
rpe = self.rpe(query_layer.size(0), key_layer.size(0)) |
|
attention_scores += rpe |
|
|
|
if self.pos_emb == "alibi": |
|
attention_scores = self.alibi_embed(attention_scores) |
|
|
|
|
|
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) |
|
|
|
|
|
|
|
with mpu.get_cuda_rng_tracker().fork(): |
|
attention_probs = self.attention_dropout(attention_probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_size = ( |
|
value_layer.size(1), |
|
value_layer.size(2), |
|
query_layer.size(0), |
|
value_layer.size(3), |
|
) |
|
|
|
|
|
value_layer = value_layer.view( |
|
value_layer.size(0), output_size[0] * output_size[1], -1 |
|
) |
|
|
|
|
|
attention_probs = attention_probs.view( |
|
output_size[0] * output_size[1], output_size[2], -1 |
|
) |
|
|
|
|
|
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) |
|
|
|
|
|
context_layer = context_layer.view(*output_size) |
|
return context_layer |
|
|
|
def flash_attention(self, query_layer, key_layer, value_layer): |
|
|
|
output_size = ( |
|
query_layer.size(1), |
|
query_layer.size(2), |
|
query_layer.size(0), |
|
key_layer.size(0), |
|
) |
|
|
|
if self.use_flash_attention and not self.use_triton: |
|
|
|
|
|
key_layer = key_layer.transpose(0, 1).reshape( |
|
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 |
|
) |
|
value_layer = value_layer.transpose(0, 1).reshape( |
|
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 |
|
) |
|
|
|
|
|
query_layer = query_layer.transpose(0, 1).reshape( |
|
output_size[0], output_size[2], output_size[1], -1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
extra_kwargs = ( |
|
{"window_size": (self.sliding_window_width, -1)} |
|
if self.sliding_window_width is not None |
|
else {} |
|
) |
|
if self.pos_emb == "alibi": |
|
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( |
|
query_layer.device |
|
).to(torch.float32) |
|
|
|
if not self.training: |
|
batch_size = output_size[0] |
|
max_seqlen_q = output_size[2] |
|
max_seqlen_k = output_size[3] |
|
|
|
cu_seqlens_q = torch.arange( |
|
0, |
|
(batch_size + 1) * max_seqlen_q, |
|
step=max_seqlen_q, |
|
dtype=torch.int32, |
|
device=query_layer.device, |
|
) |
|
|
|
cu_seqlens_k = torch.arange( |
|
0, |
|
(batch_size + 1) * max_seqlen_k, |
|
step=max_seqlen_k, |
|
dtype=torch.int32, |
|
device=key_layer.device, |
|
) |
|
|
|
q_shape = query_layer.shape |
|
k_shape = key_layer.shape |
|
v_shape = value_layer.shape |
|
is_causal = max_seqlen_q == max_seqlen_k |
|
output = self.flash_varlen_qkv_fn( |
|
query_layer.reshape( |
|
(q_shape[0] * q_shape[1], q_shape[2], q_shape[3]) |
|
), |
|
key_layer.reshape( |
|
(k_shape[0] * k_shape[1], k_shape[2], k_shape[3]) |
|
), |
|
value_layer.reshape( |
|
(v_shape[0] * v_shape[1], v_shape[2], v_shape[3]) |
|
), |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
softmax_scale=None, |
|
causal=is_causal, |
|
**extra_kwargs, |
|
) |
|
output = output.reshape(q_shape) |
|
else: |
|
output = self.flash_qkv_fn( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
self.dropout_p if self.training else 0.0, |
|
softmax_scale=None, |
|
causal=True, |
|
**extra_kwargs, |
|
) |
|
|
|
matmul_result = output |
|
|
|
matmul_result = matmul_result.transpose(1, 2) |
|
|
|
else: |
|
|
|
|
|
|
|
sq = query_layer.size(0) |
|
b = query_layer.size(1) |
|
sk = key_layer.size(0) |
|
|
|
query_layer = query_layer.transpose(0, 1) |
|
key_layer = key_layer.transpose(0, 1) |
|
value_layer = value_layer.transpose(0, 1) |
|
|
|
bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype) |
|
bias = bias.unsqueeze(0).tile((b, 1, 1, 1)) |
|
|
|
matmul_result = self.flash_triton_fn( |
|
query_layer, key_layer, value_layer, bias=bias, causal=True |
|
) |
|
matmul_result = matmul_result.transpose(1, 2) |
|
|
|
return matmul_result |
|
|
|
def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): |
|
|
|
|
|
|
|
query_layer, key_layer, value_layer = map( |
|
lambda t: t.permute(1, 2, 0, 3).contiguous(), |
|
(query_layer, key_layer, value_layer), |
|
) |
|
|
|
attn_mask = attention_mask.to(query_layer.dtype) * -10000 |
|
if exists(self.rpe): |
|
rpe = self.rpe(query_layer.size(0), key_layer.size(0)) |
|
else: |
|
rpe = None |
|
return self.sparse_attn( |
|
query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe |
|
) |
|
|
|
def gqa_project(self, hidden_states, attention_mask, layer_past=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
mixed_x_layer, _ = self.query_key_value(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_qkv_shape = ( |
|
mixed_x_layer.shape[0], |
|
mixed_x_layer.shape[1], |
|
self.num_attention_heads_per_partition, |
|
int( |
|
self.hidden_size_per_attention_head |
|
* ( |
|
1 |
|
+ 2 |
|
* ( |
|
self.num_kv_heads_per_partition |
|
/ self.num_attention_heads_per_partition |
|
) |
|
) |
|
), |
|
) |
|
mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) |
|
|
|
|
|
|
|
split_sizes = ( |
|
self.hidden_size_per_attention_head, |
|
int( |
|
( |
|
self.num_kv_heads_per_partition |
|
/ self.num_attention_heads_per_partition |
|
) |
|
* self.hidden_size_per_attention_head |
|
), |
|
int( |
|
( |
|
self.num_kv_heads_per_partition |
|
/ self.num_attention_heads_per_partition |
|
) |
|
* self.hidden_size_per_attention_head |
|
), |
|
) |
|
|
|
|
|
(query_layer, key_layer, value_layer) = [ |
|
x.contiguous() |
|
for x in torch.split( |
|
mixed_x_layer, |
|
split_sizes, |
|
dim=mixed_x_layer.dim() - 1, |
|
) |
|
] |
|
|
|
|
|
|
|
new_kv_shape = ( |
|
key_layer.size(0), |
|
key_layer.size(1), |
|
self.num_kv_heads_per_partition, |
|
self.hidden_size_per_attention_head, |
|
) |
|
|
|
key_layer = key_layer.view(*new_kv_shape) |
|
|
|
value_layer = value_layer.view(*new_kv_shape) |
|
|
|
|
|
if not self.use_flash_attention: |
|
key_layer = torch.repeat_interleave( |
|
key_layer, |
|
repeats=int( |
|
self.num_attention_heads_per_partition |
|
// self.num_kv_heads_per_partition |
|
), |
|
dim=2, |
|
) |
|
value_layer = torch.repeat_interleave( |
|
value_layer, |
|
repeats=int( |
|
self.num_attention_heads_per_partition |
|
// self.num_kv_heads_per_partition |
|
), |
|
dim=2, |
|
) |
|
|
|
return query_layer, key_layer, value_layer |
|
|
|
def forward(self, hidden_states, attention_mask, layer_past=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.gqa: |
|
|
|
|
|
|
|
mixed_x_layer, _ = self.query_key_value(hidden_states) |
|
|
|
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + ( |
|
self.num_attention_heads_per_partition, |
|
3 * self.hidden_size_per_attention_head, |
|
) |
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
|
|
|
|
|
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( |
|
mixed_x_layer, 3 |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
query_layer, key_layer, value_layer = self.gqa_project( |
|
hidden_states, attention_mask, layer_past=layer_past |
|
) |
|
|
|
|
|
if self.use_qk_layernorm: |
|
query_layer = self.qk_layernorm(query_layer) |
|
key_layer = self.qk_layernorm(key_layer) |
|
|
|
if exists(self.rotary_emb): |
|
if exists(self.rotary_ndims): |
|
|
|
query_rot, query_pass = ( |
|
query_layer[..., : self.rotary_ndims], |
|
query_layer[..., self.rotary_ndims :], |
|
) |
|
key_rot, key_pass = ( |
|
key_layer[..., : self.rotary_ndims], |
|
key_layer[..., self.rotary_ndims :], |
|
) |
|
else: |
|
|
|
query_rot, key_rot = query_layer, key_layer |
|
|
|
seq_len = key_layer.shape[0] |
|
offset = 0 |
|
if exists(layer_past) and layer_past.numel() > 0: |
|
offset = layer_past[0].shape[0] |
|
seq_len += offset |
|
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) |
|
if self.rope_fusion: |
|
query_layer, key_layer = ( |
|
fused_apply_rotary_pos_emb_cached(rot, cos, sin) |
|
for rot in [query_rot, key_rot] |
|
) |
|
else: |
|
if self.bf16: |
|
apply_rotary_fn = apply_rotary_pos_emb_torch |
|
else: |
|
apply_rotary_fn = apply_rotary_pos_emb |
|
query_layer, key_layer = apply_rotary_fn( |
|
query_rot, key_rot, cos, sin, offset=offset |
|
) |
|
|
|
if exists(self.rotary_ndims): |
|
query_layer = torch.cat((query_layer, query_pass), dim=-1) |
|
key_layer = torch.cat((key_layer, key_pass), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
if exists(layer_past) and layer_past.numel() > 0: |
|
past_key, past_value = layer_past |
|
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) |
|
value_layer = torch.cat( |
|
(past_value.type_as(value_layer), value_layer), dim=0 |
|
) |
|
|
|
if self.use_cache: |
|
present = torch.stack((key_layer, value_layer)) |
|
|
|
if self.use_flash_attention: |
|
context_layer = self.flash_attention(query_layer, key_layer, value_layer) |
|
elif not self.sparse: |
|
context_layer = self.attention( |
|
query_layer, key_layer, value_layer, layer_past, attention_mask |
|
) |
|
else: |
|
context_layer = self.sparse_attention( |
|
query_layer, key_layer, value_layer, attention_mask |
|
) |
|
|
|
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + ( |
|
self.hidden_size_per_partition, |
|
) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
|
|
|
|
|
|
|
|
|
|
output, bias = self.dense(context_layer) |
|
|
|
if self.use_cache: |
|
output = [output, present] |
|
|
|
return output, bias |
|
|
|
|
|
class ParallelTransformerLayer(nn.Module): |
|
"""A single transformer layer. |
|
|
|
Transformer layer takes input with size [b, s, h] and returns an |
|
output of the same size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
attention_mask_func, |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
rpe=None, |
|
rotary=False, |
|
use_cache=False, |
|
): |
|
|
|
super().__init__() |
|
self.layer_number = layer_number |
|
self.neox_args = neox_args |
|
|
|
norm, eps = get_norm(neox_args) |
|
|
|
|
|
self.input_layernorm = norm(neox_args.hidden_size, eps=eps) |
|
self.use_cache = use_cache |
|
|
|
self.hidden_dropout = neox_args.hidden_dropout |
|
self.bias_dropout_fusion = neox_args.bias_dropout_fusion |
|
self.gpt_j_residual = neox_args.gpt_j_residual |
|
self.gpt_j_tied = neox_args.gpt_j_tied |
|
self.moe_type = neox_args.moe_type |
|
self.activation = neox_args.activation |
|
|
|
if self.gpt_j_residual: |
|
|
|
|
|
|
|
self.reduce = ( |
|
mpu.mappings.reduce_from_model_parallel_region |
|
if not neox_args.sequence_parallel |
|
else mpu.mappings.reduce_scatter_to_sequence_parallel_region |
|
) |
|
|
|
|
|
self.attention = ParallelSelfAttention( |
|
neox_args=neox_args, |
|
attention_mask_func=attention_mask_func, |
|
init_method=init_method, |
|
output_layer_init_method=output_layer_init_method, |
|
layer_number=layer_number, |
|
rpe=rpe, |
|
use_cache=self.use_cache, |
|
rotary=rotary, |
|
parallel_output=self.gpt_j_residual, |
|
) |
|
|
|
|
|
|
|
|
|
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) |
|
|
|
|
|
def get_mlp(**kw): |
|
return ParallelMLP( |
|
neox_args=neox_args, |
|
init_method=init_method, |
|
output_layer_init_method=output_layer_init_method, |
|
parallel_output=self.gpt_j_residual, |
|
multiple_of=neox_args.mlp_multiple_of, |
|
**kw, |
|
) |
|
|
|
self.num_experts = ( |
|
neox_args.moe_num_experts |
|
if layer_number % neox_args.expert_interval == 0 |
|
else 1 |
|
) |
|
args = neox_args |
|
if self.num_experts <= 1: |
|
self.mlp = get_mlp() |
|
else: |
|
from torch import distributed as dist |
|
|
|
if self.num_experts > dist.get_world_size(): |
|
moe_mp_size = 1 |
|
else: |
|
moe_mp_size = dist.get_world_size() // self.num_experts |
|
|
|
if neox_args.moe_type == "deepspeed": |
|
self.mlp = MoE( |
|
args.hidden_size, |
|
get_mlp( |
|
"regular", |
|
MOE=True, |
|
MoE_mp_size=moe_mp_size, |
|
), |
|
num_experts=self.num_experts, |
|
ep_size=args.moe_expert_parallel_size, |
|
k=args.moe_top_k, |
|
use_residual=args.moe_use_residual, |
|
capacity_factor=args.moe_train_capacity_factor, |
|
eval_capacity_factor=args.moe_eval_capacity_factor, |
|
min_capacity=args.moe_min_capacity, |
|
drop_tokens=args.moe_token_dropping, |
|
use_tutel=args.use_tutel, |
|
enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, |
|
) |
|
elif neox_args.moe_type == "megablocks": |
|
|
|
def integrate_megablocks_with_ds_expert_parallelism(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
throwaway = MoE( |
|
args.hidden_size, |
|
get_mlp( |
|
"regular", |
|
MOE=True, |
|
MoE_mp_size=moe_mp_size, |
|
), |
|
num_experts=self.num_experts, |
|
ep_size=args.moe_expert_parallel_size, |
|
k=args.moe_top_k, |
|
use_residual=args.moe_use_residual, |
|
capacity_factor=args.moe_train_capacity_factor, |
|
eval_capacity_factor=args.moe_eval_capacity_factor, |
|
min_capacity=args.moe_min_capacity, |
|
drop_tokens=args.moe_token_dropping, |
|
use_tutel=args.use_tutel, |
|
enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, |
|
) |
|
throwaway.set_deepspeed_parallelism() |
|
|
|
ep_group = throwaway.deepspeed_moe.ep_group |
|
if args.moe_token_dropping: |
|
self.mlp = MbMoE( |
|
neox_args, init_method, output_layer_init_method, ep_group |
|
) |
|
else: |
|
self.mlp = dMoE( |
|
neox_args, init_method, output_layer_init_method, ep_group |
|
) |
|
|
|
|
|
for param in self.mlp.parameters(): |
|
if getattr(param, "expert_model_parallel", None) is not None: |
|
|
|
param.allreduce = False |
|
param.group_name = throwaway.expert_group_name |
|
|
|
integrate_megablocks_with_ds_expert_parallelism() |
|
|
|
else: |
|
raise KeyError(neox_args.moe_type) |
|
|
|
self.layer_past = None |
|
|
|
def _get_bias_dropout(self): |
|
if self.bias_dropout_fusion: |
|
fn = ( |
|
bias_dropout_add_fused_train |
|
if self.training |
|
else bias_dropout_add_fused_inference |
|
) |
|
else: |
|
fn = get_bias_dropout_add(self.training) |
|
return fn |
|
|
|
def forward(self, x, attention_mask, layer_past=None): |
|
layer_past = layer_past if layer_past is not None else self.layer_past |
|
bias_dropout_fn = self._get_bias_dropout() |
|
moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) |
|
|
|
if self.gpt_j_residual: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual = x |
|
|
|
if self.gpt_j_tied: |
|
x = self.input_layernorm(x) |
|
x1, x2 = x, x |
|
else: |
|
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) |
|
|
|
|
|
attention_output, attention_bias = self.attention( |
|
x1, attention_mask, layer_past=layer_past |
|
) |
|
if self.use_cache: |
|
attention_output, presents = attention_output |
|
self.layer_past = presents |
|
|
|
if attention_bias is not None: |
|
with torch.enable_grad() if not self.eval else nullcontext(): |
|
attention_output = bias_dropout_fn( |
|
attention_output, |
|
bias=attention_bias.expand_as(attention_output), |
|
residual=None, |
|
prob=self.hidden_dropout, |
|
) |
|
|
|
|
|
mlp_output, mlp_bias = self.mlp(x2) |
|
if mlp_bias is not None: |
|
with torch.enable_grad() if not self.eval else nullcontext(): |
|
output = bias_dropout_fn( |
|
mlp_output, |
|
bias=mlp_bias.expand_as(mlp_output), |
|
residual=attention_output, |
|
prob=self.hidden_dropout, |
|
) |
|
else: |
|
output = mlp_output |
|
|
|
|
|
output = residual + self.reduce(output) |
|
else: |
|
|
|
|
|
|
|
|
|
residual = x |
|
|
|
|
|
attention_output, attention_bias = self.attention( |
|
self.input_layernorm(x), attention_mask, layer_past=layer_past |
|
) |
|
if self.use_cache: |
|
attention_output, presents = attention_output |
|
self.layer_past = presents |
|
with torch.enable_grad() if not self.eval else nullcontext(): |
|
if attention_bias is not None: |
|
|
|
attention_output = bias_dropout_fn( |
|
attention_output, |
|
bias=attention_bias.expand_as(residual), |
|
residual=residual, |
|
prob=self.hidden_dropout, |
|
) |
|
else: |
|
|
|
attention_output = ( |
|
torch.nn.functional.dropout( |
|
attention_output, |
|
p=self.hidden_dropout, |
|
training=self.training, |
|
) |
|
+ residual |
|
) |
|
|
|
|
|
layernorm_output = self.post_attention_layernorm(attention_output) |
|
mlp_bias = torch.tensor( |
|
0.0, device=layernorm_output.device, dtype=layernorm_output.dtype |
|
) |
|
|
|
if self.num_experts == 1: |
|
mlp_output, mlp_bias = self.mlp(layernorm_output) |
|
else: |
|
if self.moe_type == "deepspeed": |
|
mlp_output, moe_loss, _ = self.mlp(layernorm_output) |
|
mlp_bias = ( |
|
None |
|
) |
|
elif self.moe_type == "megablocks": |
|
mlp_output, mlp_bias = self.mlp(layernorm_output) |
|
else: |
|
raise KeyError(self.moe_type) |
|
|
|
with torch.enable_grad() if not self.eval else nullcontext(): |
|
if ( |
|
self.activation == "swiglu" |
|
or self.num_experts > 1 |
|
and self.moe_type == "deepspeed" |
|
): |
|
|
|
assert mlp_bias is None |
|
output = mlp_output + attention_output |
|
else: |
|
output = bias_dropout_fn( |
|
mlp_output, |
|
bias=mlp_bias.expand_as(attention_output), |
|
residual=attention_output, |
|
prob=self.hidden_dropout, |
|
) |
|
|
|
return output, moe_loss |
|
|
|
|
|
class ParallelTransformerLayerPipe(ParallelTransformerLayer): |
|
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline.""" |
|
|
|
def forward(self, args): |
|
assert ( |
|
len(args) == 2 |
|
), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" |
|
hidden_states, attention_mask = args |
|
|
|
output, moe_loss = super().forward(hidden_states, attention_mask) |
|
|
|
self.last_moe_loss = moe_loss |
|
return output, attention_mask |
|
|
|
|
|
class ParallelLinearPipe(ParallelLinear): |
|
"""Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" |
|
|
|
def forward(self, args): |
|
assert isinstance( |
|
args, torch.Tensor |
|
), "ParallelLinearPipe expects a single argument - hidden_states" |
|
hidden_state = args |
|
logits, bias = super().forward(hidden_state) |
|
return logits |
|
|
|
|
|
class NormPipe(nn.Module): |
|
"""Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" |
|
|
|
def __init__(self, norm_class, hidden_size, eps): |
|
super().__init__() |
|
self.norm = norm_class(hidden_size, eps=eps) |
|
|
|
def forward(self, args): |
|
assert not isinstance( |
|
args, tuple |
|
), "NormPipe should only receive a single tensor as input" |
|
return self.norm(args) |
|
|
|
|
|
def parallel_lm_logits( |
|
input_, |
|
word_embeddings_weight, |
|
parallel_output, |
|
seq_parallel=False, |
|
seq_dim=1, |
|
bias=None, |
|
): |
|
"""LM logits using word embedding weights.""" |
|
|
|
if seq_parallel: |
|
|
|
|
|
input_parallel = mpu.gather_from_sequence_parallel_region( |
|
input_, seq_dim=seq_dim |
|
) |
|
else: |
|
|
|
input_parallel = mpu.copy_to_model_parallel_region(input_) |
|
|
|
|
|
if bias is None: |
|
logits_parallel = F.linear(input_parallel, word_embeddings_weight) |
|
else: |
|
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) |
|
|
|
|
|
if parallel_output: |
|
return logits_parallel |
|
|
|
return mpu.gather_from_model_parallel_region(logits_parallel) |
|
|