theblackcat102 commited on
Commit
465eb1b
·
verified ·
1 Parent(s): 1ec39db

Create merging/merge.py

Browse files
Files changed (1) hide show
  1. merging/merge.py +279 -0
merging/merge.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from safetensors import safe_open
4
+ from safetensors.torch import save_file, load_file
5
+ from kernel import weight_dequant
6
+
7
+
8
+ TGT_PATH = "PATH_TO_Deepseek-v3-dense-model"
9
+ SRC_PATH = "PATH_TO_Deepseek-v3-BASE"
10
+ with open(f'{TGT_PATH}/model.safetensors.index.json','r') as f:
11
+ dense_large_index = json.load(f)
12
+
13
+ with open(f'{SRC_PATH}/model.safetensors.index.json','r') as f:
14
+ model_index = json.load(f)
15
+
16
+ def init_non_experts_weights():
17
+ updated_cnt = 0
18
+ updated = []
19
+ for k, filename in dense_large_index['weight_map'].items():
20
+ if k in model_index['weight_map']:
21
+ tgt_safe_tensors = "/PATH_TO_BASE/"+filename
22
+ tensors = load_file(tgt_safe_tensors)
23
+ initial_size = len(tensors)
24
+ src_safe_tensors = SRC_PATH+model_index['weight_map'][k]
25
+ if k+'_scale_inv' not in model_index['weight_map']:
26
+ with safe_open(src_safe_tensors, framework="pt", device='cpu') as f:
27
+ tensors[k] = f.get_tensor(k).bfloat16()
28
+ else:
29
+ print(k, 'scale_inv')
30
+ with safe_open(src_safe_tensors, framework="pt", device='cuda') as f:
31
+ weight = f.get_tensor(k)
32
+ src_scale_inv_safe_tensors = SRC_PATH+model_index['weight_map'][k+'_scale_inv']
33
+ with safe_open(src_scale_inv_safe_tensors, framework="pt", device='cuda') as f:
34
+ scale_inv = f.get_tensor(k+'_scale_inv')
35
+ dequant_tensor = weight_dequant(weight.bfloat16(), scale_inv)
36
+ tensors[k] = dequant_tensor
37
+ updated_cnt += 1
38
+ assert initial_size == len(tensors)
39
+ save_file(tensors, tgt_safe_tensors, metadata={'format': 'pt'})
40
+ updated.append(k)
41
+
42
+
43
+ def get_adjacent_filenames(filename):
44
+ # Extract the current number
45
+ import re
46
+ current_num = int(re.search(r'(\d+)-of-', filename).group(1))
47
+ total_files = int(re.search(r'-of-(\d+)', filename).group(1))
48
+
49
+ # Get padding length from original filename (5 in this case)
50
+ padding = len(re.search(r'(\d+)-of-', filename).group(1))
51
+
52
+ # Generate previous number (wrap around to end if at start)
53
+ prev_num = total_files if current_num == 0 else current_num - 1
54
+
55
+ # Generate next number (wrap around to 0 if at end)
56
+ next_num = 0 if current_num == total_files else current_num + 1
57
+
58
+ # Create the filename pattern
59
+ base = filename.split('-of-')[0].rsplit('-', 1)[0]
60
+ ext = filename.split('.')[-1]
61
+
62
+ # Format the filenames using the same padding
63
+ prev_file = f"{base}-{str(prev_num).zfill(padding)}-of-000163.{ext}"
64
+ next_file = f"{base}-{str(next_num).zfill(padding)}-of-000163.{ext}"
65
+
66
+ return prev_file, next_file
67
+
68
+ def get_safetensors_mapping(target_layer):
69
+ gate_name = {}
70
+ share_experts_name = None
71
+ experts = {}
72
+ pre_mlp_norm = {}
73
+ for key_name, filename in model_index['weight_map'].items():
74
+ try:
75
+ layer_idx = key_name.split('.')[2]
76
+ layer_idx = int(layer_idx)
77
+ if layer_idx == target_layer:
78
+ if 'self_attn' in key_name:
79
+ pre_mlp_norm[key_name] = filename
80
+ elif 'input_layernorm' in key_name or 'post_attention_layernorm' in key_name:
81
+ pre_mlp_norm[key_name] = filename
82
+ elif '.gate.' in key_name:
83
+ gate_name[key_name] = filename
84
+ elif 'shared_experts' in key_name:
85
+ share_experts_name = filename
86
+ elif 'experts' in key_name:
87
+ expert_num = int(key_name.split('.')[5])
88
+ experts[expert_num] = filename
89
+ except (ValueError, IndexError):
90
+ continue
91
+ return {
92
+ 'pre_mlp_keys': pre_mlp_norm,
93
+ 'gate_safetensors': gate_name,
94
+ 'share_expert_safetensors': share_experts_name,
95
+ 'experts_safetensors': experts
96
+ }
97
+ def load_related_tensors(mapping):
98
+ tensors = {}
99
+ for key, filename in mapping.items():
100
+ with safe_open("/mnt/ssd/DeepSeek-V3-Base/"+filename, framework="pt", device='cpu') as f:
101
+ tensors[key] = f.get_tensor(key)
102
+ return tensors
103
+
104
+ def load_experts_weights(experts_safetensors_map, expert_range=[], target_layer=-1):
105
+ tensors = {}
106
+ expert_ids_matched = {}
107
+ for expert_id, safe_tensor_file in experts_safetensors_map.items():
108
+ if expert_id not in expert_range:
109
+ continue
110
+ tgt_safe_tensors = SRC_PATH+safe_tensor_file
111
+ matched = 0
112
+ with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
113
+ for k in f.keys():
114
+ if 'experts' not in k or 'shared_experts' in k:
115
+ continue
116
+ layer_idx = k.split('.')[2]
117
+ layer_idx = int(layer_idx)
118
+ expert_idx = int(k.split('.')[5])
119
+
120
+ if expert_idx in expert_range:
121
+ if expert_idx not in expert_ids_matched:
122
+ expert_ids_matched[expert_idx] = {}
123
+ tensors[k] = f.get_tensor(k)
124
+ matched += 1
125
+ postfix = '.'.join(k.split('.')[6:])
126
+ expert_ids_matched[expert_idx][postfix] = 1
127
+ for expert_id, keys in expert_ids_matched.items():
128
+ if len(keys) != 6:
129
+ original_src = experts_safetensors_map[expert_id]
130
+ prev_filename, next_filename = get_adjacent_filenames(original_src)
131
+ prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
132
+ for _filename in [prev_filename, next_filename, prev_prev_filename]:
133
+ with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
134
+ for k in f.keys():
135
+ if 'experts' not in k or 'shared_experts' in k:
136
+ continue
137
+ layer_idx = k.split('.')[2]
138
+ layer_idx = int(layer_idx)
139
+ expert_idx = int(k.split('.')[5])
140
+ if expert_idx == expert_id:
141
+ tensors[k] = f.get_tensor(k)
142
+ matched += 1
143
+ postfix = '.'.join(k.split('.')[6:])
144
+ expert_ids_matched[expert_idx][postfix] = 1
145
+
146
+ return tensors
147
+ def load_shared_experts_weights(safe_tensor_file, target_layer=-1):
148
+ tgt_safe_tensors = SRC_PATH+safe_tensor_file
149
+ tensors = {}
150
+ with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
151
+ for k in f.keys():
152
+ if 'shared_experts' in k:
153
+ tensors[k] = f.get_tensor(k)
154
+ if len(tensors) <= 1:
155
+ prev_filename, next_filename = get_adjacent_filenames(safe_tensor_file)
156
+ prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
157
+ for _filename in [prev_filename, next_filename, prev_prev_filename]:
158
+ with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
159
+ for k in f.keys():
160
+ if 'shared_experts' not in k:
161
+ continue
162
+ layer_idx = k.split('.')[2]
163
+ layer_idx = int(layer_idx)
164
+ if target_layer == layer_idx:
165
+ tensors[k] = f.get_tensor(k)
166
+ return tensors
167
+
168
+
169
+ if __name__ == "__main__":
170
+ init_non_experts_weights()
171
+ expert_ranges = [
172
+ list(range(0, 256//8)), # 0-31
173
+ list(range(32, 2*256//8)), # 32-63
174
+ list(range(64, 3*256//8)), # 64-95
175
+ list(range(96, 4*256//8)), # 96-127
176
+ list(range(128, 5*256//8)), # 128-159
177
+ list(range(160, 6*256//8)), # 160-191
178
+ list(range(192, 7*256//8)), # 192-223
179
+ list(range(224, 256)), # 224-255
180
+ ]
181
+
182
+ for target_layer in range(3, 62):
183
+ result = get_safetensors_mapping(target_layer)
184
+ if len(result['experts_safetensors']) == 0:
185
+ print('empty at ', target_layer)
186
+ final_up_proj = []
187
+ final_gate_proj = []
188
+ final_down_proj = []
189
+ for expert_range in expert_ranges:
190
+ experts_weights = load_experts_weights(result['experts_safetensors'], expert_range, target_layer)
191
+ new_state_dict = {}
192
+ for weight_name, weight in experts_weights.items():
193
+ if weight_name.endswith("_scale_inv"):
194
+ continue
195
+ elif weight.element_size() == 1: # FP8 weight
196
+ scale_inv_name = f"{weight_name}_scale_inv"
197
+ try:
198
+ # Get scale_inv from the correct file
199
+ scale_inv = experts_weights[scale_inv_name]
200
+ new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
201
+ except KeyError:
202
+ print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
203
+ else:
204
+ new_state_dict[weight_name] = weight
205
+ up_proj, gate_proj, down_proj = [], [], []
206
+ for expert_id in expert_range:
207
+ key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.up_proj.weight'
208
+ up_proj.append(new_state_dict[key])
209
+ key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.gate_proj.weight'
210
+ gate_proj.append(new_state_dict[key])
211
+ key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.down_proj.weight'
212
+ down_proj.append(new_state_dict[key])
213
+ avg_up_proj = torch.mean(torch.stack(up_proj, dim=0), dim=0)
214
+ avg_gate_proj = torch.mean(torch.stack(gate_proj, dim=0), dim=0)
215
+ avg_down_proj = torch.mean(torch.stack(down_proj, dim=0), dim=0)
216
+ final_up_proj.append(avg_up_proj)
217
+ final_gate_proj.append(avg_gate_proj)
218
+ final_down_proj.append(avg_down_proj)
219
+ # append the final shared experts
220
+
221
+ shared_experts_weight = load_shared_experts_weights(result['share_expert_safetensors'], target_layer)
222
+ new_state_dict = {}
223
+ for weight_name, weight in shared_experts_weight.items():
224
+ if weight_name.endswith("_scale_inv"):
225
+ continue
226
+ elif weight.element_size() == 1: # FP8 weight
227
+ scale_inv_name = f"{weight_name}_scale_inv"
228
+ try:
229
+ # Get scale_inv from the correct file
230
+ scale_inv = shared_experts_weight[scale_inv_name]
231
+ new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
232
+ except KeyError:
233
+ print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
234
+ else:
235
+ new_state_dict[weight_name] = weight
236
+ key = f'model.layers.{target_layer}.mlp.shared_experts.up_proj.weight'
237
+ final_up_proj.append(new_state_dict[key])
238
+ key = f'model.layers.{target_layer}.mlp.shared_experts.gate_proj.weight'
239
+ final_gate_proj.append(new_state_dict[key])
240
+ key = f'model.layers.{target_layer}.mlp.shared_experts.down_proj.weight'
241
+ final_down_proj.append(new_state_dict[key])
242
+
243
+
244
+ dense_up_proj = torch.concatenate(final_up_proj, dim=0)
245
+ dense_gate_proj = torch.concatenate(final_gate_proj, dim=0)
246
+ dense_down_proj = torch.concatenate([ t.T for t in final_down_proj], dim=0).T.contiguous()
247
+
248
+ assert dense_down_proj.shape[1] == 18432
249
+ assert dense_gate_proj.shape[0] == 18432
250
+ assert dense_up_proj.shape[0] == 18432
251
+
252
+
253
+ # GATE PROJ
254
+ key = f"model.layers.{target_layer}.mlp.gate_proj.weight"
255
+ target_safetensors = dense_large_index['weight_map'][key]
256
+ tensors = load_file(TGT_PATH+target_safetensors)
257
+ print(len(tensors))
258
+ assert tensors[key].shape == dense_gate_proj.shape
259
+ tensors[key] = dense_gate_proj.bfloat16()
260
+ print(len(tensors), TGT_PATH+target_safetensors)
261
+ save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
262
+
263
+ # UP PROJ
264
+ key = f"model.layers.{target_layer}.mlp.up_proj.weight"
265
+ target_safetensors = dense_large_index['weight_map'][key]
266
+ tensors = load_file(TGT_PATH+target_safetensors)
267
+ assert tensors[key].shape == dense_up_proj.shape
268
+ tensors[key] = dense_up_proj.bfloat16()
269
+ print(len(tensors), TGT_PATH+target_safetensors)
270
+ save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
271
+
272
+ # DOWN PROJ
273
+ key = f"model.layers.{target_layer}.mlp.down_proj.weight"
274
+ target_safetensors = dense_large_index['weight_map'][key]
275
+ tensors = load_file(TGT_PATH+target_safetensors)
276
+ assert tensors[key].shape == dense_down_proj.shape
277
+ print(len(tensors), TGT_PATH+target_safetensors)
278
+ tensors[key] = dense_down_proj.bfloat16()
279
+ save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})