|
import torch |
|
from torch import nn |
|
from .LMConfig import LMConfig |
|
import math |
|
import torch.nn.functional as F |
|
from typing import Optional |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim: int, eps: float) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
self.eps = eps |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(self.eps + x.pow(2).mean(-1, keepdim = True)) |
|
|
|
def forward(self, x): |
|
x = self._norm(x.float()).type_as(x) |
|
x = x * self.weight |
|
return x |
|
|
|
def repeat_kv(x: torch.Tensor, n_rep: int): |
|
''' |
|
x 是 key 或者 value ,大小是 (batch_size, seq_len, kv_heads, head_dim) |
|
要把它复制 n_rep 遍,变成 (batch_size, seq_len, kv_heads * n_rep, head_dim) |
|
''' |
|
if n_rep == 1: |
|
return x |
|
else: |
|
bs, seq_len, kv_heads, head_dim = x.shape |
|
return x[:,:,:,None,:].expand(bs, seq_len, kv_heads, n_rep, head_dim).reshape(bs, seq_len, kv_heads * n_rep, head_dim) |
|
|
|
|
|
|
|
|
|
def get_rotation(dim: int, seq_len: int, base: float = 10000.0): |
|
''' |
|
获得旋转矩阵,就是一个(seq_len, dim // 2)大小的矩阵W。 |
|
W[a][b] = cos(a*θ_b) + i*sin(a*θ_b) ,实际上就是模长为 1 ,旋转角度为 a*θ_b 的虚数向量 |
|
但是要注意,这里的 dim 并不是模型的大小,而是在每个注意力头里的 tensor 的大小。也就是 args.dim // args.n_heads |
|
''' |
|
angles = 1.0 / (base ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim)) |
|
seq = torch.arange(0, seq_len, device = angles.device) |
|
angle_matrix = torch.outer(seq, angles).float() |
|
weight = torch.polar(torch.ones_like(angle_matrix), angle_matrix) |
|
return weight |
|
|
|
def position_encoding(xq, xk, weight): |
|
|
|
|
|
|
|
|
|
xq = xq.float() |
|
xk = xk.float() |
|
xq = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) |
|
xk = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) |
|
|
|
|
|
|
|
|
|
|
|
xq = torch.view_as_real(weight[None, :, None, :] * xq).flatten(3) |
|
xk = torch.view_as_real(weight[None, :, None, :] * xk).flatten(3) |
|
|
|
|
|
|
|
|
|
|
|
return xq, xk |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, args: LMConfig) -> None: |
|
super().__init__() |
|
self.dim = args.dim |
|
self.n_heads = args.n_heads |
|
self.n_kv_heads = args.n_kv_heads |
|
|
|
assert self.n_heads % self.n_kv_heads == 0 |
|
self.n_rep = self.n_heads // self.n_kv_heads |
|
|
|
assert self.dim % self.n_heads == 0 |
|
self.head_dim = self.dim // self.n_heads |
|
|
|
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias = False) |
|
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False) |
|
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False) |
|
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias = False) |
|
|
|
|
|
self.resid_dropout = nn.Dropout(args.dropout) |
|
|
|
self.dropout = args.dropout |
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn |
|
|
|
if not self.flash: |
|
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) |
|
mask = torch.triu(mask, diagonal = 1) |
|
self.register_buffer("mask", mask) |
|
|
|
|
|
self.k_cache, self.v_cache = None, None |
|
|
|
def forward(self, x: torch.Tensor, weight: torch.Tensor, use_kv_cache = False): |
|
|
|
|
|
bsz, seq_len, _ = x.shape |
|
|
|
if use_kv_cache and self.eval(): |
|
|
|
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
|
else: |
|
token = x[:, -1:, :] |
|
|
|
xq = torch.concat((torch.zeros_like(x[:, : -1, :]), self.wq(token)), dim = 1) |
|
|
|
|
|
xk = torch.concat((self.k_cache, self.wk(token)), dim = 1) |
|
|
|
xv = torch.concat((self.v_cache, self.wv(token)), dim = 1) |
|
|
|
self.k_cache, self.v_cache = xk, xv |
|
else: |
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
|
|
|
xq = xq.reshape(bsz, seq_len, self.n_heads, self.head_dim) |
|
xk = xk.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) |
|
xv = xv.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) |
|
|
|
xq, xk = position_encoding(xq, xk, weight) |
|
xk, xv = repeat_kv(xk, self.n_rep), repeat_kv(xv, self.n_rep) |
|
|
|
xq = xq.transpose(1, 2) |
|
xk = xk.transpose(1, 2) |
|
xv = xv.transpose(1, 2) |
|
|
|
if self.flash: |
|
|
|
|
|
|
|
x = F.scaled_dot_product_attention(xq, xk, xv, attn_mask = None, |
|
dropout_p = self.dropout if self.training else 0.0, |
|
is_causal = True) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
x = xq @ xk.transpose(2, 3) / math.sqrt(self.head_dim) |
|
assert hasattr(self, "mask") |
|
x = x + self.mask[:, :, : seq_len, : seq_len] |
|
x = F.softmax(x, dim = -1) @ xv |
|
|
|
x = x.transpose(1, 2).contiguous().view(bsz, seq_len, -1) |
|
x = self.resid_dropout(self.wo(x)) |
|
return x |
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim: int, hidden_dim: int, multi: int, dropout: float) -> None: |
|
|
|
|
|
|
|
super().__init__() |
|
if hidden_dim is None: |
|
hidden_dim = 4 * dim |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
hidden_dim = multi * ((hidden_dim + multi - 1) // multi) |
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias = False) |
|
self.w2 = nn.Linear(dim, hidden_dim, bias = False) |
|
self.w3 = nn.Linear(hidden_dim, dim, bias = False) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
x_2 = self.w2(x) |
|
x = self.w1(x) |
|
x = F.silu(x) |
|
x = x * x_2 |
|
x = self.w3(x) |
|
return x |
|
|
|
|
|
class MoEGate(nn.Module): |
|
def __init__(self, args: LMConfig) -> None: |
|
super().__init__() |
|
self.topk = args.num_experts_per_tok |
|
self.gating_dim = args.dim |
|
self.n_routed_experts = args.num_experts_per_tok |
|
self.scoring_func = args.scoring_func |
|
self.norm_topk_prob = args.norm_topk_prob |
|
self.alpha = args.aux_loss_alpha |
|
self.seq_aux = args.seq_aux |
|
self.w = nn.Linear(self.gating_dim, self.n_routed_experts, bias = False) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
import torch.nn.init as init |
|
init.kaiming_normal_(self.w.weight) |
|
|
|
def forward(self, x: torch.Tensor): |
|
bsz, seq_len, dim = x.shape |
|
|
|
hidden_states = x.view(-1, dim) |
|
scores = self.w(hidden_states) |
|
|
|
if self.scoring_func == "softmax": |
|
scores = F.softmax(scores, dim = -1) |
|
else: |
|
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') |
|
|
|
|
|
topk_weight, topk_idx = torch.topk(scores, self.topk, dim = -1, sorted = False) |
|
|
|
|
|
if self.norm_topk_prob: |
|
denominator = topk_weight.sum(dim = -1) + 1e-20 |
|
topk_weight = topk_weight / denominator |
|
|
|
if self.training and self.alpha > 0: |
|
|
|
scores_for_aux = scores |
|
aux_topk = self.topk |
|
topk_idx_for_aux_loss = topk_idx.view(bsz, -1) |
|
|
|
if self.seq_aux: |
|
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1).mean(dim = 1) |
|
|
|
ce = torch.zeros(bsz, self.n_routed_experts) |
|
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, |
|
device = hidden_states.device).div_( |
|
seq_len * aux_topk / self.n_routed_experts |
|
)) |
|
|
|
|
|
|
|
|
|
|
|
aux_loss = (ce * scores_for_seq_aux).sum(dim = -1).mean() * self.alpha |
|
else: |
|
|
|
|
|
ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes = self.n_routed_experts).mean(dim = 0) |
|
ce = ce * self.n_routed_experts |
|
|
|
|
|
|
|
|
|
|
|
aux_loss = (ce * scores_for_aux.mean(dim = 0)).sum() * self.alpha |
|
else: |
|
aux_loss = None |
|
return topk_weight, topk_idx, aux_loss |
|
|
|
class MOEFeedForward(nn.Module): |
|
def __init__(self, args: LMConfig) -> None: |
|
super().__init__() |
|
self.topk = args.num_experts_per_tok |
|
self.n_routed_experts = args.num_experts_per_tok |
|
self.experts = nn.ModuleList([ |
|
FeedForward(dim = args.dim, |
|
hidden_dim = args.hidden_dim, |
|
multi = args.multiple_of, |
|
dropout = args.dropout) |
|
for _ in range(self.n_routed_experts) |
|
]) |
|
self.gate = MoEGate(args) |
|
|
|
if args.n_shared_experts is not None: |
|
self.shared_experts = FeedForward( |
|
dim = args.dim, |
|
hidden_dim = args.hidden_dim, |
|
multi = args.multiple_of, |
|
dropout = args.dropout |
|
) |
|
|
|
def work(self, x, topk_weight, topk_idx): |
|
bsz, seq_len, dim = x.shape |
|
|
|
x = x.view(-1, dim) |
|
x = x.repeat_interleave(self.topk, dim = 0) |
|
|
|
flat_topk_idx = topk_idx.view(-1) |
|
|
|
y = torch.empty_like(x, dtype = torch.float16) |
|
for i in range(self.n_routed_experts): |
|
y[flat_topk_idx == i] = self.experts[i](x[flat_topk_idx == i]) |
|
|
|
y = y.view(bsz, seq_len, self.topk, -1) |
|
y = y * topk_weight.unsqueeze(-1).sum(dim = 1) |
|
|
|
y = y.view(bsz, seq_len, -1) |
|
return y |
|
|
|
def forward(self, x): |
|
|
|
topk_weight, topk_idx, _ = self.gate(x) |
|
|
|
if self.training: |
|
y = self.work(x, topk_weight, topk_idx) |
|
else: |
|
with torch.no_grad: |
|
y = self.work(x, topk_weight, topk_idx) |
|
|
|
if self.args.n_shared_experts is not None: |
|
y = y + self.shared_experts(y) |
|
|
|
return y |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, layer_id: int, args: LMConfig) -> None: |
|
|
|
super().__init__() |
|
|
|
self.attn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps) |
|
self.attn = Attention(args) |
|
|
|
self.ffn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps) |
|
if args.use_moe: |
|
self.feed_forward = MOEFeedForward(args) |
|
else: |
|
self.feed_forward = FeedForward(dim = args.dim, |
|
hidden_dim = args.hidden_dim, |
|
multi = args.multiple_of, |
|
dropout = args.dropout) |
|
|
|
|
|
def forward(self, x, weight, use_kv_cache = False): |
|
|
|
|
|
|
|
x = x + self.attn(self.attn_norm(x), weight, use_kv_cache) |
|
|
|
x = x + self.feed_forward(self.ffn_norm(x)) |
|
|
|
return x |
|
|
|
|
|
class Transformer(PreTrainedModel): |
|
config_class = LMConfig |
|
last_loss: Optional[torch.Tensor] |
|
def __init__(self, args: LMConfig = None) -> None: |
|
super().__init__(args) |
|
if not args: |
|
args = LMConfig() |
|
|
|
self.args = args |
|
|
|
self.embedding = nn.Embedding(args.vocab_size, args.dim) |
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(args.n_layers): |
|
self.layers.append(TransformerBlock(i, args)) |
|
|
|
|
|
rotation_weight = get_rotation(dim = args.dim // args.n_heads, seq_len = args.max_seq_len) |
|
self.register_buffer('rotation_weight', rotation_weight, persistent = False) |
|
|
|
self.norm = RMSNorm(dim = args.dim, eps = args.norm_eps) |
|
self.output = nn.Linear(args.dim, args.vocab_size, bias = False) |
|
|
|
self.embedding.weight = self.output.weight |
|
|
|
self.OUT = CausalLMOutputWithPast() |
|
|
|
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, |
|
use_kv_cache = False, **key_args): |
|
|
|
if 'input_ids' in key_args: |
|
tokens = key_args['input_ids'] |
|
if 'attention_mask' in key_args: |
|
tokens = key_args['attention_mask'] |
|
|
|
_, seq_len = tokens.shape |
|
x = self.embedding(tokens) |
|
x = self.dropout(x) |
|
|
|
|
|
|
|
r_w = self.rotation_weight[: seq_len] |
|
|
|
|
|
for layer in self.layers: |
|
x = layer(x, r_w, use_kv_cache) |
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
|
if targets is not None: |
|
logits = self.output(x) |
|
|
|
|
|
last_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index = -1) |
|
|
|
|
|
|
|
|
|
else: |
|
logits = self.output(x[:, [-1], :]) |
|
|
|
last_loss = None |
|
|
|
self.OUT.__setitem__('logits', logits) |
|
self.OUT.__setitem__('last_loss', last_loss) |
|
|
|
return self.OUT |
|
|
|
@torch.inference_mode() |
|
|
|
def generate(self, idx, eos, max_new_tokens, temperature = 0.7, top_k = None, |
|
stream = True, repetition_penalty = 1., use_kv_cache = True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, seq_len = idx.shape |
|
|
|
while idx.shape[1] < max_new_tokens - 1: |
|
res = self(idx, use_kv_cache = use_kv_cache) |
|
logits = res.logits |
|
logits = logits[:, -1, :] |
|
|
|
|
|
for b in range(bsz): |
|
for token in set(idx.tolist()[b]): |
|
logits[b, token] /= repetition_penalty |
|
|
|
|
|
if temperature == 0.0: |
|
|
|
_, idx_nxt = torch.topk(logits, k = 1, dim = -1) |
|
else: |
|
logits = logits / temperature |
|
if top_k is not None: |
|
|
|
v, _ = torch.topk(logits, k = min(top_k, logits.shape[-1]), dim = -1) |
|
|
|
logits[logits < v[:, [-1]]] = -float("Inf") |
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim = -1) |
|
idx_nxt = torch.multinomial(probs, num_samples = 1, generator = None) |
|
|
|
|
|
if idx_nxt == eos: |
|
break |
|
|
|
idx = torch.concat((idx, idx_nxt), dim = -1) |
|
|
|
if stream: |
|
yield idx[:, seq_len:] |
|
|
|
if not stream: |
|
yield idx[:, seq_len:] |
|
|
|
@torch.inference_mode() |
|
def eval_answer(self, idx): |
|
idx_cond = idx if idx.shape[1] < self.args.max_seq_len else idx[:, -self.args.max_seq_len:] |
|
res = self(idx_cond) |
|
logits = res.logits[:, -1, :] |
|
return logits |