myminimind / model.py
POSH
modify model.py
ee119b5
import torch
from torch import nn
from .LMConfig import LMConfig
import math
import torch.nn.functional as F
from typing import Optional
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(self.eps + x.pow(2).mean(-1, keepdim = True))
def forward(self, x):
x = self._norm(x.float()).type_as(x) # 用 float 提高精确度,防止溢出
x = x * self.weight
return x
def repeat_kv(x: torch.Tensor, n_rep: int):
'''
x 是 key 或者 value ,大小是 (batch_size, seq_len, kv_heads, head_dim)
要把它复制 n_rep 遍,变成 (batch_size, seq_len, kv_heads * n_rep, head_dim)
'''
if n_rep == 1:
return x
else:
bs, seq_len, kv_heads, head_dim = x.shape
return x[:,:,:,None,:].expand(bs, seq_len, kv_heads, n_rep, head_dim).reshape(bs, seq_len, kv_heads * n_rep, head_dim)
# expand 的用法:只能拓展大小为1的维度,或者增加维度。并且expand并没有实际占用内存,它只是用广播而已
# 这句不能用 view,要不然报错:
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
def get_rotation(dim: int, seq_len: int, base: float = 10000.0):
'''
获得旋转矩阵,就是一个(seq_len, dim // 2)大小的矩阵W。
W[a][b] = cos(a*θ_b) + i*sin(a*θ_b) ,实际上就是模长为 1 ,旋转角度为 a*θ_b 的虚数向量
但是要注意,这里的 dim 并不是模型的大小,而是在每个注意力头里的 tensor 的大小。也就是 args.dim // args.n_heads
'''
angles = 1.0 / (base ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
seq = torch.arange(0, seq_len, device = angles.device)
angle_matrix = torch.outer(seq, angles).float()
weight = torch.polar(torch.ones_like(angle_matrix), angle_matrix)
return weight
def position_encoding(xq, xk, weight):
# 先把 xq 和 xk 转化成虚数
# xq.shape = [bsz, seq_len, n_heads, head_dim]
# xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # reshape 能处理内存不连续情况,view 不行
# xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq = xq.float()
xk = xk.float()
xq = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) # reshape 能处理内存不连续情况,view 不行
xk = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
# 相乘,然后转化成实数
# xq_ 变成[bsz, seq_len, n_heads, head_dim // 2],把weight变成[1, seq_len, 1, head_dim // 2]
# xq_pos = torch.view_as_real(weight[None, :, None, :] * xq_).flatten(3)
# xk_pos = torch.view_as_real(weight[None, :, None, :] * xk_).flatten(3)
xq = torch.view_as_real(weight[None, :, None, :] * xq).flatten(3)
xk = torch.view_as_real(weight[None, :, None, :] * xk).flatten(3)
# flatten(3)是把第三维度后面的内容全部合并成一维,因为虚数变实数之后就变成(b, s, n_h, h // 2, 2)了
# assert xq_pos.shape == xq.shape
# assert xk_pos.shape == xk.shape
return xq, xk
class Attention(nn.Module):
def __init__(self, args: LMConfig) -> None:
super().__init__()
self.dim = args.dim # 模型维度 512
self.n_heads = args.n_heads # 注意力头数 16
self.n_kv_heads = args.n_kv_heads # kv 头数 8
assert self.n_heads % self.n_kv_heads == 0
self.n_rep = self.n_heads // self.n_kv_heads # kv 重复次数
assert self.dim % self.n_heads == 0
self.head_dim = self.dim // self.n_heads # 每个注意力头里面的张量维度
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias = False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias = False)
# self.attn_dropout = nn.Dropout(args.dropout) # 注意力 dropout
self.resid_dropout = nn.Dropout(args.dropout) # 残差 dropout
self.dropout = args.dropout # 给 flash attn 用的
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# 判断是否使用 Flash Attention。后者令 is_causal=True 可以实现掩码注意力功能。
if not self.flash:
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal = 1) # upper triangular 和 lower triangular
self.register_buffer("mask", mask) # 这样 mask 就不会反向更新了
# kv 缓存。因为测试的时候参数不再更新,所以每个 token 生成的 xk 和 xv 都不变。因此可以直接复用
self.k_cache, self.v_cache = None, None
def forward(self, x: torch.Tensor, weight: torch.Tensor, use_kv_cache = False):
# x 是(seq_len, dim)的输入,weight是旋转矩阵
# print("进来了 FORWARD!!")
bsz, seq_len, _ = x.shape
# print("进来了 FORWARD!!")
if use_kv_cache and self.eval(): # 评估模式,就是测试阶段的意思
# if self.k_cache is None or self.k_cache.shape[1] == x.shape[1] - 1: # x 的词数量比 k 缓存多一个
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
# print("缓冲是 None!")
# self.k_cache.shape[1] != x.shape[1] - 1 这一句不能不写!
# 因为你每处理一段新的上下文,是不会创建新模型对象的。换言之只用一个模型,处理若干个问题
# 那么当你切换到新的上下文的时候,你的 kv 缓冲按理必须要清空。
# 那么怎么判断你是否切换了新的上下文呢?就用 self.k_cache.shape[1] != x.shape[1] - 1 方法
# 否则你会出现 reshape 大小不匹配的问题!
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
token = x[:, -1:, :]
# print("1 号concat:")
xq = torch.concat((torch.zeros_like(x[:, : -1, :]), self.wq(token)), dim = 1)
# 只更新最后一个 token 的值,因为后面要有残差,所以相当于对于前面的词向量什么都不做
# print("2 号concat:")
xk = torch.concat((self.k_cache, self.wk(token)), dim = 1)
# print("3 号concat:")
xv = torch.concat((self.v_cache, self.wv(token)), dim = 1)
# 复用之前的 xw 和 xv
self.k_cache, self.v_cache = xk, xv
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = position_encoding(xq, xk, weight) # 给 q 和 k 加位置编码
xk, xv = repeat_kv(xk, self.n_rep), repeat_kv(xv, self.n_rep) # 把 k 和 v 重复 n_rep 遍
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if self.flash: # 直接算出 softmax(...) @ xv
# output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask = None,
# dropout_p = self.dropout if self.training else 0.0,
# is_causal = True) # is_causal = True 表示使用掩码
x = F.scaled_dot_product_attention(xq, xk, xv, attn_mask = None,
dropout_p = self.dropout if self.training else 0.0,
is_causal = True) # is_causal = True 表示使用掩码
# self.training 用来指示模型当前是否处于训练模式
else:
# scores = xq @ xk.transpose(2, 3) / math.sqrt(self.head_dim) # (bs, n_head, seq_len, seq_len)
# assert hasattr(self, "mask")
# scores = scores + self.mask[:, :, : seq_len, : seq_len] # 掩码,把后文盖住
# output = F.softmax(scores, dim = -1) @ xv # (bs, n_head, seq_len, head_dim)
x = xq @ xk.transpose(2, 3) / math.sqrt(self.head_dim) # (bs, n_head, seq_len, seq_len)
assert hasattr(self, "mask")
x = x + self.mask[:, :, : seq_len, : seq_len] # 掩码,把后文盖住
x = F.softmax(x, dim = -1) @ xv # (bs, n_head, seq_len, head_dim)
x = x.transpose(1, 2).contiguous().view(bsz, seq_len, -1) # (bs, seq_len, dim)
x = self.resid_dropout(self.wo(x))
return x
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multi: int, dropout: float) -> None:
# hidden_dim 默认是 None
# multi 隐藏层维度的倍数,默认为 64
# dropout 默认是 0.0
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multi * ((hidden_dim + multi - 1) // multi) # 没理解这么做的目的
# 最后算出来是 1408
self.w1 = nn.Linear(dim, hidden_dim, bias = False)
self.w2 = nn.Linear(dim, hidden_dim, bias = False)
self.w3 = nn.Linear(hidden_dim, dim, bias = False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
# return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
# return self.w3(F.silu(self.w1(x)) * self.w2(x))
x_2 = self.w2(x)
x = self.w1(x)
x = F.silu(x)
x = x * x_2
x = self.w3(x)
return x
class MoEGate(nn.Module):
def __init__(self, args: LMConfig) -> None:
super().__init__()
self.topk = args.num_experts_per_tok # top-k 里面的 k ,也就是选择的专家个数
self.gating_dim = args.dim # 门控维度,跟模型大小是一样的
self.n_routed_experts = args.num_experts_per_tok # 专家个数
self.scoring_func = args.scoring_func # 评分函数
self.norm_topk_prob = args.norm_topk_prob # 标准化 top-k 概率
self.alpha = args.aux_loss_alpha # 辅助损失函数的 alpha 参数
self.seq_aux = args.seq_aux # 是否在序列级别上计算辅助损失,默认为 True
self.w = nn.Linear(self.gating_dim, self.n_routed_experts, bias = False)
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_normal_(self.w.weight) # 初始化参数
def forward(self, x: torch.Tensor):
bsz, seq_len, dim = x.shape
hidden_states = x.view(-1, dim)
scores = self.w(hidden_states) # (bsz * seq_len, n_routed_experts)
if self.scoring_func == "softmax":
scores = F.softmax(scores, dim = -1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
# (bsz * seq_len, n_routed_experts),score[i][j] 表示每个序列里第 j 个专家的权重 / 评分
topk_weight, topk_idx = torch.topk(scores, self.topk, dim = -1, sorted = False)
# 获得k个最大的权重和对应的专家 (bsz * seq_len, k)
if self.norm_topk_prob: # 原文里还有判断self.topk > 1,我认为没有必要
denominator = topk_weight.sum(dim = -1) + 1e-20
topk_weight = topk_weight / denominator # 归一化权重
if self.training and self.alpha > 0:
# 训练阶段,并且 alpha > 0。要是 alpha <= 0 那 aux_loss 就是非正数,就是不合法的 loss
scores_for_aux = scores # (bsz * seq_len, n_routed_experts)
aux_topk = self.topk
topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # (bsz, seq_len * k)
if self.seq_aux: # 在序列级别上计算辅助损失
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1).mean(dim = 1)
# 第一步:算出 ce (bsz * n_routed_experts)
ce = torch.zeros(bsz, self.n_routed_experts)
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk,
device = hidden_states.device).div_(
seq_len * aux_topk / self.n_routed_experts
))
# 保留 topk_idx 里面的bsz,用 idx 作为第二维度在 ce 里进行累加
# 每个 batch 里使用的专家总数就是 seq_len * k ,这可能就是为什么要除以 seq_len * aux_topk
# 最后还要乘一个专家个数,这个在下面不在序列级别算损失的时候也要用
# 第二步:ce 和 scores_for_seq_aux 按位相乘
# 第三步:按位相乘的效果按专家求和,得到长为 bsz 的序列,然后求均值,再乘以 alpha
aux_loss = (ce * scores_for_seq_aux).sum(dim = -1).mean() * self.alpha
else:
# 第一步:算出 ce (1, n_routed_experts)
# 具体方法是把 idx 展平做出独热编码,然后求均值;然后还要乘以专家个数
ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes = self.n_routed_experts).mean(dim = 0)
ce = ce * self.n_routed_experts
# 保证维度是专家个数。因为不一定所有专家都被选上,所以不指定的话有可能独热码维度比专家小
# 独热码返回的维度是 (bsz * seq_len * k, n_routed_experts)
# 第二步:算出每个专家权重的均值 (1, n_routed_experts)
# 具体方法是对 scores_for_aux (也就是上面求出的 scores )求均值
# 第三步:对上面二者求和再乘以 alpha
aux_loss = (ce * scores_for_aux.mean(dim = 0)).sum() * self.alpha
else:
aux_loss = None
return topk_weight, topk_idx, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, args: LMConfig) -> None:
super().__init__()
self.topk = args.num_experts_per_tok # top-k 里面的 k ,也就是选择的专家个数
self.n_routed_experts = args.num_experts_per_tok # 专家个数
self.experts = nn.ModuleList([
FeedForward(dim = args.dim,
hidden_dim = args.hidden_dim,
multi = args.multiple_of,
dropout = args.dropout)
for _ in range(self.n_routed_experts)
])
self.gate = MoEGate(args)
if args.n_shared_experts is not None:
self.shared_experts = FeedForward(
dim = args.dim,
hidden_dim = args.hidden_dim,
multi = args.multiple_of,
dropout = args.dropout
)
def work(self, x, topk_weight, topk_idx):
bsz, seq_len, dim = x.shape
# 先把 x 复制 k 份
x = x.view(-1, dim)
x = x.repeat_interleave(self.topk, dim = 0) # (bsz * seq_len * k, dim)
# 把权重展平
flat_topk_idx = topk_idx.view(-1) # (bsz * seq_len * k)
# 过专家
y = torch.empty_like(x, dtype = torch.float16) # (bsz * seq_len * k, dim)
for i in range(self.n_routed_experts):
y[flat_topk_idx == i] = self.experts[i](x[flat_topk_idx == i])
# 乘以权重
y = y.view(bsz, seq_len, self.topk, -1)
y = y * topk_weight.unsqueeze(-1).sum(dim = 1) # (bsz * seq_len, dim)
# 恢复成输入的形状
y = y.view(bsz, seq_len, -1)
return y
def forward(self, x):
# x 是(bsz, seq_len, dim)
topk_weight, topk_idx, _ = self.gate(x) # 确定选哪些专家及其权重
# (bsz * seq_len, k)
if self.training:
y = self.work(x, topk_weight, topk_idx)
else:
with torch.no_grad:
y = self.work(x, topk_weight, topk_idx)
if self.args.n_shared_experts is not None:
y = y + self.shared_experts(y)
return y
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: LMConfig) -> None:
# layer_id 是当前块的编号
super().__init__()
self.attn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps)
self.attn = Attention(args)
self.ffn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps)
if args.use_moe:
self.feed_forward = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(dim = args.dim,
hidden_dim = args.hidden_dim,
multi = args.multiple_of,
dropout = args.dropout)
def forward(self, x, weight, use_kv_cache = False):
# print("吾来也!!")
# print("第一步")
# def forward(self, x: torch.Tensor, weight: torch.Tensor, use_kv_cache = False)
x = x + self.attn(self.attn_norm(x), weight, use_kv_cache)
# print("第二步")
x = x + self.feed_forward(self.ffn_norm(x))
# print("第三步")
return x
# class Transformer(nn.Module):
class Transformer(PreTrainedModel):
config_class = LMConfig
last_loss: Optional[torch.Tensor]
def __init__(self, args: LMConfig = None) -> None:
super().__init__(args)
if not args:
args = LMConfig()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.dim)
self.dropout = nn.Dropout(args.dropout) # 在 embedding 之后就要进行一个 dropout
self.layers = nn.ModuleList() # Transformer 块
for i in range(args.n_layers):
self.layers.append(TransformerBlock(i, args))
# 下面是旋转位置嵌入的权重,尺寸是 (max_seq_len, weight.dim // 2)
rotation_weight = get_rotation(dim = args.dim // args.n_heads, seq_len = args.max_seq_len)
self.register_buffer('rotation_weight', rotation_weight, persistent = False)
self.norm = RMSNorm(dim = args.dim, eps = args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias = False) # 最后的线性层
self.embedding.weight = self.output.weight
self.OUT = CausalLMOutputWithPast()
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
use_kv_cache = False, **key_args):
# Optional[torch.Tensor] 的意思就是可以传入张量,也可以传入 None
if 'input_ids' in key_args:
tokens = key_args['input_ids']
if 'attention_mask' in key_args:
tokens = key_args['attention_mask']
_, seq_len = tokens.shape # 输入的文本,一共有 bsz 个 batch ,每个文本的长度是 seq_len
x = self.embedding(tokens) # x 的尺寸是 (bsz, seq_len, dim)
x = self.dropout(x)
# print("embedding完成!")
# 下面就是获得位置编码,然后过 transformer block
r_w = self.rotation_weight[: seq_len]
# print("旋转编码完成!")
for layer in self.layers:
x = layer(x, r_w, use_kv_cache)
# print("正在训练......")
# print("Transformer块完成!")
x = self.norm(x) # 过归一化
if targets is not None: # 就是训练阶段的意思
logits = self.output(x) # (bsz, seq_len, vocal_size)
# print("算出预测值!")
# self.last_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index = -1)
last_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index = -1)
# print("算出误差")
# targets 是每个输入序列的下一个词
# ignore_idx 表示填充值,这里就是 -1 ,表示遇到 tensor 里有 -1 的直接当作空值处理
# F.cross_entropy 会先自动进行 softmax
else: # 就是评估阶段的意思
logits = self.output(x[:, [-1], :]) # (bsz, 1, vocal_size),也就是每个 batch 的最后一个序列
# self.last_loss = None
last_loss = None
# 没明白为什么一个是类变量,一个不是
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', last_loss)
# print("返回!")
return self.OUT
@torch.inference_mode()
# 可参看 https://zhuanlan.zhihu.com/p/667025336
def generate(self, idx, eos, max_new_tokens, temperature = 0.7, top_k = None,
stream = True, repetition_penalty = 1., use_kv_cache = True):
# idx 是 (bsz, seq_len),每个 seq 里面都是文本的词的下标
# eos 是 结束符。如果最后推出来生成的内容是结束符,就停止生成
# max_new_tokens 是最多能生成的词的个数
# temperature 是用来平滑概率的。在原来概率的基础上 * temperature 再进行 softmax 就可以缩小各词概率之间的差距,让选择概率小的词的几率增大
# top_k 是 Top-K Sampling 的参数。如果 top_k 是 None,那就不进行 Top-K Sampling ;否则就让 Top-K Sampling 里面的 k 是 top_k
# stream 指的是流式输出。如果要流式输出,那就是说每次生成新词就直接输出;否则就是全生成完毕再输出
# repetition_penalty 是惩罚项,用来降低前文出现过的词的概率,否则可能会出现循环文本
bsz, seq_len = idx.shape
while idx.shape[1] < max_new_tokens - 1: # 文本的大小不超过最大的文本长度,就可以继续
res = self(idx, use_kv_cache = use_kv_cache)
logits = res.logits # (bsz, vocal_size)
logits = logits[:, -1, :] # (bsz, 1, vocal_size)
# 降低前文出现过的词的概率
for b in range(bsz): # 遍历每一个 batch
for token in set(idx.tolist()[b]): # 获得不重复的 token 序列
logits[b, token] /= repetition_penalty
# 利用 temperature 进行概率的平滑
if temperature == 0.0:
# 直接选概率最高的 token ,idx_nxt 的尺寸是 (bsz, 1)
_, idx_nxt = torch.topk(logits, k = 1, dim = -1)
else:
logits = logits / temperature
if top_k is not None:
# 把概率排名在 k 以外的概率都设置成 0 ,这样就防止选择到概率过低的 token
v, _ = torch.topk(logits, k = min(top_k, logits.shape[-1]), dim = -1)
# v 在每个 batch 里是从大到小排好序的,尺寸是 (bsz, top_k)
logits[logits < v[:, [-1]]] = -float("Inf")
# v[:, [-1]] 就是每一个 batch 里的最小概率的概率值,大小是 (bsz, 1)
# logits < v[:, [-1]] 返回一个大小和 logits 一样的,由 True 和 False 组成的矩阵
# 具体来说,如果 logits[i][j] < v[i][0],那[i][j]位置就返回 True ,否则是 False
# 设置成负无穷,这样用 softmax 转成概率就是 0 了
probs = F.softmax(logits, dim = -1)
idx_nxt = torch.multinomial(probs, num_samples = 1, generator = None)
# 根据 prob 随机选择一个 token
if idx_nxt == eos:
break # 可能有问题
idx = torch.concat((idx, idx_nxt), dim = -1) # 放入新生成的内容
if stream:
yield idx[:, seq_len:] # 每次新生成内容就s输出所有新生成的东西
if not stream:
yield idx[:, seq_len:]
@torch.inference_mode()
def eval_answer(self, idx): # 没看出有什么作用
idx_cond = idx if idx.shape[1] < self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
res = self(idx_cond)
logits = res.logits[:, -1, :]
return logits