Hilbertmeng commited on
Commit
fbde94d
·
1 Parent(s): 7907b9c

add model code

Browse files
README.md CHANGED
@@ -1,3 +1,75 @@
1
  ---
 
 
 
 
 
 
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - pytorch
6
+ - causal-lm
7
+ - muddformer
8
  license: mit
9
  ---
10
+ In comparison with Pythia-2.8B, MUDDPythia-2.8B is a pretrained language model on the Pile with 300B tokens, which uses a simple yet effective method to address the limitations of residual connections and enhance cross-layer information flow in Transformers. Please see downstrem evaluations and more details in the paper[(MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections)](https://arxiv.org). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/MUDDFormer/).
11
+
12
+ We recommend <strong>compiled version</strong> of MUDDPythia with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.
13
+
14
+ # Usage
15
+
16
+ ## Env
17
+
18
+ ```
19
+ pip install transformers==4.35.0 torch==2.5.1
20
+ ```
21
+
22
+ ## Generation
23
+
24
+ ```
25
+ import time
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ import torch
28
+
29
+ import os
30
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
31
+
32
+ device = torch.device('cuda:0')
33
+ dtype = torch.bfloat16
34
+ MAX_BATCH_SIZE = 1
35
+ MAX_SEQ_LENGTH = 2048
36
+ NUM_TOKENS_TO_GENERATE = 10
37
+ COMPILE = True
38
+ OPTIMIZED_COMPPILE = False
39
+
40
+ if OPTIMIZED_COMPPILE:
41
+ import torch._dynamo.config
42
+ import torch._inductor.config
43
+ torch._dynamo.config.cache_size_limit = 64
44
+ torch._inductor.config.coordinate_descent_tuning = True
45
+ torch._inductor.config.triton.unique_kernel_names = True
46
+ torch._inductor.config.fx_graph_cache = True
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDPythia-2.8B")
49
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDPythia-2.8B", trust_remote_code=True)
50
+
51
+ _ = model.to(device=device,dtype=dtype)
52
+ with torch.device(device):
53
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH,dtype=dtype)
54
+
55
+ def decode_one_token(model, cur_token, input_pos):
56
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
57
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
58
+ return new_token
59
+
60
+ prompt = "Beijing is the capital of China. London is the capital of"
61
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
62
+
63
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
64
+
65
+ print('Start generating tokens, but it will take a few minutes to compile at the first time.')
66
+ for i in range(10):
67
+ t0 = time.time()
68
+ with torch.no_grad():
69
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
70
+ text = tokenizer.decode(generated_ids[0])
71
+ if i ==0:
72
+ print(f'Generated text: {text}')
73
+ t1 = time.time()
74
+ print(f'Time consumed at iteration {i}: {t1-t0}s')
75
+ ```
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MUDDPythia"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_muddpythia.MUDDPythiaConfig",
7
+ "AutoModelForCausalLM": "modeling_muddpythia.MUDDPythia"
8
+ },
9
+ "block_size": 2048,
10
+ "bos_token_id": 0,
11
+ "dense": true,
12
+ "dense_type": "qkvr",
13
+ "dim": 2560,
14
+ "dynamic_dense": true,
15
+ "eos_token_id": 0,
16
+ "expand_last": true,
17
+ "head_dim": 80,
18
+ "intermediate_size": 10240,
19
+ "is_training": false,
20
+ "model_type": "muddpythia",
21
+ "n_head": 32,
22
+ "n_layer": 32,
23
+ "n_local_heads": 32,
24
+ "norm_eps": 1e-05,
25
+ "rope_base": 10000,
26
+ "rotary_pct": 0.25,
27
+ "round64": true,
28
+ "sepln": true,
29
+ "stack_hidden": false,
30
+ "tie_word_embeddings": false,
31
+ "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.35.0",
33
+ "use_gradient_checkpointing": false,
34
+ "use_layer_cache": true,
35
+ "use_linear_bias": true,
36
+ "use_parallel_residual": true,
37
+ "use_qk_norm": true,
38
+ "vocab_size": 50432
39
+ }
configuration_muddpythia.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from typing import Optional
3
+
4
+
5
+ class MUDDPythiaConfig(PretrainedConfig):
6
+ model_type = "muddpythia"
7
+
8
+ '''
9
+ MUDDPythiaConfig is a config class for MUDDPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
10
+ '''
11
+ def __init__(
12
+ self,
13
+ block_size: int = 2048,
14
+ vocab_size: int = 32000,
15
+ n_layer: int = 32,
16
+ n_head: int = 32,
17
+ dim: int = 2560,
18
+ intermediate_size: int = None,
19
+ n_local_heads: int = -1,
20
+ head_dim: int = 64,
21
+ rope_base: float = 10000,
22
+ norm_eps: float = 1e-5,
23
+ use_gradient_checkpointing: bool = False,
24
+ is_training: bool = False,
25
+ use_qk_norm: bool = False ,
26
+ pad_token_id: Optional[int]= None,
27
+ use_parallel_residual: bool =True,
28
+ use_linear_bias: bool = True,
29
+ rotary_pct: float = 0.25,
30
+ bos_token_id: int =1,
31
+ eos_token_id: int =2,
32
+ tie_word_embeddings: bool =False,
33
+ use_layer_cache: bool = True,
34
+ stack_hidden: bool = False,
35
+ dense: bool = True,
36
+ dynamic_dense: bool = True,
37
+ sepln: bool = True,
38
+ dense_type: str = 'qkvr',
39
+ expand_last: bool = False,
40
+ round64: bool = False,
41
+ **kwargs,
42
+ ):
43
+ self.block_size=block_size
44
+ self.vocab_size=vocab_size
45
+ self.n_layer=n_layer
46
+ self.n_head=n_head
47
+ self.dim=dim
48
+ self.intermediate_size=intermediate_size
49
+ self.n_local_heads=n_local_heads
50
+ self.head_dim=head_dim
51
+ self.rope_base=rope_base
52
+ self.norm_eps=norm_eps
53
+ self.use_gradient_checkpointing=use_gradient_checkpointing
54
+ self.is_training=is_training
55
+ self.use_qk_norm=use_qk_norm
56
+ self.use_parallel_residual = use_parallel_residual
57
+ self.use_linear_bias = use_linear_bias
58
+ self.rotary_pct = rotary_pct
59
+
60
+ self.use_layer_cache= use_layer_cache
61
+ self.stack_hidden= stack_hidden
62
+ self.dense= dense
63
+ self.dynamic_dense= dynamic_dense
64
+ self.sepln= sepln
65
+ self.dense_type=dense_type
66
+ self.expand_last= expand_last
67
+ self.round64 = round64
68
+ # post init
69
+ if self.n_local_heads == -1:
70
+ self.n_local_heads = self.n_head
71
+ if self.intermediate_size is None:
72
+ self.intermediate_size = 4 * self.dim
73
+ self.head_dim = self.dim // self.n_head
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
generation_demo.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ import os
6
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
7
+
8
+ device = torch.device('cuda:0')
9
+ dtype = torch.bfloat16
10
+ MAX_BATCH_SIZE = 1
11
+ MAX_SEQ_LENGTH = 2048
12
+ NUM_TOKENS_TO_GENERATE = 10
13
+ COMPILE = True
14
+ OPTIMIZED_COMPPILE = False
15
+
16
+ if OPTIMIZED_COMPPILE:
17
+ import torch._dynamo.config
18
+ import torch._inductor.config
19
+ torch._dynamo.config.cache_size_limit = 64
20
+ torch._inductor.config.coordinate_descent_tuning = True
21
+ torch._inductor.config.triton.unique_kernel_names = True
22
+ torch._inductor.config.fx_graph_cache = True
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDPythia-2.8B")
25
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDPythia-2.8B", trust_remote_code=True)
26
+
27
+ _ = model.to(device=device,dtype=dtype)
28
+ with torch.device(device):
29
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH,dtype=dtype)
30
+
31
+ def decode_one_token(model, cur_token, input_pos):
32
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
33
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
34
+ return new_token
35
+
36
+ prompt = "Beijing is the capital of China. London is the capital of"
37
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
38
+
39
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
40
+
41
+ print('Start generating tokens, but it will take a few minutes to compile at the first time.')
42
+ for i in range(10):
43
+ t0 = time.time()
44
+ with torch.no_grad():
45
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
46
+ text = tokenizer.decode(generated_ids[0])
47
+ if i ==0:
48
+ print(f'Generated text: {text}')
49
+ t1 = time.time()
50
+ print(f'Time consumed at iteration {i}: {t1-t0}s')
modeling_muddpythia.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ from collections import namedtuple
3
+ from einops import rearrange
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ try:
13
+ from .configuration_muddpythia import MUDDPythiaConfig
14
+ except:
15
+ from configuration_muddpythia import MUDDPythiaConfig
16
+
17
+ from transformers.modeling_utils import PreTrainedModel
18
+
19
+
20
+ class KVCache(nn.Module):
21
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float16):
22
+ super().__init__()
23
+ self.seq_length = max_seq_length
24
+ cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
25
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
26
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
27
+
28
+ def update(self, input_pos, k_val, v_val):
29
+ # input_pos: [S], k_val: [B, H, S, D]
30
+ assert input_pos.shape[0] == k_val.shape[2]
31
+ k_out = self.k_cache
32
+ v_out = self.v_cache
33
+ k_out[:, :, input_pos] = k_val
34
+ v_out[:, :, input_pos] = v_val
35
+ return k_out, v_out
36
+
37
+ class LayerCache(nn.Module):
38
+ def __init__(self, max_batch_size, num_layers, model_dim, dtype=torch.float16):
39
+ super().__init__()
40
+ cache_shape = (num_layers+1, max_batch_size, 1, model_dim) # LBTD
41
+ self.register_buffer('layer_cache', torch.zeros(cache_shape, dtype=dtype))
42
+
43
+ def update(self, x, lidx):
44
+ self.layer_cache[lidx] = x
45
+ return self.layer_cache[:lidx+1]
46
+
47
+ class MultiwayDynamicDenseBlock(nn.Module):
48
+ def __init__(self, config: MUDDPythiaConfig, lidx: int, last_layer=False) -> None:
49
+ super().__init__()
50
+ self.norm = RMSnormNoscale(epsilon=config.norm_eps)
51
+ self.C = len(config.dense_type) if not last_layer else 1
52
+ self.lidx = lidx
53
+ l = lidx + 2
54
+ hid_dim, out_dim = l * self.C, l * self.C
55
+ if last_layer and config.expand_last: hid_dim *= 4
56
+ if config.round64: hid_dim = (hid_dim// 64 +1) * 64
57
+ self.w1 = nn.Linear(config.dim, hid_dim, bias=False)
58
+ self.act = nn.GELU()
59
+ self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ x = self.norm(x)
63
+ dw = self.w2(self.act(self.w1(x))) # BTD->BTL
64
+ dw = rearrange(dw, 'B T (C L) -> C B T L', C=self.C)
65
+ return dw
66
+
67
+ def layer_mix(self, hids, dw)-> Tensor:
68
+ x = tuple([sum(dw[cidx,:,:,j,None] * hids[j] for j in range(self.lidx+2)) for cidx in range(self.C)]) # BTL, LBTD-> BTD
69
+ return x
70
+
71
+ class MUDDPythia(PreTrainedModel):
72
+ config_class=MUDDPythiaConfig
73
+ def __init__(self, config: MUDDPythiaConfig) -> None:
74
+ super().__init__(config)
75
+ self.config = config
76
+ self.use_gradient_checkpointing = config.use_gradient_checkpointing
77
+ self.is_training = config.is_training
78
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
79
+ self.layers = nn.ModuleList(TransformerBlock(config, lidx) for lidx in range(config.n_layer))
80
+ self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
81
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
82
+ C = len(self.config.dense_type)
83
+ self.dense_bs = nn.ParameterList([nn.Parameter(data=torch.randn(C if lidx != config.n_layer-1 else 1, lidx+2)) for lidx in range(config.n_layer)])
84
+
85
+ self.layer_cache = None
86
+ self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
87
+ self.stack_hidden = self.config.stack_hidden
88
+ self.dynamic = self.config.dynamic_dense
89
+ self.dense = self.config.dense
90
+ if self.dynamic:
91
+ self.dynamic_dense = nn.ModuleList([MultiwayDynamicDenseBlock(config, lidx, last_layer=lidx==config.n_layer-1) for lidx in range(config.n_layer)])
92
+
93
+ self.rotary_ndims = int(config.head_dim * config.rotary_pct)
94
+ self.freqs_cis: Optional[Tensor] = None
95
+ self.mask_cache: Optional[Tensor] = None
96
+ self.max_batch_size = -1
97
+ self.max_seq_length = -1
98
+
99
+ def tie_weights(self): # placeholder
100
+ return
101
+
102
+ def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float16):
103
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
104
+ return
105
+ head_dim = self.config.dim // self.config.n_head
106
+ max_seq_length = find_multiple(max_seq_length, 8)
107
+ self.max_seq_length = max_seq_length
108
+ self.max_batch_size = max_batch_size
109
+ if not self.config.is_training:
110
+ if self.use_layer_cache:
111
+ self.layer_cache = LayerCache(max_batch_size, self.config.n_layer, self.config.dim, dtype=dtype)
112
+ for b in self.layers:
113
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=dtype)
114
+
115
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device)
116
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
117
+
118
+ def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
119
+ batch_size, seq_length = input_ids.shape
120
+ input_pos = torch.arange(seq_length, device=self.device)
121
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
122
+ generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
123
+ logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
124
+ _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
125
+ next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
126
+ next_token[:batch_size] = _next_token
127
+ generated_ids[:, seq_length] = next_token[:batch_size, 0]
128
+ input_pos = torch.tensor([seq_length], device=self.device)
129
+ for _ in range(1, num_tokens_to_generate):
130
+ if compiled_decode_one_token is not None:
131
+ next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
132
+ else:
133
+ next_token = self.decode_one_token(next_token.clone(), input_pos)
134
+ generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
135
+ input_pos += 1
136
+ return generated_ids
137
+
138
+ def decode_one_token(self, cur_token, input_pos):
139
+ logits = self.forward(
140
+ cur_token,
141
+ input_pos=input_pos,
142
+ return_tensor=True
143
+ )
144
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
145
+ return new_token
146
+
147
+ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
148
+ assert self.freqs_cis is not None, "Caches must be initialized first"
149
+ if input_pos is None:
150
+ input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
151
+ mask = self.causal_mask[None, None, input_pos]
152
+ freqs_cis = self.freqs_cis[input_pos]
153
+ x = self.tok_embeddings(idx)
154
+ _, seqlen, _ = x.shape
155
+ use_layer_cache = self.use_layer_cache and seqlen == 1
156
+ if use_layer_cache:
157
+ self.layer_cache.update(x, 0)
158
+ else:
159
+ hiddens = [x]
160
+ for i, layer in enumerate(self.layers):
161
+ if self.use_gradient_checkpointing:
162
+ x = checkpoint(layer, x, input_pos, freqs_cis, mask)
163
+ else:
164
+ x = layer(x, input_pos, freqs_cis, mask)
165
+ if use_layer_cache:
166
+ _hidden = self.layer_cache.update(x, i+1) # LBTD
167
+ else:
168
+ hiddens.append(x)
169
+ _hidden = hiddens if not self.stack_hidden else hiddens
170
+ if self.dynamic and self.dense:
171
+ dw = self.dynamic_dense[i](x) # BTD -> CBTL
172
+ dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
173
+ if self.stack_hidden:
174
+ x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
175
+ else:
176
+ x = self.dynamic_dense[i].layer_mix(_hidden, dw)
177
+ if self.config.dense_type == 'qkvr' and self.config.dense and self.config.dynamic_dense:
178
+ x = x[0]
179
+ x = self.norm(x)
180
+ logits = self.output(x)
181
+ if return_tensor:
182
+ return logits
183
+ else:
184
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
185
+ return CausalLMOutput(logits=logits)
186
+
187
+ class TransformerBlock(nn.Module):
188
+ def __init__(self, config: MUDDPythiaConfig, lidx) -> None:
189
+ super().__init__()
190
+ self.lidx = lidx
191
+ self.config = config
192
+ self.attention = Attention(config, lidx)
193
+
194
+ self.feed_forward = FeedForward(config, lidx)
195
+ self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
196
+ self.use_parallel_residual = config.use_parallel_residual
197
+
198
+ if self.config.sepln and self.lidx > 0:
199
+ self.attention_norms = torch.nn.ModuleList([ nn.LayerNorm(config.dim, eps=config.norm_eps) for _ in range(3)])
200
+ else:
201
+ self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
202
+
203
+ def forward(self, x: Union[Tuple[Tensor], Tensor], input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
204
+ if self.config.dense_type == 'l' or self.lidx == 0 or not self.config.dense:
205
+ res = x
206
+ normed_x = self.attention_norm(x)
207
+ elif self.config.dense_type == 'qkvr':
208
+ res = x[-1] # for mlp
209
+ if self.config.stack_hidden or not self.config.sepln:
210
+ normed_x = self.attention_norm(x[:3])
211
+ else:
212
+ normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
213
+ h = res + self.attention(normed_x, freqs_cis, mask, input_pos)
214
+ out = h + self.feed_forward(self.ffn_norm(res if self.use_parallel_residual else h))
215
+ return out
216
+
217
+ class Attention(nn.Module):
218
+ def __init__(self, config: MUDDPythiaConfig, lidx):
219
+ super().__init__()
220
+ assert config.dim % config.n_head == 0
221
+
222
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
223
+ # key, query, value projections for all heads, but in a batch
224
+ self.config = config
225
+ if self.config.dense_type == 'l' or not self.config.dense:
226
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=True)
227
+ elif self.config.dense_type == 'qkvr':
228
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=True)
229
+ self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)
230
+ self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)
231
+
232
+ self.wo = nn.Linear(config.dim, config.dim, bias=True)
233
+ self.lidx = lidx
234
+ self.kv_cache = None
235
+
236
+ self.n_head = config.n_head
237
+ self.head_dim = config.head_dim
238
+ self.scale_factor = 1 / math.sqrt(self.head_dim)
239
+ self.n_local_heads = config.n_local_heads
240
+ self.dim = config.dim
241
+ self.use_qk_norm = config.use_qk_norm
242
+ if self.use_qk_norm:
243
+ self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
244
+ self.k_norm = RMSNorm(self.head_dim, config.norm_eps)
245
+
246
+ self.rotary_ndims = int(self.head_dim * config.rotary_pct)
247
+
248
+ self._register_load_state_dict_pre_hook(self.load_hook)
249
+
250
+ def load_hook(self, state_dict, prefix, *args):
251
+ if prefix + "wq.weight" in state_dict and (self.config.dense_type == 'l' or not self.config.dense):
252
+ wq = state_dict.pop(prefix + "wq.weight")
253
+ wk = state_dict.pop(prefix + "wk.weight")
254
+ wv = state_dict.pop(prefix + "wv.weight")
255
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
256
+
257
+ def forward(self, x: Union[Tuple[Tensor], Tensor], freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
258
+ if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
259
+ bsz, seqlen, _ = x.shape
260
+ else:
261
+ if self.config.stack_hidden:
262
+ C, bsz, seqlen, _ = x.shape
263
+ else:
264
+ C, (bsz, seqlen, _) = len(x), x[0].shape
265
+ kv_size = self.n_local_heads * self.head_dim
266
+
267
+ if self.config.dense_type == 'l' or not self.config.dense:
268
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
269
+
270
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
271
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
272
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
273
+ elif self.config.dense_type == 'qkvr':
274
+ if self.lidx == 0:
275
+ xq, xk, xv = x, x, x
276
+ else:
277
+ xq, xk, xv = x[0], x[1], x[2]
278
+ q = self.wq(xq).view(bsz, seqlen, self.n_head, self.head_dim)
279
+ k = self.wk(xk).view(bsz, seqlen, self.n_local_heads, self.head_dim)
280
+ v = self.wv(xv).view(bsz, seqlen, self.n_local_heads, self.head_dim)
281
+
282
+ if self.use_qk_norm:
283
+ q, k = self.q_norm(q), self.k_norm(k)
284
+
285
+ if self.rotary_ndims == self.head_dim:
286
+ q = apply_rotary_emb(q, freqs_cis) #BTND
287
+ k = apply_rotary_emb(k, freqs_cis)
288
+ else:
289
+ q_rot = q[..., : self.rotary_ndims]
290
+ q_pass = q[..., self.rotary_ndims :]
291
+ k_rot = k[..., : self.rotary_ndims]
292
+ k_pass = k[..., self.rotary_ndims :]
293
+ q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND
294
+ k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half')
295
+ q = torch.cat((q_rot, q_pass), dim=-1)
296
+ k = torch.cat((k_rot, k_pass), dim=-1)
297
+
298
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
299
+
300
+ if self.kv_cache is not None:
301
+ if seqlen == 1:
302
+ k, v = self.kv_cache.update(input_pos, k, v)
303
+ else:
304
+ _, _ = self.kv_cache.update(input_pos, k, v)
305
+
306
+ if seqlen == 1: # one-token generation
307
+ k_mask = mask[:,:,:,:self.kv_cache.seq_length]
308
+ else:# prefill
309
+ k_mask = mask[:,:,:,:k.shape[-2]]
310
+
311
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
312
+ min_value = torch.finfo(torch.float16).min
313
+ logits = torch.where(k_mask, logits, min_value)
314
+ probs = logits.softmax(-1)
315
+ y = probs @ v
316
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
317
+ y = self.wo(y)
318
+ return y
319
+
320
+ class RMSnormNoscale(nn.Module):
321
+
322
+ def __init__(self, epsilon=1e-6, dim=-1):
323
+ super().__init__()
324
+ self.dim = dim
325
+ self.epsilon = epsilon
326
+
327
+ def forward(self, inputs):
328
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
329
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
330
+ return normed_inputs
331
+
332
+ class RMSNorm(nn.Module):
333
+ def __init__(self, dim: int, eps: float = 1e-5):
334
+ super().__init__()
335
+ self.eps = eps
336
+ self.weight = nn.Parameter(torch.ones(dim))
337
+
338
+ def _norm(self, x):
339
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
340
+
341
+ def forward(self, x: Tensor) -> Tensor:
342
+ output = self._norm(x.float()).type_as(x)
343
+ return output * self.weight
344
+
345
+ class FeedForward(nn.Module):
346
+ def __init__(self, config: MUDDPythiaConfig, lidx, round128=True, scale_with_layer=True) -> None:
347
+ super().__init__()
348
+ hid_dim = config.intermediate_size
349
+ if scale_with_layer:
350
+ hid_dim = hid_dim * (lidx/(config.n_layer -1) +0.5)
351
+ if round128:
352
+ hid_dim = round(hid_dim / 128) * 128
353
+ self.w1 = nn.Linear(config.dim, hid_dim, bias=config.use_linear_bias)
354
+ self.w2 = nn.Linear(hid_dim, config.dim, bias=config.use_linear_bias)
355
+
356
+ def forward(self, x: Tensor) -> Tensor:
357
+ return self.w2(F.gelu(self.w1(x)))
358
+
359
+ def find_multiple(n: int, k: int) -> int:
360
+ if n % k == 0:
361
+ return n
362
+ return n + k - (n % k)
363
+
364
+ def precompute_freqs_cis(
365
+ seq_len: int, n_elem: int, base: int = 10000
366
+ ) -> Tensor:
367
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
368
+ t = torch.arange(seq_len, device=freqs.device)
369
+ freqs = torch.outer(t, freqs)
370
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
371
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
372
+ return cache.to(dtype=torch.float16)
373
+
374
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
375
+ if mode == 'half':
376
+ xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
377
+ elif mode == 'alternative':
378
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
379
+ freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
380
+ x_out2 = torch.stack(
381
+ [
382
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
383
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
384
+ ],
385
+ -1,
386
+ )
387
+ x_out2 = x_out2.flatten(3)
388
+ return x_out2.type_as(x)
389
+
390
+ def match_weight_muddpythia(model, w, strict=False, pythia=True):
391
+ if pythia:
392
+ map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1':'ffn_layer1', 'w2': 'ffn_layer2',
393
+ 'weight': 'w', 'bias': 'b'}
394
+ else:
395
+ map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1': 'ffn_layer1_gate', 'w3': 'ffn_layer1', 'w2': 'ffn_layer2',
396
+ 'weight': 'w', 'bias': 'b'}
397
+ ln_dict={'weight':'scale','bias':'bias'}
398
+ E, H, D = model.config.dim, model.config.n_head, model.config.head_dim
399
+ N = model.config.vocab_size
400
+ state_dict = {}
401
+ for k, v in model.named_parameters():
402
+ if k == 'tok_embeddings.weight':
403
+ v = w['state.mdl_vars.params.lm.embedding_lookup.emb_var'][:N,:]
404
+ elif k == 'norm.weight':
405
+ v = w['state.mdl_vars.params.lm.final_ln.scale']
406
+ elif k == 'norm.bias':
407
+ v = w['state.mdl_vars.params.lm.final_ln.bias']
408
+ elif k == 'output.weight':
409
+ v = w['state.mdl_vars.params.lm.softmax.logits_ffn.linear.w'].T[:N,:] # E,N -> N,E
410
+ elif 'dense_bs' in k: # static dense w
411
+ lidx = int(k.split('.')[-1])
412
+ v = w[f'state.mdl_vars.params.lm.transformer.dense_conn_{lidx}']
413
+ elif 'dynamic_dense' in k:
414
+ lidx = int(k.split('.')[1])
415
+ widx = int(k.split('.')[2][-1]) # 1 or 2 in w1, w2
416
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.dynamic_dense_conn{widx}_{lidx}'].T
417
+ else:
418
+ assert 'layers' in k
419
+ lidx = int(k.split('.')[1])
420
+ if '.attention.' in k:
421
+ _, _, _, ptype, wtype = k.split('.')
422
+ if ptype in ['wq', 'wk', 'wv', 'wo']:
423
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.{map_dict.get(wtype, wtype)}']#.reshape(E,E)
424
+ if wtype == 'weight':
425
+ v = v.reshape(E,E)
426
+ elif wtype == 'bias':
427
+ v = v.reshape(E)
428
+ if ptype != 'wo' and wtype == 'weight':
429
+ v = v.T
430
+ elif ptype in ['q_norm', 'k_norm']:
431
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.scale']
432
+ elif 'feed_forward' in k:
433
+ ptype = k.split('.')[3] # w1, w3,w2
434
+ wtype = k.split('.')[4] # weight or bias
435
+ if wtype=='weight':
436
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.linear.{map_dict[wtype]}']
437
+ v = v.T
438
+ elif wtype=='bias':
439
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.bias.{map_dict[wtype]}']
440
+ elif 'ffn_norm' in k: # mlp layernorm
441
+ wtype = k.split('.')[-1] # weight or bias
442
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.layer_norm.{ln_dict[wtype]}']
443
+ elif 'attention_norm' in k: # attention layernorm
444
+ wtype = k.split('.')[-1] # weight or bias
445
+ if 'attention_norms' in k:
446
+ ln_idx = int(k.split('.')[3])
447
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norms_{ln_idx}.{ln_dict[wtype]}']
448
+ else:
449
+ v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norm.{ln_dict[wtype]}']
450
+ if pythia and 'weight' in k and 'norm' in k and 'q_norm' not in k and 'k_norm' not in k:
451
+ v = v+1
452
+ state_dict[k] = torch.tensor(v)
453
+ model.load_state_dict(state_dict, strict=strict)
454
+ return model
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.5.1
2
+ transformers==4.35.0
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }