|
import torch |
|
from transformers import AutoModelForCausalLM,AutoTokenizer |
|
from transformers import LlamaTokenizer |
|
from vllm import LLM, SamplingParams |
|
|
|
def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'): |
|
|
|
|
|
averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','') |
|
|
|
models=[] |
|
model_paths=[model_path_1,model_path_2] |
|
for model_path in model_paths: |
|
models.append(AutoModelForCausalLM.from_pretrained(model_path)) |
|
avg_state_dict = {} |
|
for key in models[0].state_dict().keys(): |
|
avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key] |
|
base_model = AutoModelForCausalLM.from_pretrained(base_path) |
|
base_model.load_state_dict(avg_state_dict) |
|
base_model.save_pretrained(averaged_model_path) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path_1) |
|
tokenizer.save_pretrained(averaged_model_path) |
|
return averaged_model_path |