import torch import torch.nn as nn import torch.nn.functional as F from models.norm import RMSNorm from models.rope import precompute_freqs_cis, apply_rotary_emb import bitsandbytes as bnb import math class NormalLinear(nn.Linear): def reset_parameters(self) -> None: pass class BnbInt8Linear(bnb.nn.Linear8bitLt): def __init__(self, *args, **kwargs): super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs) def reset_parameters(self) -> None: pass def get_linear_layer(use_int8): if use_int8: return BnbInt8Linear return NormalLinear class WordEmbedding(nn.Module): def __init__(self, args): super(WordEmbedding, self).__init__() self.embedding = nn.Embedding(args.vocab_size, args.emb_size) def forward(self, src): emb = self.embedding(src) return emb class MultiHeadedAttention(nn.Module): def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True): super(MultiHeadedAttention, self).__init__() self.heads_num = heads_num self.per_head_size = attention_head_size self.inner_hidden_size = heads_num * attention_head_size Linear = get_linear_layer(use_int8) self.linear_layers = nn.ModuleList( [Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)] ) self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias) # add cache to reduce compute source. self.cache_k = torch.zeros( (args.batch_size, args.seq_length, self.heads_num, self.per_head_size) ) self.cache_v = torch.zeros( (args.batch_size, args.seq_length, self.heads_num, self.per_head_size) ) def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis): batch_size, seq_length, _ = query.size() heads_num = self.heads_num per_head_size = self.per_head_size query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \ for l, x in zip(self.linear_layers, (query, key, value))] query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis) if self.cache_k.device != key.device: self.cache_k = self.cache_k.to(key) if self.cache_v.device != value.device: self.cache_v = self.cache_v.to(value) self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value key = self.cache_k[continue_exsample, : start_pos + seq_length] value = self.cache_v[continue_exsample, : start_pos + seq_length] query, key, value = [x.transpose(1, 2) for x in (query, key, value)] scores = torch.matmul(query, key.transpose(-2, -1)) scores = scores / math.sqrt(float(per_head_size)) if mask is not None: scores += mask # probs = nn.Softmax(dim=-1)(scores) probs = F.softmax(scores.float(), dim=-1).type_as(query) output = torch.matmul(probs, value).transpose(1, 2).\ contiguous().view(batch_size, seq_length, -1) return self.final_linear(output) class GatedFeedForward(nn.Module): def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True): super(GatedFeedForward, self).__init__() Linear = get_linear_layer(use_int8) self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias) self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias) self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias) self.act = F.silu def forward(self, x): # gate = self.act(self.linear_gate(x)) gate = self.act(self.linear_gate(x)).type_as(x) inter_linear = self.linear_1(x) inter = gate * inter_linear output = self.linear_2(inter) return output class TransformerLayer(nn.Module): def __init__(self, args): super(TransformerLayer, self).__init__() if hasattr(args, "attention_head_size"): attention_head_size = args.attention_head_size else: attention_head_size = args.hidden_size // args.heads_num has_bias = bool(1 - args.remove_transformer_bias) # Multi-head Attention self.self_attn = MultiHeadedAttention( args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias, use_int8=args.use_int8 ) # FFN self.feed_forward = GatedFeedForward( args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8 ) self.layer_norm_1 = RMSNorm(args.hidden_size) self.layer_norm_2 = RMSNorm(args.hidden_size) def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None): inter = self.layer_norm_1(hidden) inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis) hidden = hidden + inter output = self.layer_norm_2(hidden) output = self.feed_forward(output) + hidden return output class TransformerEncoder(nn.Module): def __init__(self, args): super(TransformerEncoder, self).__init__() self.mask = args.mask self.layers_num = args.layers_num self.transformer = nn.ModuleList( [TransformerLayer(args) for _ in range(self.layers_num)] ) self.layer_norm = RMSNorm(args.hidden_size) self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) def forward(self, emb, start_pos, continue_exsample): batch_size, seq_length, _ = emb.size() mask = None if seq_length > 1: mask = torch.ones(seq_length, seq_length, device=emb.device) mask = torch.tril(mask) mask = (1.0 - mask) * -10000 mask = mask.repeat(batch_size, 1, 1, 1) hidden = emb freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device) for i in range(self.layers_num): hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis) return self.layer_norm(hidden) class LmOutput(nn.Module): def __init__(self, args): super(LmOutput, self).__init__() # update: lm output not use int8 Linear = get_linear_layer(False) self.lm = Linear(args.hidden_size, args.vocab_size, bias=False) def forward(self, x): return self.lm(x[:, -1, :]) class LLaMa(nn.Module): def __init__(self, args): super(LLaMa, self).__init__() self.embedding = WordEmbedding(args) self.encoder = TransformerEncoder(args) self.target = LmOutput(args) #@torch.inference_mode() def forward(self, src, start_pos, continue_exsample): emb = self.embedding(src) output = self.encoder(emb, start_pos, continue_exsample) output = self.target(output) return output