from typing import Optional, Tuple, Union from collections import namedtuple from einops import rearrange import math import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from torch.utils.checkpoint import checkpoint from .configuration_muddpythia import MUDDPythiaConfig #try: # from .configuration_muddpythia import MUDDPythiaConfig #except: # from configuration_muddpythia import MUDDPythiaConfig from transformers.modeling_utils import PreTrainedModel class KVCache(nn.Module): def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float16): super().__init__() self.seq_length = max_seq_length cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim) self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] k_out = self.k_cache v_out = self.v_cache k_out[:, :, input_pos] = k_val v_out[:, :, input_pos] = v_val return k_out, v_out class LayerCache(nn.Module): def __init__(self, max_batch_size, num_layers, model_dim, dtype=torch.float16): super().__init__() cache_shape = (num_layers+1, max_batch_size, 1, model_dim) # LBTD self.register_buffer('layer_cache', torch.zeros(cache_shape, dtype=dtype)) def update(self, x, lidx): self.layer_cache[lidx] = x return self.layer_cache[:lidx+1] class MultiwayDynamicDenseBlock(nn.Module): def __init__(self, config: MUDDPythiaConfig, lidx: int, last_layer=False) -> None: super().__init__() self.norm = RMSnormNoscale(epsilon=config.norm_eps) self.C = len(config.dense_type) if not last_layer else 1 self.lidx = lidx l = lidx + 2 hid_dim, out_dim = l * self.C, l * self.C if last_layer and config.expand_last: hid_dim *= 4 if config.round64: hid_dim = (hid_dim// 64 +1) * 64 self.w1 = nn.Linear(config.dim, hid_dim, bias=False) self.act = nn.GELU() self.w2 = nn.Linear(hid_dim, out_dim, bias=False) def forward(self, x: Tensor) -> Tensor: x = self.norm(x) dw = self.w2(self.act(self.w1(x))) # BTD->BTL dw = rearrange(dw, 'B T (C L) -> C B T L', C=self.C) return dw def layer_mix(self, hids, dw)-> Tensor: x = tuple([sum(dw[cidx,:,:,j,None] * hids[j] for j in range(self.lidx+2)) for cidx in range(self.C)]) # BTL, LBTD-> BTD return x class MUDDPythia(PreTrainedModel): config_class=MUDDPythiaConfig def __init__(self, config: MUDDPythiaConfig) -> None: super().__init__(config) self.config = config self.use_gradient_checkpointing = config.use_gradient_checkpointing self.is_training = config.is_training self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList(TransformerBlock(config, lidx) for lidx in range(config.n_layer)) self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) C = len(self.config.dense_type) self.dense_bs = nn.ParameterList([nn.Parameter(data=torch.randn(C if lidx != config.n_layer-1 else 1, lidx+2)) for lidx in range(config.n_layer)]) self.layer_cache = None self.use_layer_cache = False if self.is_training else self.config.use_layer_cache self.stack_hidden = self.config.stack_hidden self.dynamic = self.config.dynamic_dense self.dense = self.config.dense if self.dynamic: self.dynamic_dense = nn.ModuleList([MultiwayDynamicDenseBlock(config, lidx, last_layer=lidx==config.n_layer-1) for lidx in range(config.n_layer)]) self.rotary_ndims = int(config.head_dim * config.rotary_pct) self.freqs_cis: Optional[Tensor] = None self.mask_cache: Optional[Tensor] = None self.max_batch_size = -1 self.max_seq_length = -1 def tie_weights(self): # placeholder return def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float16): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: return head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size if not self.config.is_training: if self.use_layer_cache: self.layer_cache = LayerCache(max_batch_size, self.config.n_layer, self.config.dim, dtype=dtype) for b in self.layers: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=dtype) self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device)) def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None): batch_size, seq_length = input_ids.shape input_pos = torch.arange(seq_length, device=self.device) generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device) generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int) logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True) _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int) next_token[:batch_size] = _next_token generated_ids[:, seq_length] = next_token[:batch_size, 0] input_pos = torch.tensor([seq_length], device=self.device) for _ in range(1, num_tokens_to_generate): if compiled_decode_one_token is not None: next_token = compiled_decode_one_token(self, next_token.clone(), input_pos) else: next_token = self.decode_one_token(next_token.clone(), input_pos) generated_ids[:, input_pos+1] = next_token.int()[:batch_size] input_pos += 1 return generated_ids def decode_one_token(self, cur_token, input_pos): logits = self.forward( cur_token, input_pos=input_pos, return_tensor=True ) new_token = torch.argmax(logits[:, -1], dim=-1)[:,None] return new_token def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" if input_pos is None: input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int) mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) _, seqlen, _ = x.shape use_layer_cache = self.use_layer_cache and seqlen == 1 if use_layer_cache: self.layer_cache.update(x, 0) else: hiddens = [x] for i, layer in enumerate(self.layers): if self.use_gradient_checkpointing: x = checkpoint(layer, x, input_pos, freqs_cis, mask) else: x = layer(x, input_pos, freqs_cis, mask) if use_layer_cache: _hidden = self.layer_cache.update(x, i+1) # LBTD else: hiddens.append(x) _hidden = hiddens if not self.stack_hidden else hiddens if self.dynamic and self.dense: dw = self.dynamic_dense[i](x) # BTD -> CBTL dw = dw + self.dense_bs[i][:,None,None,:] # CBTL if self.stack_hidden: x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw) else: x = self.dynamic_dense[i].layer_mix(_hidden, dw) if self.config.dense_type == 'qkvr' and self.config.dense and self.config.dynamic_dense: x = x[0] x = self.norm(x) logits = self.output(x) if return_tensor: return logits else: CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=logits) class TransformerBlock(nn.Module): def __init__(self, config: MUDDPythiaConfig, lidx) -> None: super().__init__() self.lidx = lidx self.config = config self.attention = Attention(config, lidx) self.feed_forward = FeedForward(config, lidx) self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) self.use_parallel_residual = config.use_parallel_residual if self.config.sepln and self.lidx > 0: self.attention_norms = torch.nn.ModuleList([ nn.LayerNorm(config.dim, eps=config.norm_eps) for _ in range(3)]) else: self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) def forward(self, x: Union[Tuple[Tensor], Tensor], input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: if self.config.dense_type == 'l' or self.lidx == 0 or not self.config.dense: res = x normed_x = self.attention_norm(x) elif self.config.dense_type == 'qkvr': res = x[-1] # for mlp if self.config.stack_hidden or not self.config.sepln: normed_x = self.attention_norm(x[:3]) else: normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])]) h = res + self.attention(normed_x, freqs_cis, mask, input_pos) out = h + self.feed_forward(self.ffn_norm(res if self.use_parallel_residual else h)) return out class Attention(nn.Module): def __init__(self, config: MUDDPythiaConfig, lidx): super().__init__() assert config.dim % config.n_head == 0 total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim # key, query, value projections for all heads, but in a batch self.config = config if self.config.dense_type == 'l' or not self.config.dense: self.wqkv = nn.Linear(config.dim, total_head_dim, bias=True) elif self.config.dense_type == 'qkvr': self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=True) self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True) self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True) self.wo = nn.Linear(config.dim, config.dim, bias=True) self.lidx = lidx self.kv_cache = None self.n_head = config.n_head self.head_dim = config.head_dim self.scale_factor = 1 / math.sqrt(self.head_dim) self.n_local_heads = config.n_local_heads self.dim = config.dim self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: self.q_norm = RMSNorm(self.head_dim, config.norm_eps) self.k_norm = RMSNorm(self.head_dim, config.norm_eps) self.rotary_ndims = int(self.head_dim * config.rotary_pct) self._register_load_state_dict_pre_hook(self.load_hook) def load_hook(self, state_dict, prefix, *args): if prefix + "wq.weight" in state_dict and (self.config.dense_type == 'l' or not self.config.dense): wq = state_dict.pop(prefix + "wq.weight") wk = state_dict.pop(prefix + "wk.weight") wv = state_dict.pop(prefix + "wv.weight") state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) def forward(self, x: Union[Tuple[Tensor], Tensor], freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense: bsz, seqlen, _ = x.shape else: if self.config.stack_hidden: C, bsz, seqlen, _ = x.shape else: C, (bsz, seqlen, _) = len(x), x[0].shape kv_size = self.n_local_heads * self.head_dim if self.config.dense_type == 'l' or not self.config.dense: q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) elif self.config.dense_type == 'qkvr': if self.lidx == 0: xq, xk, xv = x, x, x else: xq, xk, xv = x[0], x[1], x[2] q = self.wq(xq).view(bsz, seqlen, self.n_head, self.head_dim) k = self.wk(xk).view(bsz, seqlen, self.n_local_heads, self.head_dim) v = self.wv(xv).view(bsz, seqlen, self.n_local_heads, self.head_dim) if self.use_qk_norm: q, k = self.q_norm(q), self.k_norm(k) if self.rotary_ndims == self.head_dim: q = apply_rotary_emb(q, freqs_cis) #BTND k = apply_rotary_emb(k, freqs_cis) else: q_rot = q[..., : self.rotary_ndims] q_pass = q[..., self.rotary_ndims :] k_rot = k[..., : self.rotary_ndims] k_pass = k[..., self.rotary_ndims :] q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half') q = torch.cat((q_rot, q_pass), dim=-1) k = torch.cat((k_rot, k_pass), dim=-1) q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) if self.kv_cache is not None: if seqlen == 1: k, v = self.kv_cache.update(input_pos, k, v) else: _, _ = self.kv_cache.update(input_pos, k, v) if seqlen == 1: # one-token generation k_mask = mask[:,:,:,:self.kv_cache.seq_length] else:# prefill k_mask = mask[:,:,:,:k.shape[-2]] logits = q @ k.transpose(-2, -1) * self.scale_factor min_value = torch.finfo(torch.float16).min logits = torch.where(k_mask, logits, min_value) probs = logits.softmax(-1) y = probs @ v y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) y = self.wo(y) return y class RMSnormNoscale(nn.Module): def __init__(self, epsilon=1e-6, dim=-1): super().__init__() self.dim = dim self.epsilon = epsilon def forward(self, inputs): var = inputs.pow(2).mean(dim=self.dim, keepdim=True) normed_inputs = inputs * torch.rsqrt(var + self.epsilon) return normed_inputs class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight class FeedForward(nn.Module): def __init__(self, config: MUDDPythiaConfig, lidx, round128=True, scale_with_layer=True) -> None: super().__init__() hid_dim = config.intermediate_size if scale_with_layer: hid_dim = hid_dim * (lidx/(config.n_layer -1) +0.5) if round128: hid_dim = round(hid_dim / 128) * 128 self.w1 = nn.Linear(config.dim, hid_dim, bias=config.use_linear_bias) self.w2 = nn.Linear(hid_dim, config.dim, bias=config.use_linear_bias) def forward(self, x: Tensor) -> Tensor: return self.w2(F.gelu(self.w1(x))) def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 10000 ) -> Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) return cache.to(dtype=torch.float16) def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor: if mode == 'half': xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2) elif mode == 'alternative': xshaped = x.float().reshape(*x.shape[:-1], -1, 2) freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return x_out2.type_as(x) def match_weight_muddpythia(model, w, strict=False, pythia=True): if pythia: map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1':'ffn_layer1', 'w2': 'ffn_layer2', 'weight': 'w', 'bias': 'b'} else: map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1': 'ffn_layer1_gate', 'w3': 'ffn_layer1', 'w2': 'ffn_layer2', 'weight': 'w', 'bias': 'b'} ln_dict={'weight':'scale','bias':'bias'} E, H, D = model.config.dim, model.config.n_head, model.config.head_dim N = model.config.vocab_size state_dict = {} for k, v in model.named_parameters(): if k == 'tok_embeddings.weight': v = w['state.mdl_vars.params.lm.embedding_lookup.emb_var'][:N,:] elif k == 'norm.weight': v = w['state.mdl_vars.params.lm.final_ln.scale'] elif k == 'norm.bias': v = w['state.mdl_vars.params.lm.final_ln.bias'] elif k == 'output.weight': v = w['state.mdl_vars.params.lm.softmax.logits_ffn.linear.w'].T[:N,:] # E,N -> N,E elif 'dense_bs' in k: # static dense w lidx = int(k.split('.')[-1]) v = w[f'state.mdl_vars.params.lm.transformer.dense_conn_{lidx}'] elif 'dynamic_dense' in k: lidx = int(k.split('.')[1]) widx = int(k.split('.')[2][-1]) # 1 or 2 in w1, w2 v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.dynamic_dense_conn{widx}_{lidx}'].T else: assert 'layers' in k lidx = int(k.split('.')[1]) if '.attention.' in k: _, _, _, ptype, wtype = k.split('.') if ptype in ['wq', 'wk', 'wv', 'wo']: v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.{map_dict.get(wtype, wtype)}']#.reshape(E,E) if wtype == 'weight': v = v.reshape(E,E) elif wtype == 'bias': v = v.reshape(E) if ptype != 'wo' and wtype == 'weight': v = v.T elif ptype in ['q_norm', 'k_norm']: v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.scale'] elif 'feed_forward' in k: ptype = k.split('.')[3] # w1, w3,w2 wtype = k.split('.')[4] # weight or bias if wtype=='weight': v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.linear.{map_dict[wtype]}'] v = v.T elif wtype=='bias': v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.bias.{map_dict[wtype]}'] elif 'ffn_norm' in k: # mlp layernorm wtype = k.split('.')[-1] # weight or bias v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.layer_norm.{ln_dict[wtype]}'] elif 'attention_norm' in k: # attention layernorm wtype = k.split('.')[-1] # weight or bias if 'attention_norms' in k: ln_idx = int(k.split('.')[3]) v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norms_{ln_idx}.{ln_dict[wtype]}'] else: v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norm.{ln_dict[wtype]}'] if pythia and 'weight' in k and 'norm' in k and 'q_norm' not in k and 'k_norm' not in k: v = v+1 state_dict[k] = torch.tensor(v) model.load_state_dict(state_dict, strict=strict) return model