File size: 5,387 Bytes
92d683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli
import torch
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from .utils import grab_first_if_tuple
def grab_first_if_tuple(x):
if x.__class__.__name__ == "tuple":
return x[0]
else:
return x
class RMSNorm(torch.nn.Module):
def __init__(self, config):
super(RMSNorm, self).__init__()
self.eps, self.hidden_size = config.eps, config.hidden_size
self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
self.register_parameter("scale", self.scale)
self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
if self.use_flash_rmsnorm:
try:
from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
self.rmsnorm_func = rmsnorm_func
except:
raise ImportError(
"For `use_flash_rmsnorm`: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm`"
)
def forward(self, x):
if self.use_flash_rmsnorm:
return self.rmsnorm_func(x, self.scale, self.eps)
else:
y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
return self.scale * y
class ParallelGatedMLP(nn.Module):
def __init__(
self,
config,
):
super().__init__()
multiple_of = config.get("inner_size_multiple_of", 64)
self.act_type = config.get("mlp_activation", "silu")
if self.act_type == "gelu":
self.act = F.gelu
elif self.act_type == "silu":
self.act = F.silu
else:
raise NotImplementedError
self.multiple_of = multiple_of * config.model_parallel_size
inner_size = int(2 * config.hidden_size * 4 / 3)
inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
if config.get("inner_mlp_size", None) is not None:
inner_size = config.inner_mlp_size
self.l1 = nn.Linear(
in_features=config.hidden_size,
out_features=inner_size,
bias=False,
)
self.l2 = nn.Linear(
in_features=config.hidden_size,
out_features=inner_size,
bias=False,
)
self.l3 = nn.Linear(
in_features=inner_size,
out_features=config.hidden_size,
bias=False,
)
def forward(self, z):
z1, z2 = self.l1(z), self.l2(z)
z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
y = self.l3(self.act(z1) * z2)
return grab_first_if_tuple(y)
class Embedding(nn.Module):
_train_dtype = "bf16"
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
def embed(self, input_ids, position_ids=None, tokentype_ids=None):
embeddings = self.word_embeddings(input_ids)
return embeddings
def unembed(self, u):
weight = self.word_embeddings.weight
return torch.matmul(u, weight)
class VocabParallelEmbedding(nn.Embedding):
"Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
def __init__(self, config):
vocab_size, process_group, padding_idx = (
config.vocab_size,
config.get("process_group", None),
config.get("padding_idx", None),
)
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if vocab_size % world_size != 0:
raise ValueError(
f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None:
raise RuntimeError("ParallelEmbedding does not support padding_idx")
else:
world_size = 1
super().__init__(
vocab_size // world_size,
embedding_dim=config.hidden_size,
padding_idx=padding_idx,
)
def embed(self, x: Tensor) -> Tensor:
if self.process_group is None:
return self.forward(x)
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.num_embeddings
vocab_start_index, vocab_end_index = (
rank * vocab_size,
(rank + 1) * vocab_size,
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (x < vocab_start_index) | (x >= vocab_end_index)
x = x - vocab_start_index
x[input_ids_mask] = 0
embeddings = self.forward(x)
embeddings[input_ids_mask] = 0.0
# Reduce to the global process group
torch.distributed.all_reduce(embeddings, group=self.process_group)
return embeddings
def unembed(self, u: Tensor) -> Tensor:
if self.process_group is None:
return u @ self.weight.T
else:
raise NotImplementedError
|