theblackcat102
commited on
Create merging/merge.py
Browse files- merging/merge.py +279 -0
merging/merge.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from safetensors import safe_open
|
4 |
+
from safetensors.torch import save_file, load_file
|
5 |
+
from kernel import weight_dequant
|
6 |
+
|
7 |
+
|
8 |
+
TGT_PATH = "PATH_TO_Deepseek-v3-dense-model"
|
9 |
+
SRC_PATH = "PATH_TO_Deepseek-v3-BASE"
|
10 |
+
with open(f'{TGT_PATH}/model.safetensors.index.json','r') as f:
|
11 |
+
dense_large_index = json.load(f)
|
12 |
+
|
13 |
+
with open(f'{SRC_PATH}/model.safetensors.index.json','r') as f:
|
14 |
+
model_index = json.load(f)
|
15 |
+
|
16 |
+
def init_non_experts_weights():
|
17 |
+
updated_cnt = 0
|
18 |
+
updated = []
|
19 |
+
for k, filename in dense_large_index['weight_map'].items():
|
20 |
+
if k in model_index['weight_map']:
|
21 |
+
tgt_safe_tensors = "/PATH_TO_BASE/"+filename
|
22 |
+
tensors = load_file(tgt_safe_tensors)
|
23 |
+
initial_size = len(tensors)
|
24 |
+
src_safe_tensors = SRC_PATH+model_index['weight_map'][k]
|
25 |
+
if k+'_scale_inv' not in model_index['weight_map']:
|
26 |
+
with safe_open(src_safe_tensors, framework="pt", device='cpu') as f:
|
27 |
+
tensors[k] = f.get_tensor(k).bfloat16()
|
28 |
+
else:
|
29 |
+
print(k, 'scale_inv')
|
30 |
+
with safe_open(src_safe_tensors, framework="pt", device='cuda') as f:
|
31 |
+
weight = f.get_tensor(k)
|
32 |
+
src_scale_inv_safe_tensors = SRC_PATH+model_index['weight_map'][k+'_scale_inv']
|
33 |
+
with safe_open(src_scale_inv_safe_tensors, framework="pt", device='cuda') as f:
|
34 |
+
scale_inv = f.get_tensor(k+'_scale_inv')
|
35 |
+
dequant_tensor = weight_dequant(weight.bfloat16(), scale_inv)
|
36 |
+
tensors[k] = dequant_tensor
|
37 |
+
updated_cnt += 1
|
38 |
+
assert initial_size == len(tensors)
|
39 |
+
save_file(tensors, tgt_safe_tensors, metadata={'format': 'pt'})
|
40 |
+
updated.append(k)
|
41 |
+
|
42 |
+
|
43 |
+
def get_adjacent_filenames(filename):
|
44 |
+
# Extract the current number
|
45 |
+
import re
|
46 |
+
current_num = int(re.search(r'(\d+)-of-', filename).group(1))
|
47 |
+
total_files = int(re.search(r'-of-(\d+)', filename).group(1))
|
48 |
+
|
49 |
+
# Get padding length from original filename (5 in this case)
|
50 |
+
padding = len(re.search(r'(\d+)-of-', filename).group(1))
|
51 |
+
|
52 |
+
# Generate previous number (wrap around to end if at start)
|
53 |
+
prev_num = total_files if current_num == 0 else current_num - 1
|
54 |
+
|
55 |
+
# Generate next number (wrap around to 0 if at end)
|
56 |
+
next_num = 0 if current_num == total_files else current_num + 1
|
57 |
+
|
58 |
+
# Create the filename pattern
|
59 |
+
base = filename.split('-of-')[0].rsplit('-', 1)[0]
|
60 |
+
ext = filename.split('.')[-1]
|
61 |
+
|
62 |
+
# Format the filenames using the same padding
|
63 |
+
prev_file = f"{base}-{str(prev_num).zfill(padding)}-of-000163.{ext}"
|
64 |
+
next_file = f"{base}-{str(next_num).zfill(padding)}-of-000163.{ext}"
|
65 |
+
|
66 |
+
return prev_file, next_file
|
67 |
+
|
68 |
+
def get_safetensors_mapping(target_layer):
|
69 |
+
gate_name = {}
|
70 |
+
share_experts_name = None
|
71 |
+
experts = {}
|
72 |
+
pre_mlp_norm = {}
|
73 |
+
for key_name, filename in model_index['weight_map'].items():
|
74 |
+
try:
|
75 |
+
layer_idx = key_name.split('.')[2]
|
76 |
+
layer_idx = int(layer_idx)
|
77 |
+
if layer_idx == target_layer:
|
78 |
+
if 'self_attn' in key_name:
|
79 |
+
pre_mlp_norm[key_name] = filename
|
80 |
+
elif 'input_layernorm' in key_name or 'post_attention_layernorm' in key_name:
|
81 |
+
pre_mlp_norm[key_name] = filename
|
82 |
+
elif '.gate.' in key_name:
|
83 |
+
gate_name[key_name] = filename
|
84 |
+
elif 'shared_experts' in key_name:
|
85 |
+
share_experts_name = filename
|
86 |
+
elif 'experts' in key_name:
|
87 |
+
expert_num = int(key_name.split('.')[5])
|
88 |
+
experts[expert_num] = filename
|
89 |
+
except (ValueError, IndexError):
|
90 |
+
continue
|
91 |
+
return {
|
92 |
+
'pre_mlp_keys': pre_mlp_norm,
|
93 |
+
'gate_safetensors': gate_name,
|
94 |
+
'share_expert_safetensors': share_experts_name,
|
95 |
+
'experts_safetensors': experts
|
96 |
+
}
|
97 |
+
def load_related_tensors(mapping):
|
98 |
+
tensors = {}
|
99 |
+
for key, filename in mapping.items():
|
100 |
+
with safe_open("/mnt/ssd/DeepSeek-V3-Base/"+filename, framework="pt", device='cpu') as f:
|
101 |
+
tensors[key] = f.get_tensor(key)
|
102 |
+
return tensors
|
103 |
+
|
104 |
+
def load_experts_weights(experts_safetensors_map, expert_range=[], target_layer=-1):
|
105 |
+
tensors = {}
|
106 |
+
expert_ids_matched = {}
|
107 |
+
for expert_id, safe_tensor_file in experts_safetensors_map.items():
|
108 |
+
if expert_id not in expert_range:
|
109 |
+
continue
|
110 |
+
tgt_safe_tensors = SRC_PATH+safe_tensor_file
|
111 |
+
matched = 0
|
112 |
+
with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
|
113 |
+
for k in f.keys():
|
114 |
+
if 'experts' not in k or 'shared_experts' in k:
|
115 |
+
continue
|
116 |
+
layer_idx = k.split('.')[2]
|
117 |
+
layer_idx = int(layer_idx)
|
118 |
+
expert_idx = int(k.split('.')[5])
|
119 |
+
|
120 |
+
if expert_idx in expert_range:
|
121 |
+
if expert_idx not in expert_ids_matched:
|
122 |
+
expert_ids_matched[expert_idx] = {}
|
123 |
+
tensors[k] = f.get_tensor(k)
|
124 |
+
matched += 1
|
125 |
+
postfix = '.'.join(k.split('.')[6:])
|
126 |
+
expert_ids_matched[expert_idx][postfix] = 1
|
127 |
+
for expert_id, keys in expert_ids_matched.items():
|
128 |
+
if len(keys) != 6:
|
129 |
+
original_src = experts_safetensors_map[expert_id]
|
130 |
+
prev_filename, next_filename = get_adjacent_filenames(original_src)
|
131 |
+
prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
|
132 |
+
for _filename in [prev_filename, next_filename, prev_prev_filename]:
|
133 |
+
with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
|
134 |
+
for k in f.keys():
|
135 |
+
if 'experts' not in k or 'shared_experts' in k:
|
136 |
+
continue
|
137 |
+
layer_idx = k.split('.')[2]
|
138 |
+
layer_idx = int(layer_idx)
|
139 |
+
expert_idx = int(k.split('.')[5])
|
140 |
+
if expert_idx == expert_id:
|
141 |
+
tensors[k] = f.get_tensor(k)
|
142 |
+
matched += 1
|
143 |
+
postfix = '.'.join(k.split('.')[6:])
|
144 |
+
expert_ids_matched[expert_idx][postfix] = 1
|
145 |
+
|
146 |
+
return tensors
|
147 |
+
def load_shared_experts_weights(safe_tensor_file, target_layer=-1):
|
148 |
+
tgt_safe_tensors = SRC_PATH+safe_tensor_file
|
149 |
+
tensors = {}
|
150 |
+
with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
|
151 |
+
for k in f.keys():
|
152 |
+
if 'shared_experts' in k:
|
153 |
+
tensors[k] = f.get_tensor(k)
|
154 |
+
if len(tensors) <= 1:
|
155 |
+
prev_filename, next_filename = get_adjacent_filenames(safe_tensor_file)
|
156 |
+
prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
|
157 |
+
for _filename in [prev_filename, next_filename, prev_prev_filename]:
|
158 |
+
with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
|
159 |
+
for k in f.keys():
|
160 |
+
if 'shared_experts' not in k:
|
161 |
+
continue
|
162 |
+
layer_idx = k.split('.')[2]
|
163 |
+
layer_idx = int(layer_idx)
|
164 |
+
if target_layer == layer_idx:
|
165 |
+
tensors[k] = f.get_tensor(k)
|
166 |
+
return tensors
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
init_non_experts_weights()
|
171 |
+
expert_ranges = [
|
172 |
+
list(range(0, 256//8)), # 0-31
|
173 |
+
list(range(32, 2*256//8)), # 32-63
|
174 |
+
list(range(64, 3*256//8)), # 64-95
|
175 |
+
list(range(96, 4*256//8)), # 96-127
|
176 |
+
list(range(128, 5*256//8)), # 128-159
|
177 |
+
list(range(160, 6*256//8)), # 160-191
|
178 |
+
list(range(192, 7*256//8)), # 192-223
|
179 |
+
list(range(224, 256)), # 224-255
|
180 |
+
]
|
181 |
+
|
182 |
+
for target_layer in range(3, 62):
|
183 |
+
result = get_safetensors_mapping(target_layer)
|
184 |
+
if len(result['experts_safetensors']) == 0:
|
185 |
+
print('empty at ', target_layer)
|
186 |
+
final_up_proj = []
|
187 |
+
final_gate_proj = []
|
188 |
+
final_down_proj = []
|
189 |
+
for expert_range in expert_ranges:
|
190 |
+
experts_weights = load_experts_weights(result['experts_safetensors'], expert_range, target_layer)
|
191 |
+
new_state_dict = {}
|
192 |
+
for weight_name, weight in experts_weights.items():
|
193 |
+
if weight_name.endswith("_scale_inv"):
|
194 |
+
continue
|
195 |
+
elif weight.element_size() == 1: # FP8 weight
|
196 |
+
scale_inv_name = f"{weight_name}_scale_inv"
|
197 |
+
try:
|
198 |
+
# Get scale_inv from the correct file
|
199 |
+
scale_inv = experts_weights[scale_inv_name]
|
200 |
+
new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
|
201 |
+
except KeyError:
|
202 |
+
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
203 |
+
else:
|
204 |
+
new_state_dict[weight_name] = weight
|
205 |
+
up_proj, gate_proj, down_proj = [], [], []
|
206 |
+
for expert_id in expert_range:
|
207 |
+
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.up_proj.weight'
|
208 |
+
up_proj.append(new_state_dict[key])
|
209 |
+
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.gate_proj.weight'
|
210 |
+
gate_proj.append(new_state_dict[key])
|
211 |
+
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.down_proj.weight'
|
212 |
+
down_proj.append(new_state_dict[key])
|
213 |
+
avg_up_proj = torch.mean(torch.stack(up_proj, dim=0), dim=0)
|
214 |
+
avg_gate_proj = torch.mean(torch.stack(gate_proj, dim=0), dim=0)
|
215 |
+
avg_down_proj = torch.mean(torch.stack(down_proj, dim=0), dim=0)
|
216 |
+
final_up_proj.append(avg_up_proj)
|
217 |
+
final_gate_proj.append(avg_gate_proj)
|
218 |
+
final_down_proj.append(avg_down_proj)
|
219 |
+
# append the final shared experts
|
220 |
+
|
221 |
+
shared_experts_weight = load_shared_experts_weights(result['share_expert_safetensors'], target_layer)
|
222 |
+
new_state_dict = {}
|
223 |
+
for weight_name, weight in shared_experts_weight.items():
|
224 |
+
if weight_name.endswith("_scale_inv"):
|
225 |
+
continue
|
226 |
+
elif weight.element_size() == 1: # FP8 weight
|
227 |
+
scale_inv_name = f"{weight_name}_scale_inv"
|
228 |
+
try:
|
229 |
+
# Get scale_inv from the correct file
|
230 |
+
scale_inv = shared_experts_weight[scale_inv_name]
|
231 |
+
new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
|
232 |
+
except KeyError:
|
233 |
+
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
234 |
+
else:
|
235 |
+
new_state_dict[weight_name] = weight
|
236 |
+
key = f'model.layers.{target_layer}.mlp.shared_experts.up_proj.weight'
|
237 |
+
final_up_proj.append(new_state_dict[key])
|
238 |
+
key = f'model.layers.{target_layer}.mlp.shared_experts.gate_proj.weight'
|
239 |
+
final_gate_proj.append(new_state_dict[key])
|
240 |
+
key = f'model.layers.{target_layer}.mlp.shared_experts.down_proj.weight'
|
241 |
+
final_down_proj.append(new_state_dict[key])
|
242 |
+
|
243 |
+
|
244 |
+
dense_up_proj = torch.concatenate(final_up_proj, dim=0)
|
245 |
+
dense_gate_proj = torch.concatenate(final_gate_proj, dim=0)
|
246 |
+
dense_down_proj = torch.concatenate([ t.T for t in final_down_proj], dim=0).T.contiguous()
|
247 |
+
|
248 |
+
assert dense_down_proj.shape[1] == 18432
|
249 |
+
assert dense_gate_proj.shape[0] == 18432
|
250 |
+
assert dense_up_proj.shape[0] == 18432
|
251 |
+
|
252 |
+
|
253 |
+
# GATE PROJ
|
254 |
+
key = f"model.layers.{target_layer}.mlp.gate_proj.weight"
|
255 |
+
target_safetensors = dense_large_index['weight_map'][key]
|
256 |
+
tensors = load_file(TGT_PATH+target_safetensors)
|
257 |
+
print(len(tensors))
|
258 |
+
assert tensors[key].shape == dense_gate_proj.shape
|
259 |
+
tensors[key] = dense_gate_proj.bfloat16()
|
260 |
+
print(len(tensors), TGT_PATH+target_safetensors)
|
261 |
+
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
|
262 |
+
|
263 |
+
# UP PROJ
|
264 |
+
key = f"model.layers.{target_layer}.mlp.up_proj.weight"
|
265 |
+
target_safetensors = dense_large_index['weight_map'][key]
|
266 |
+
tensors = load_file(TGT_PATH+target_safetensors)
|
267 |
+
assert tensors[key].shape == dense_up_proj.shape
|
268 |
+
tensors[key] = dense_up_proj.bfloat16()
|
269 |
+
print(len(tensors), TGT_PATH+target_safetensors)
|
270 |
+
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
|
271 |
+
|
272 |
+
# DOWN PROJ
|
273 |
+
key = f"model.layers.{target_layer}.mlp.down_proj.weight"
|
274 |
+
target_safetensors = dense_large_index['weight_map'][key]
|
275 |
+
tensors = load_file(TGT_PATH+target_safetensors)
|
276 |
+
assert tensors[key].shape == dense_down_proj.shape
|
277 |
+
print(len(tensors), TGT_PATH+target_safetensors)
|
278 |
+
tensors[key] = dense_down_proj.bfloat16()
|
279 |
+
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
|