import torch | |
from safetensors.torch import load_file, save_file | |
model_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors'] | |
merged_state_dict = {} | |
for model_file in model_files: | |
state_dict = load_file(model_file) | |
for key, value in state_dict.items(): | |
if key in merged_state_dict: | |
merged_state_dict[key] += value | |
else: | |
merged_state_dict[key] = value | |
torch.save(merged_state_dict, 'pytorch_model.bin') | |