random_models_9 / model_soups_utils.py
YiDuo1999's picture
Rename model_soups_utils to model_soups_utils.py
4a89382 verified
import torch
from transformers import AutoModelForCausalLM,AutoTokenizer
from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams
def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'):
# Path to save the averaged model and tokenizer
averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','')
# Load and average the state dicts for each model
models=[]
model_paths=[model_path_1,model_path_2]
for model_path in model_paths:
models.append(AutoModelForCausalLM.from_pretrained(model_path))
avg_state_dict = {}
for key in models[0].state_dict().keys():
avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key] #sum([model.state_dict()[key] for model in models]) / len(models)
base_model = AutoModelForCausalLM.from_pretrained(base_path) # Load the base model configuration
base_model.load_state_dict(avg_state_dict)
base_model.save_pretrained(averaged_model_path) # Save the averaged model
# Load the tokenizer (assuming all models used the same tokenizer)
# If needed, adjust the tokenizer path to match the base LLaMA tokenizer used
tokenizer = AutoTokenizer.from_pretrained(model_path_1) #tokenizer = LlamaTokenizer.from_pretrained(model_path+'_{0}'.format(seeds[0]))
tokenizer.save_pretrained(averaged_model_path)
return averaged_model_path