Hilbertmeng commited on
Commit
78bf0cf
·
1 Parent(s): c925b93
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MUDDFormer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_muddformer.MUDDFormerConfig",
7
+ "AutoModelForCausalLM": "modeling_muddformer.MUDDFormer"
8
+ },
9
+ "block_size": 2048,
10
+ "bos_token_id": 1,
11
+ "dense": true,
12
+ "dense_type": "qkvr",
13
+ "dim": 2560,
14
+ "dynamic_dense": true,
15
+ "eos_token_id": 2,
16
+ "expand_last": true,
17
+ "head_dim": 80,
18
+ "intermediate_size": 6912,
19
+ "is_training": false,
20
+ "model_type": "muddformer",
21
+ "n_head": 32,
22
+ "n_layer": 32,
23
+ "n_local_heads": 32,
24
+ "norm_eps": 1e-06,
25
+ "rope_base": 10000,
26
+ "round64": true,
27
+ "sepln": true,
28
+ "stack_hidden": false,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "bfloat16",
31
+ "transformers_version": "4.35.0",
32
+ "use_gradient_checkpointing": false,
33
+ "use_layer_cache": true,
34
+ "use_qk_norm": true,
35
+ "vocab_size": 50432
36
+ }
configuration_muddformer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from typing import Optional
3
+
4
+
5
+ def find_multiple(n: int, k: int) -> int:
6
+ if n % k == 0:
7
+ return n
8
+ return n + k - (n % k)
9
+
10
+ class MUDDFormerConfig(PretrainedConfig):
11
+ model_type = "muddformer"
12
+
13
+ '''
14
+ MUDDFormerConfig is a config class for MUDDFormer, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
15
+ '''
16
+ def __init__(
17
+ self,
18
+ block_size: int = 2048,
19
+ vocab_size: int = 50432,
20
+ n_layer: int = 32,
21
+ n_head: int = 32,
22
+ dim: int = 2560,
23
+ intermediate_size: int = None,
24
+ n_local_heads: int = -1,
25
+ head_dim: int = 64,
26
+ rope_base: float = 10000,
27
+ norm_eps: float = 1e-6,
28
+ use_gradient_checkpointing: bool = False,
29
+ is_training: bool = False,
30
+ use_qk_norm: bool = False ,
31
+ pad_token_id: Optional[int]= None,
32
+ bos_token_id: int =1,
33
+ eos_token_id: int =2,
34
+ tie_word_embeddings: bool =False,
35
+ use_layer_cache: bool = True,
36
+ stack_hidden: bool = False,
37
+ dense: bool = True,
38
+ dynamic_dense: bool = True,
39
+ sepln: bool = True,
40
+ dense_type: str = 'qkvr',
41
+ expand_last: bool = False,
42
+ round64: bool = False,
43
+ **kwargs,
44
+ ):
45
+ self.block_size=block_size
46
+ self.vocab_size=vocab_size
47
+ self.n_layer=n_layer
48
+ self.n_head=n_head
49
+ self.dim=dim
50
+ self.intermediate_size=intermediate_size
51
+ self.n_local_heads=n_local_heads
52
+ self.head_dim=head_dim
53
+ self.rope_base=rope_base
54
+ self.norm_eps=norm_eps
55
+ self.use_gradient_checkpointing=use_gradient_checkpointing
56
+ self.is_training=is_training
57
+ self.use_qk_norm=use_qk_norm
58
+
59
+ self.use_layer_cache= use_layer_cache
60
+ self.stack_hidden= stack_hidden
61
+ self.dense= dense
62
+ self.dynamic_dense= dynamic_dense
63
+ self.sepln= sepln
64
+ self.dense_type=dense_type
65
+ self.expand_last= expand_last
66
+ self.round64 = round64
67
+ # post init
68
+ if self.n_local_heads == -1:
69
+ self.n_local_heads = self.n_head
70
+ if self.intermediate_size is None:
71
+ hidden_dim = 4 * self.dim
72
+ n_hidden = int(2 * hidden_dim / 3)
73
+ self.intermediate_size = find_multiple(n_hidden, 256)
74
+ self.head_dim = self.dim // self.n_head
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
generation_demo.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ import os
6
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
7
+
8
+ device = torch.device('cuda:0')
9
+ MAX_BATCH_SIZE = 1
10
+ MAX_SEQ_LENGTH = 2048
11
+ NUM_TOKENS_TO_GENERATE = 10
12
+ COMPILE = True
13
+ OPTIMAZED_COMPPILE = False
14
+
15
+ if OPTIMAZED_COMPPILE:
16
+ import torch._dynamo.config
17
+ import torch._inductor.config
18
+ torch._dynamo.config.cache_size_limit = 64
19
+ torch._inductor.config.coordinate_descent_tuning = True
20
+ torch._inductor.config.triton.unique_kernel_names = True
21
+ torch._inductor.config.fx_graph_cache = True
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDFormer-2.8B")
24
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDFormer-2.8B", trust_remote_code=True)
25
+
26
+ _ = model.to(device=device,dtype=torch.bfloat16)
27
+ with torch.device(device):
28
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH)
29
+
30
+ def decode_one_token(model, cur_token, input_pos):
31
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
32
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
33
+ return new_token
34
+
35
+ prompt = "Beijing is the capital of China. London is the capital of"
36
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
37
+
38
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
39
+
40
+ print('Start generating tokens, but it will take a few minutes to compile at the first time.')
41
+ for i in range(10):
42
+ t0 = time.time()
43
+ with torch.no_grad():
44
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
45
+ text = tokenizer.decode(generated_ids[0])
46
+ if i ==0:
47
+ print(f'Generated text: {text}')
48
+ t1 = time.time()
49
+ print(f'Time consumed at iteration {i}: {t1-t0}s')
modeling_muddformer.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn import functional as F
8
+ from einops import rearrange
9
+ from collections import namedtuple
10
+ from torch.utils.checkpoint import checkpoint
11
+ from typing import Optional, Tuple, Union
12
+
13
+ try:
14
+ from .configuration_muddformer import MUDDFormerConfig
15
+ except:
16
+ from configuration_muddformer import MUDDFormerConfig
17
+
18
+ from transformers.modeling_utils import PreTrainedModel
19
+
20
+
21
+ def find_multiple(n: int, k: int) -> int:
22
+ if n % k == 0:
23
+ return n
24
+ return n + k - (n % k)
25
+
26
+ class KVCache(nn.Module):
27
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
28
+ super().__init__()
29
+ self.seq_length = max_seq_length
30
+ cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
31
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
32
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
33
+
34
+ def update(self, input_pos, k_val, v_val):
35
+ # input_pos: [S], k_val: [B, H, S, D]
36
+ assert input_pos.shape[0] == k_val.shape[2]
37
+ B,N,S,D = v_val.shape
38
+ k_out = self.k_cache
39
+ v_out = self.v_cache
40
+ k_out[:, :, input_pos] = k_val
41
+ v_out[:, :, input_pos] = v_val
42
+ return k_out, v_out
43
+
44
+ class LayerCache(nn.Module):
45
+ def __init__(self, max_batch_size, num_layers, model_dim, dtype=torch.bfloat16):
46
+ super().__init__()
47
+ cache_shape = (num_layers+1, max_batch_size, 1, model_dim) # LBTD
48
+ self.register_buffer('layer_cache', torch.zeros(cache_shape, dtype=dtype))
49
+
50
+ def update(self, x, lidx):
51
+ self.layer_cache[lidx] = x
52
+ return self.layer_cache[:lidx+1]
53
+
54
+ class MultiwayDynamicDenseBlock(nn.Module):
55
+ def __init__(self, config: MUDDFormerConfig, lidx: int, last_layer=False) -> None:
56
+ super().__init__()
57
+ self.norm = RMSnormNoscale(epsilon=config.norm_eps)
58
+ self.C = len(config.dense_type) if not last_layer else 1
59
+ self.lidx = lidx
60
+ l = lidx + 2
61
+ hid_dim, out_dim = l * self.C, l * self.C
62
+ if last_layer and config.expand_last: hid_dim *= 4
63
+ if config.round64: hid_dim = (hid_dim// 64 +1) * 64
64
+ self.w1 = nn.Linear(config.dim, hid_dim, bias=False)
65
+ self.act = nn.GELU()
66
+ self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ x = self.norm(x)
70
+ dw = self.w2(self.act(self.w1(x))) # BTD->BTL
71
+ dw = rearrange(dw, 'B T (C L) -> C B T L', C=self.C)
72
+ return dw
73
+
74
+ def layer_mix(self, hids, dw)-> Tensor:
75
+ x = tuple([sum(dw[cidx,:,:,j,None] * hids[j] for j in range(self.lidx+2)) for cidx in range(self.C)]) # BTL, LBTD-> BTD
76
+ return x
77
+
78
+ class MUDDFormer(PreTrainedModel):
79
+ config_class=MUDDFormerConfig
80
+ '''
81
+ MUDDFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
82
+ '''
83
+ def __init__(self, config: MUDDFormerConfig) -> None:
84
+ super().__init__(config)
85
+ self.config = config
86
+ self.use_gradient_checkpointing = config.use_gradient_checkpointing
87
+ self.is_training = config.is_training
88
+
89
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
90
+ self.layers = nn.ModuleList(TransformerBlock(config, lidx) for lidx in range(config.n_layer))
91
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
92
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
93
+ C = len(self.config.dense_type)
94
+ self.dense_bs = nn.ParameterList([nn.Parameter(data=torch.randn(C if lidx != config.n_layer-1 else 1, lidx+2)) for lidx in range(config.n_layer)])
95
+
96
+ self.layer_cache = None
97
+ self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
98
+ self.stack_hidden = self.config.stack_hidden
99
+
100
+ self.dynamic = self.config.dynamic_dense
101
+ self.dense = self.config.dense
102
+ if self.dynamic:
103
+ self.dynamic_dense = nn.ModuleList([MultiwayDynamicDenseBlock(config, lidx, last_layer=lidx==config.n_layer-1) for lidx in range(config.n_layer)])
104
+
105
+ self.freqs_cis: Optional[Tensor] = None
106
+ self.mask_cache: Optional[Tensor] = None
107
+ self.max_batch_size = -1
108
+ self.max_seq_length = -1
109
+
110
+ def tie_weights(self): # placeholder
111
+ return
112
+
113
+ def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
114
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
115
+ return
116
+ head_dim = self.config.dim // self.config.n_head
117
+ max_seq_length = find_multiple(max_seq_length, 8)
118
+ self.max_seq_length = max_seq_length
119
+ self.max_batch_size = max_batch_size
120
+ if not self.config.is_training:
121
+ if self.use_layer_cache:
122
+ self.layer_cache = LayerCache(max_batch_size, self.config.n_layer, self.config.dim)
123
+ for b in self.layers:
124
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=dtype)
125
+
126
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base).to(self.tok_embeddings.weight.device)
127
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
128
+
129
+ def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
130
+ batch_size, seq_length = input_ids.shape
131
+ input_pos = torch.arange(seq_length, device=self.device)
132
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
133
+ generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
134
+ logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
135
+ _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
136
+ next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
137
+ next_token[:batch_size] = _next_token
138
+ generated_ids[:, seq_length] = next_token[:batch_size, 0]
139
+ input_pos = torch.tensor([seq_length], device=self.device)
140
+ for _ in range(1, num_tokens_to_generate):
141
+ if compiled_decode_one_token is not None:
142
+ next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
143
+ else:
144
+ next_token = self.decode_one_token(next_token.clone(), input_pos)
145
+ generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
146
+ input_pos += 1
147
+ return generated_ids
148
+
149
+ def decode_one_token(self, cur_token, input_pos):
150
+ logits = self.forward(
151
+ cur_token,
152
+ input_pos=input_pos,
153
+ return_tensor=True
154
+ )
155
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
156
+ return new_token
157
+
158
+ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
159
+ assert self.freqs_cis is not None, "Caches must be initialized first"
160
+ if input_pos is None:
161
+ input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
162
+ mask = self.causal_mask[None, None, input_pos]
163
+ freqs_cis = self.freqs_cis[input_pos]
164
+ x = self.tok_embeddings(idx)
165
+ _, seqlen, _ = x.shape
166
+ use_layer_cache = self.use_layer_cache and seqlen == 1
167
+ if use_layer_cache:
168
+ self.layer_cache.update(x, 0)
169
+ else:
170
+ hiddens = [x]
171
+ for i, layer in enumerate(self.layers):
172
+ if self.use_gradient_checkpointing:
173
+ x = checkpoint(layer, x, input_pos, freqs_cis, mask)
174
+ else:
175
+ x = layer(x, input_pos, freqs_cis, mask)
176
+ if use_layer_cache:
177
+ _hidden = self.layer_cache.update(x, i+1) # LBTD
178
+ else:
179
+ hiddens.append(x)
180
+ _hidden = hiddens if not self.stack_hidden else hiddens
181
+ if self.dynamic and self.dense:
182
+ dw = self.dynamic_dense[i](x) # BTD -> CBTL
183
+ dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
184
+ if self.stack_hidden:
185
+ x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
186
+ else:
187
+ x = self.dynamic_dense[i].layer_mix(_hidden, dw)
188
+
189
+ if self.config.dense_type == 'qkvr' and self.config.dense and self.config.dynamic_dense:
190
+ x = x[0]
191
+ x = self.norm(x)
192
+ logits = self.output(x)
193
+ if return_tensor:
194
+ return logits
195
+ else:
196
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
197
+ return CausalLMOutput(logits=logits)
198
+
199
+ class TransformerBlock(nn.Module):
200
+ def __init__(self, config: MUDDFormerConfig, lidx) -> None:
201
+ super().__init__()
202
+ self.lidx = lidx
203
+ self.config = config
204
+ self.attention = Attention(config, lidx)
205
+ self.feed_forward = FeedForward(config, lidx)
206
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
207
+ if self.config.sepln and self.lidx > 0 :
208
+ self.attention_norms = torch.nn.ModuleList([RMSNorm(config.dim, config.norm_eps) for _ in range(3)])
209
+ else:
210
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
211
+
212
+ def forward(self, x: Union[Tuple[Tensor], Tensor], input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
213
+ if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
214
+ res = x
215
+ normed_x = self.attention_norm(x)
216
+ elif self.config.dense_type == 'qkvr':
217
+ res = x[-1] # for mlp
218
+ if self.config.stack_hidden or not self.config.sepln:
219
+ normed_x = self.attention_norm(x[:3])
220
+ else:
221
+ normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
222
+ attn_out = self.attention(normed_x, freqs_cis, mask, input_pos)
223
+ h = res + attn_out
224
+ out = h + self.feed_forward(self.ffn_norm(h))
225
+ return out
226
+
227
+ class Attention(nn.Module):
228
+ def __init__(self, config: MUDDFormerConfig, lidx):
229
+ super().__init__()
230
+ assert config.dim % config.n_head == 0
231
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
232
+ self.config = config
233
+ if self.config.dense_type == 'l' or not self.config.dense:
234
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
235
+ elif self.config.dense_type == 'qkvr':
236
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
237
+ self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
238
+ self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
239
+
240
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
241
+ self.lidx = lidx
242
+ self.kv_cache = None
243
+
244
+ self.n_head = config.n_head
245
+ self.head_dim = config.head_dim
246
+ self.scale_factor = 1 / math.sqrt(self.head_dim)
247
+ self.n_local_heads = config.n_local_heads
248
+ self.dim = config.dim
249
+
250
+ self.use_qk_norm = config.use_qk_norm
251
+ if self.use_qk_norm:
252
+ self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
253
+ self.k_norm = RMSNorm(self.head_dim, config.norm_eps)
254
+
255
+ self._register_load_state_dict_pre_hook(self.load_hook)
256
+
257
+ def load_hook(self, state_dict, prefix, *args):
258
+ if prefix + "wq.weight" in state_dict and (self.config.dense_type == 'l' or not self.config.dense):
259
+ wq = state_dict.pop(prefix + "wq.weight")
260
+ wk = state_dict.pop(prefix + "wk.weight")
261
+ wv = state_dict.pop(prefix + "wv.weight")
262
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
263
+
264
+ def forward(self, x: Union[Tuple[Tensor], Tensor], freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
265
+ if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
266
+ bsz, seqlen, _ = x.shape
267
+ else:
268
+ if self.config.stack_hidden:
269
+ C, bsz, seqlen, _ = x.shape
270
+ else:
271
+ C, (bsz, seqlen, _) = len(x), x[0].shape
272
+ kv_size = self.n_local_heads * self.head_dim
273
+
274
+ if self.config.dense_type == 'l' or not self.config.dense:
275
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
276
+
277
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
278
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
279
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
280
+ elif self.config.dense_type == 'qkvr':
281
+ if self.lidx == 0:
282
+ xq, xk, xv = x, x, x
283
+ else:
284
+ xq, xk, xv = x[0], x[1], x[2]
285
+ q = self.wq(xq).view(bsz, seqlen, self.n_head, self.head_dim)
286
+ k = self.wk(xk).view(bsz, seqlen, self.n_local_heads, self.head_dim)
287
+ v = self.wv(xv).view(bsz, seqlen, self.n_local_heads, self.head_dim)
288
+
289
+ if self.use_qk_norm:
290
+ q, k = self.q_norm(q), self.k_norm(k)
291
+
292
+ q = apply_rotary_emb(q, freqs_cis)
293
+ k = apply_rotary_emb(k, freqs_cis)
294
+
295
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
296
+
297
+ if self.kv_cache is not None:
298
+ if seqlen == 1:
299
+ k, v = self.kv_cache.update(input_pos, k, v)
300
+ else:
301
+ _, _ = self.kv_cache.update(input_pos, k, v)
302
+
303
+ if seqlen == 1: # one-token generation
304
+ k_mask = mask[:,:,:,:self.kv_cache.seq_length]
305
+ else:# prefill
306
+ k_mask = mask[:,:,:,:k.shape[-2]]
307
+
308
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
309
+ dtype = logits.dtype
310
+ min_value = torch.finfo(torch.float32).min
311
+ logits = logits.to(dtype=torch.float32)
312
+ logits = torch.where(k_mask, logits, min_value)
313
+ probs = logits.softmax(-1)
314
+ probs = probs.to(dtype=dtype)
315
+ y = probs @ v
316
+
317
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
318
+
319
+ y = self.wo(y)
320
+ return y
321
+
322
+ class FeedForward(nn.Module):
323
+ def __init__(self, config: MUDDFormerConfig, lidx, round128=True, scale_with_layer=True) -> None:
324
+ super().__init__()
325
+ hid_dim = config.intermediate_size
326
+ if scale_with_layer:
327
+ hid_dim = hid_dim * (lidx/(config.n_layer -1) +0.5)
328
+ if round128:
329
+ hid_dim = round(hid_dim / 128) * 128
330
+ self.w1 = nn.Linear(config.dim, hid_dim, bias=False)
331
+ self.w3 = nn.Linear(config.dim, hid_dim, bias=False)
332
+ self.w2 = nn.Linear(hid_dim, config.dim, bias=False)
333
+
334
+ def forward(self, x: Tensor) -> Tensor:
335
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
336
+
337
+ class RMSNorm(nn.Module):
338
+ def __init__(self, dim: int, eps: float = 1e-5):
339
+ super().__init__()
340
+ self.eps = eps
341
+ self.weight = nn.Parameter(torch.ones(dim))
342
+
343
+ def _norm(self, x):
344
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
345
+
346
+ def forward(self, x: Tensor) -> Tensor:
347
+ output = self._norm(x.float()).type_as(x)
348
+ return output * self.weight
349
+
350
+ class RMSnormNoscale(nn.Module):
351
+
352
+ def __init__(self, epsilon=1e-6, dim=-1):
353
+ super().__init__()
354
+ self.dim = dim
355
+ self.epsilon = epsilon
356
+
357
+ def forward(self, inputs):
358
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
359
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
360
+ return normed_inputs
361
+
362
+ def precompute_freqs_cis(
363
+ seq_len: int, n_elem: int, base: int = 10000
364
+ ) -> Tensor:
365
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
366
+ t = torch.arange(seq_len, device=freqs.device)
367
+ freqs = torch.outer(t, freqs)
368
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
369
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
370
+ return cache.to(dtype=torch.bfloat16)
371
+
372
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
373
+ if mode == 'half':
374
+ xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
375
+ elif mode == 'alternative':
376
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
377
+ freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
378
+ x_out2 = torch.stack(
379
+ [
380
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
381
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
382
+ ],
383
+ -1,
384
+ )
385
+ x_out2 = x_out2.flatten(3)
386
+ return x_out2.type_as(x)
387
+
388
+ def match_weight_muddformer(model, w, strict=False):
389
+ map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1': 'ffn_layer1_gate', 'w3': 'ffn_layer1', 'w2': 'ffn_layer2',
390
+ 'weight': 'w'}
391
+ E, H, D = model.config.dim, model.config.n_head, model.config.head_dim
392
+ N = model.config.vocab_size
393
+ state_dict = {}
394
+ for k, v in model.named_parameters():
395
+ if k == 'tok_embeddings.weight':
396
+ v = w['state.mdl_vars.params.lm.embedding_lookup.emb_var']#[:50257,:]
397
+ elif k == 'norm.weight':
398
+ v = w['state.mdl_vars.params.lm.final_ln.scale']
399
+ elif k == 'output.weight':
400
+ v = w['state.mdl_vars.params.lm.softmax.logits_ffn.linear.w'].T#[:50257,:] # E,N -> N,E
401
+ elif 'dense_bs' in k: # static dense w
402
+ lidx = int(k.split('.')[-1])
403
+ v = w[f'state.mdl_vars.params.lm.transformer.dense_conn_{lidx}']
404
+ elif 'dynamic_dense' in k:
405
+ lidx = int(k.split('.')[1])
406
+ widx = int(k.split('.')[2][-1]) # 1 or 2 in w1, w2
407
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.dynamic_dense_conn{widx}_{lidx}'].T
408
+ else:
409
+ assert 'layers' in k
410
+ lidx = int(k.split('.')[1])
411
+ if '.attention.' in k:
412
+ _, _, _, ptype, wtype = k.split('.')
413
+ if ptype in ['wq', 'wk', 'wv', 'wo']:
414
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.{map_dict.get(wtype, wtype)}'].reshape(E,E)
415
+ if ptype != 'wo':
416
+ v = v.T
417
+ elif ptype in ['q_norm', 'k_norm']:
418
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.scale']
419
+ elif 'feed_forward' in k:
420
+ ptype = k.split('.')[3] # w1, w3,w2
421
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.linear.w'].T
422
+ elif 'ffn_norm' in k: # mlp layernorm
423
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.layer_norm.scale']
424
+ elif 'attention_norm' in k: # attention layernorm
425
+ if 'attention_norms' in k:
426
+ ln_idx = int(k.split('.')[3])
427
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norms_{ln_idx}.scale']
428
+ else:
429
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norm.scale']
430
+ state_dict[k] = torch.tensor(v)
431
+ model.load_state_dict(state_dict, strict=strict)
432
+ return model
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.5.1
2
+ transformers==4.35.0
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }