NEOX / megatron /model /gmlp.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.activations import get_activation
from megatron.model.norms import get_norm
from megatron.model.utils import get_fusion_type
from megatron import mpu
class TinyAttention(nn.Module):
def __init__(self, neox_args, d_attn, d_ff, mask_fn):
super().__init__()
self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn)
self.scale = d_attn**-0.5
self.proj_ffn = nn.Linear(d_attn, d_ff)
self.softmax = FusedScaleMaskSoftmax(
input_in_fp16=neox_args.precision == "fp16",
input_in_bf16=neox_args.precision == "bfloat16",
fusion_type=get_fusion_type(neox_args),
mask_func=mask_fn,
softmax_in_fp32=neox_args.attention_softmax_in_fp32,
scale=None,
)
def forward(self, x, attention_mask):
q, k, v = torch.chunk(self.proj_qkv(x), 3, dim=-1)
w = torch.einsum("bnd,bmd->bnm", q, k).unsqueeze(1) * self.scale
a = self.softmax(
w, mask=attention_mask[..., : w.size(-2), : w.size(-1)]
).squeeze(1)
x = torch.einsum("bnm,bmd->bnd", a, v)
return self.proj_ffn(x)
class SpatialGatingUnit(nn.Module):
def __init__(self, neox_args, d_ff, d_attn=None, causal=True, mask_fn=None):
super().__init__()
self.causal = causal
self.use_attn = d_attn is not None
norm, eps = get_norm(neox_args)
self.norm = norm(d_ff, eps=eps)
self.proj = nn.Linear(neox_args.seq_length, neox_args.seq_length)
if self.use_attn:
assert mask_fn is not None
self.attn = TinyAttention(
neox_args=neox_args, d_attn=d_attn, d_ff=d_ff, mask_fn=mask_fn
)
nn.init.zeros_(self.proj.weight)
nn.init.constant_(self.proj.bias, 1.0)
def forward(self, x, attention_mask):
device, n = x.device, x.shape[1]
x = x.transpose(0, 1) # [s, b, d] -> [b, s, d]
res, gate = x.chunk(2, dim=-1) # split along dim
gate = self.norm(gate)
weight, bias = self.proj.weight, self.proj.bias
if self.causal:
weight, bias = weight[:n, :n], bias[:n]
mask = torch.ones(weight.shape[:2], device=device).triu_(1).bool()
weight = weight.masked_fill(mask, 0.0)
gate = F.linear(gate.transpose(2, 1), weight, self.proj.bias).transpose(2, 1)
if self.use_attn:
gate = gate + self.attn(x, attention_mask)
return (gate * res).transpose(0, 1) # [b, s, d] -> [s, b, d]
class GMLPBlock(nn.Module):
def __init__(
self,
neox_args,
init_method,
output_layer_init_method,
layer_number,
ff_mult=4,
mask_fn=None,
):
super().__init__()
self.layer_number = layer_number
ff_dim = neox_args.hidden_size * ff_mult
norm, eps = get_norm(neox_args)
self.norm = norm(neox_args.hidden_size, eps=eps)
self.input_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim * 2,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
self.activation_func, _ = get_activation(neox_args)
ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size())
if neox_args.attention_config[layer_number] == "amlp":
d_attn = neox_args.gmlp_attn_dim
else:
d_attn = None
self.sgu = SpatialGatingUnit(
neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn
)
self.output_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
)
def forward(self, args):
assert len(args) == 2, "GMLPBlock expects 2 arguments"
x, attention_mask = args
x = self.norm(x)
x, _ = self.input_linear(x)
x = self.activation_func(x)
x = self.sgu(x, attention_mask)
x, _ = self.output_linear(x)
return x, attention_mask