import torch from transformers import AutoModelForCausalLM, AutoTokenizer print("Loading checkpoint...") # Define the paths to your checkpoint files checkpoint_paths = [ './llama3-5b/model-00001-of-00003.pt', './llama3-5b/model-00002-of-00003.pt', './llama3-5b/model-00003-of-00003.pt' ] # Initialize an empty state dictionary merged_state_dict = {} # Load each checkpoint and merge them for checkpoint_path in checkpoint_paths: checkpoint = torch.load(checkpoint_path, map_location='cpu') merged_state_dict.update(checkpoint) print("Loading original model...") # Define the original model name or path original_model_name = "../../slice_with_mergekit/merged/" # Load the model configuration and create a new model instance model = AutoModelForCausalLM.from_pretrained(original_model_name, state_dict=merged_state_dict) print("Converting to fp16...") # Convert model parameters to float16 model.half() print("Saving model...") # Save the model in the safetensors format output_dir = './llama3-5b/hf/' model.save_pretrained(output_dir, safe_serialization=True) print("Saving tokenizer...") # Save the tokenizer as well tokenizer = AutoTokenizer.from_pretrained(original_model_name) tokenizer.save_pretrained(output_dir) print(f"Merged model saved to {output_dir}")