|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""GPT-2 model.""" |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from collections import defaultdict |
|
|
|
from functools import partial |
|
from megatron.model.utils import Lambda, SequentialWrapper, recursive_setattr |
|
from megatron.model.norms import get_norm |
|
from megatron.model.init_functions import get_init_methods |
|
|
|
from megatron import mpu |
|
from megatron.mpu import ParallelRelativePositionBias |
|
from megatron.model.transformer import ( |
|
ParallelTransformerLayerPipe, |
|
NormPipe, |
|
ParallelLinearPipe, |
|
parallel_lm_logits, |
|
ParallelLinear, |
|
) |
|
from megatron.model.gmlp import GMLPBlock |
|
from megatron.model.rwkv.v6 import RWKVResidualLayerPipe |
|
from megatron.model.mamba import ParallelMambaResidualLayerPipe |
|
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding |
|
|
|
|
|
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec |
|
from typing import Union, List |
|
|
|
|
|
def gpt2_attention_mask_func(attention_scores, ltor_mask): |
|
mask_value = torch.finfo(attention_scores.dtype).min |
|
|
|
|
|
mask_value = torch.tensor( |
|
mask_value, dtype=attention_scores.dtype, device=attention_scores.device |
|
) |
|
attention_scores.masked_fill_(ltor_mask, mask_value) |
|
return attention_scores |
|
|
|
|
|
def cross_entropy(output, labels, _fp16=False): |
|
"""From pretrain_gpt2:forward_step()""" |
|
""" |
|
if self.fp16_lm_cross_entropy: |
|
assert output.dtype == torch.half |
|
loss = mpu.vocab_parallel_cross_entropy(output, labels) |
|
else: |
|
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) |
|
return loss |
|
""" |
|
labels, loss_mask = labels[0], labels[1] |
|
if _fp16: |
|
assert output.dtype == torch.half and loss_mask.dtype == torch.half |
|
losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels) |
|
else: |
|
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) |
|
loss_mask = loss_mask.view(-1) |
|
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() |
|
return loss |
|
|
|
|
|
def _pre_transformer_block(args): |
|
|
|
assert len(args) == 2, "Incorrect number of arguments to _pre_transformer_block" |
|
fn = lambda _args: (_args[0].transpose(0, 1).contiguous(), *_args[1:]) |
|
return fn(args) |
|
|
|
|
|
def _post_transformer_block(args): |
|
|
|
|
|
assert len(args) == 2, "Incorrect number of arguments to _post_transformer_block" |
|
fn = lambda _args: (_args[0].transpose(0, 1).contiguous()) |
|
return fn(args) |
|
|
|
|
|
class GPT2ModelPipe(PipelineModule, torch.nn.Module): |
|
"""GPT2Model adapted for pipeline parallelism. |
|
|
|
The largest change is flattening the GPTModel class so we can express it as a |
|
sequence of layers including embedding, transformer layers, and output. |
|
|
|
:param neox_args: NeoX arguments object (configuration) |
|
:param num_tokentypes: number of token types (TODO: deprecated, remove) |
|
:param parallel_output: if true, don't gather the output logits, and calculate loss in parallel. Set to true by default in training for efficiency, but set to false for inference. |
|
:param topology: deepspeed topology object specifying pipe / model parallelism topology. |
|
:param use_cache: if true, cache key/value pairs for each layer in inference. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
num_tokentypes=0, |
|
parallel_output=True, |
|
topology=None, |
|
use_cache=False, |
|
): |
|
self.neox_args = neox_args |
|
|
|
self.use_cache = use_cache |
|
self.parallel_output = parallel_output |
|
self.hidden_size = self.neox_args.hidden_size |
|
self.num_tokentypes = num_tokentypes |
|
self.init_method, self.output_layer_init_method = get_init_methods( |
|
self.neox_args |
|
) |
|
self.__topology__ = topology |
|
|
|
self.specs = [] |
|
self.init_specs() |
|
|
|
super().__init__( |
|
layers=self.specs, |
|
loss_fn=partial(cross_entropy, _fp16=self.neox_args.fp16_lm_cross_entropy), |
|
topology=topology, |
|
activation_checkpoint_interval=self.neox_args.checkpoint_num_layers |
|
if self.neox_args.checkpoint_activations |
|
else 0, |
|
partition_method=neox_args.pipe_partition_method, |
|
checkpointable_layers=[ |
|
"GMLPBlock", |
|
"ParallelTransformerLayerPipe", |
|
"ParallelMambaResidualLayerPipe", |
|
], |
|
) |
|
|
|
def insert_layers( |
|
self, layers: Union[nn.Module, nn.ModuleList, nn.Sequential, List], idx |
|
): |
|
""" |
|
inserts the layers in `layers` into the pipe model at `idx`. |
|
""" |
|
if isinstance(layers, nn.Module): |
|
self.specs.insert(idx, layers) |
|
elif any( |
|
[isinstance(layers, nn.ModuleList), isinstance(layers, nn.Sequential)] |
|
): |
|
self.specs[idx:idx] = layers |
|
elif isinstance(layers, list): |
|
assert all( |
|
[hasattr(l, "__call__") for l in layers] |
|
), "all items in `layers` must be Callables" |
|
self.specs[idx:idx] = layers |
|
else: |
|
raise ValueError( |
|
f"layer passed into {self.__class__.__name__}.insert_layer() should be either an nn.Module, an nn.ModuleList, an nn.Sequential object, or a list of callables not a {type(layers)}" |
|
) |
|
|
|
|
|
super().__init__( |
|
layers=self.specs, |
|
loss_fn=self.loss_fn, |
|
topology=self.__topology__, |
|
activation_checkpoint_interval=self.activation_checkpoint_interval, |
|
partition_method=self.neox_args.pipe_partition_method, |
|
checkpointable_layers=[ |
|
"GMLPBlock", |
|
"ParallelTransformerLayerPipe", |
|
"ParallelMambaResidualLayerPipe", |
|
"RWKVResidualLayerPipe", |
|
], |
|
) |
|
|
|
def init_specs(self): |
|
|
|
weight_tying = not self.neox_args.no_weight_tying |
|
self.specs = [] |
|
|
|
|
|
|
|
|
|
if weight_tying: |
|
self.specs.append( |
|
TiedLayerSpec( |
|
"embed", |
|
EmbeddingPipe, |
|
self.neox_args, |
|
self.hidden_size, |
|
self.neox_args.padded_vocab_size, |
|
self.neox_args.max_position_embeddings, |
|
self.neox_args.hidden_dropout, |
|
self.init_method, |
|
self.num_tokentypes, |
|
tied_weight_attr="word_embeddings_weight", |
|
) |
|
) |
|
else: |
|
self.specs.append( |
|
LayerSpec( |
|
EmbeddingPipe, |
|
self.neox_args, |
|
self.hidden_size, |
|
self.neox_args.padded_vocab_size, |
|
self.neox_args.max_position_embeddings, |
|
self.neox_args.hidden_dropout, |
|
self.init_method, |
|
self.num_tokentypes, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.specs.append(_pre_transformer_block) |
|
|
|
|
|
if self.neox_args.pos_emb == "rpe": |
|
hidden_size_per_attention_head = mpu.divide( |
|
self.neox_args.hidden_size, self.neox_args.num_attention_heads |
|
) |
|
rpe_scale = math.sqrt(hidden_size_per_attention_head) |
|
rpe_emb = ParallelRelativePositionBias( |
|
neox_args=self.neox_args, |
|
scale=rpe_scale, |
|
causal=True, |
|
num_buckets=self.neox_args.rpe_num_buckets, |
|
max_distance=self.neox_args.rpe_max_distance, |
|
heads=self.neox_args.num_attention_heads, |
|
) |
|
|
|
|
|
for i in range(self.neox_args.num_layers): |
|
layer_type = self.neox_args.attention_config[i] |
|
if layer_type in ["gmlp", "amlp"]: |
|
self.specs.append( |
|
LayerSpec( |
|
GMLPBlock, |
|
init_method=self.init_method, |
|
layer_number=i, |
|
output_layer_init_method=self.output_layer_init_method, |
|
neox_args=self.neox_args, |
|
mask_fn=gpt2_attention_mask_func, |
|
) |
|
) |
|
elif layer_type == "rwkv": |
|
self.specs.append( |
|
LayerSpec( |
|
RWKVResidualLayerPipe, |
|
neox_args=self.neox_args, |
|
layer_number=i, |
|
) |
|
) |
|
elif layer_type in ["mamba"]: |
|
self.specs.append( |
|
LayerSpec( |
|
ParallelMambaResidualLayerPipe, |
|
neox_args=self.neox_args, |
|
init_method=self.init_method, |
|
output_layer_init_method=self.output_layer_init_method, |
|
layer_number=i, |
|
) |
|
) |
|
else: |
|
self.specs.append( |
|
LayerSpec( |
|
ParallelTransformerLayerPipe, |
|
neox_args=self.neox_args, |
|
attention_mask_func=gpt2_attention_mask_func, |
|
init_method=self.init_method, |
|
output_layer_init_method=self.output_layer_init_method, |
|
layer_number=i, |
|
rpe=rpe_emb if self.neox_args.pos_emb == "rpe" else None, |
|
rotary=self.neox_args.pos_emb == "rotary", |
|
use_cache=self.use_cache, |
|
) |
|
) |
|
|
|
|
|
self.specs.append(_post_transformer_block) |
|
|
|
|
|
norm, eps = get_norm(self.neox_args) |
|
self.specs.append( |
|
LayerSpec(NormPipe, norm, self.neox_args.hidden_size, eps=eps) |
|
) |
|
|
|
|
|
|
|
def _logits_helper(embedding, lm_output): |
|
"""Just a wrapper to massage inputs/outputs from pipeline.""" |
|
if self.neox_args.use_mup: |
|
|
|
|
|
lm_output = ( |
|
lm_output |
|
/ self.tied_modules.embed.word_embeddings.weight.infshape.width_mult() |
|
) |
|
|
|
logits = parallel_lm_logits( |
|
lm_output, |
|
embedding.word_embeddings_weight, |
|
self.parallel_output, |
|
seq_parallel=self.neox_args.sequence_parallel, |
|
) |
|
return logits |
|
|
|
if weight_tying: |
|
self.specs.append( |
|
TiedLayerSpec( |
|
"embed", |
|
EmbeddingPipe, |
|
self.neox_args, |
|
self.hidden_size, |
|
self.neox_args.padded_vocab_size, |
|
self.neox_args.max_position_embeddings, |
|
self.neox_args.hidden_dropout, |
|
self.init_method, |
|
self.num_tokentypes, |
|
forward_fn=_logits_helper, |
|
tied_weight_attr="word_embeddings_weight", |
|
) |
|
) |
|
else: |
|
self.specs.append( |
|
LayerSpec( |
|
ParallelLinearPipe, |
|
neox_args=self.neox_args, |
|
init_method=self.init_method, |
|
parallel_output=self.parallel_output, |
|
is_last_layer=True, |
|
) |
|
) |
|
|
|
def _set_parallel_output(self, value): |
|
|
|
final_layer = list(self.forward_funcs)[-1] |
|
if isinstance(final_layer, (ParallelLinearPipe, ParallelLinear)): |
|
final_layer.final_linear.set_parallel_output(value) |
|
|
|
def inference_mode(self, use_cache=True): |
|
""" |
|
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false, |
|
so logits are gathered across model parallel ranks. |
|
|
|
:param cache: (bool) True if you want to use caching during inference, False otherwise |
|
""" |
|
|
|
recursive_setattr(self.forward_funcs, "use_cache", use_cache, assert_type=bool) |
|
|
|
self._set_parallel_output(False) |
|
recursive_setattr(self.forward_funcs, "training", False) |
|
|
|
def train_mode(self): |
|
""" |
|
Sets up the model for training by turning off k/v caching and setting `parallel output` of the final layer to True, |
|
so logits are not gathered across model parallel ranks, and loss is computed in parallel (more efficient). |
|
""" |
|
|
|
recursive_setattr(self.forward_funcs, "use_cache", False) |
|
|
|
self._set_parallel_output(True) |
|
recursive_setattr(self.forward_funcs, "training", True) |
|
|
|
def clear_cache(self): |
|
""" |
|
Recursively clears the kv cache on all layers |
|
""" |
|
recursive_setattr(self.forward_funcs, "layer_past", None) |
|
|
|
def to_sequential(self): |
|
""" |
|
Transforms the PipelineModule to a plain nn.Sequential module |
|
:return: |
|
""" |
|
layers = [] |
|
tied_layers = defaultdict(list) |
|
for n, spec in enumerate(self.specs): |
|
if isinstance(spec, TiedLayerSpec): |
|
if spec.key in tied_layers: |
|
|
|
layers.append( |
|
Lambda(lambda x: spec.forward_fn(tied_layers[spec.key][0], x)) |
|
) |
|
else: |
|
|
|
module = spec.build(log=False) |
|
layers.append(module) |
|
tied_layers[spec.key].append(module) |
|
elif isinstance(spec, LayerSpec): |
|
layers.append(spec.build(log=False)) |
|
elif hasattr(spec, "__call__"): |
|
|
|
layers.append(Lambda(spec)) |
|
else: |
|
raise ValueError(f"Layer number {n} ({spec}) Not recognized") |
|
model = SequentialWrapper( |
|
layers, |
|
self.activation_checkpoint_interval, |
|
self.activation_checkpoint_func, |
|
parent_class_name=self.__class__.__name__, |
|
) |
|
return model |
|
|