# Copyright (c) 2024, EleutherAI # This file is based on code by the authors denoted below and has been modified from its original version. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.activations import get_activation from megatron.model.norms import get_norm from megatron.model.utils import get_fusion_type from megatron import mpu class TinyAttention(nn.Module): def __init__(self, neox_args, d_attn, d_ff, mask_fn): super().__init__() self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn) self.scale = d_attn**-0.5 self.proj_ffn = nn.Linear(d_attn, d_ff) self.softmax = FusedScaleMaskSoftmax( input_in_fp16=neox_args.precision == "fp16", input_in_bf16=neox_args.precision == "bfloat16", fusion_type=get_fusion_type(neox_args), mask_func=mask_fn, softmax_in_fp32=neox_args.attention_softmax_in_fp32, scale=None, ) def forward(self, x, attention_mask): q, k, v = torch.chunk(self.proj_qkv(x), 3, dim=-1) w = torch.einsum("bnd,bmd->bnm", q, k).unsqueeze(1) * self.scale a = self.softmax( w, mask=attention_mask[..., : w.size(-2), : w.size(-1)] ).squeeze(1) x = torch.einsum("bnm,bmd->bnd", a, v) return self.proj_ffn(x) class SpatialGatingUnit(nn.Module): def __init__(self, neox_args, d_ff, d_attn=None, causal=True, mask_fn=None): super().__init__() self.causal = causal self.use_attn = d_attn is not None norm, eps = get_norm(neox_args) self.norm = norm(d_ff, eps=eps) self.proj = nn.Linear(neox_args.seq_length, neox_args.seq_length) if self.use_attn: assert mask_fn is not None self.attn = TinyAttention( neox_args=neox_args, d_attn=d_attn, d_ff=d_ff, mask_fn=mask_fn ) nn.init.zeros_(self.proj.weight) nn.init.constant_(self.proj.bias, 1.0) def forward(self, x, attention_mask): device, n = x.device, x.shape[1] x = x.transpose(0, 1) # [s, b, d] -> [b, s, d] res, gate = x.chunk(2, dim=-1) # split along dim gate = self.norm(gate) weight, bias = self.proj.weight, self.proj.bias if self.causal: weight, bias = weight[:n, :n], bias[:n] mask = torch.ones(weight.shape[:2], device=device).triu_(1).bool() weight = weight.masked_fill(mask, 0.0) gate = F.linear(gate.transpose(2, 1), weight, self.proj.bias).transpose(2, 1) if self.use_attn: gate = gate + self.attn(x, attention_mask) return (gate * res).transpose(0, 1) # [b, s, d] -> [s, b, d] class GMLPBlock(nn.Module): def __init__( self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None, ): super().__init__() self.layer_number = layer_number ff_dim = neox_args.hidden_size * ff_mult norm, eps = get_norm(neox_args) self.norm = norm(neox_args.hidden_size, eps=eps) self.input_linear = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim * 2, gather_output=False, init_method=init_method, skip_bias_add=True, ) self.activation_func, _ = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim else: d_attn = None self.sgu = SpatialGatingUnit( neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn ) self.output_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, ) def forward(self, args): assert len(args) == 2, "GMLPBlock expects 2 arguments" x, attention_mask = args x = self.norm(x) x, _ = self.input_linear(x) x = self.activation_func(x) x = self.sgu(x, attention_mask) x, _ = self.output_linear(x) return x, attention_mask