Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
""" | |
@author:XuMing([email protected]) | |
@description: | |
Usage: | |
python merge_peft_adapter.py \ | |
--base_model_name_or_path path/to/llama/model \ | |
--tokenizer_path path/to/llama/tokenizer \ | |
--peft_model_path path/to/lora/model \ | |
--output_dir path/to/output/dir | |
after merged, chatglm and baichuan model need copy python script to output dir. | |
""" | |
import argparse | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import ( | |
AutoModel, | |
AutoTokenizer, | |
BloomForCausalLM, | |
BloomTokenizerFast, | |
AutoModelForCausalLM, | |
LlamaTokenizer, | |
LlamaForCausalLM, | |
AutoModelForSequenceClassification, | |
) | |
MODEL_CLASSES = { | |
"bloom": (BloomForCausalLM, BloomTokenizerFast), | |
"chatglm": (AutoModel, AutoTokenizer), | |
"llama": (LlamaForCausalLM, LlamaTokenizer), | |
"baichuan": (AutoModelForCausalLM, AutoTokenizer), | |
"auto": (AutoModelForCausalLM, AutoTokenizer), | |
} | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_type', default=None, type=str, required=True) | |
parser.add_argument('--base_model_name_or_path', default=None, required=True, type=str, | |
help="Base model name or path") | |
parser.add_argument('--tokenizer_path', default=None, type=str, | |
help="Please specify tokenization path.") | |
parser.add_argument('--peft_model_path', default=None, required=True, type=str, | |
help="Please specify LoRA model to be merged.") | |
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings') | |
parser.add_argument('--output_dir', default='./merged', type=str) | |
args = parser.parse_args() | |
print(args) | |
base_model_path = args.base_model_name_or_path | |
peft_model_path = args.peft_model_path | |
output_dir = args.output_dir | |
print(f"Base model: {base_model_path}") | |
print(f"LoRA model: {peft_model_path}") | |
peft_config = PeftConfig.from_pretrained(peft_model_path) | |
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | |
if peft_config.task_type == "SEQ_CLS": | |
print("Loading LoRA for sequence classification model") | |
if args.model_type == "chatglm": | |
raise ValueError("chatglm does not support sequence classification") | |
base_model = AutoModelForSequenceClassification.from_pretrained( | |
base_model_path, | |
load_in_8bit=False, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
else: | |
print("Loading LoRA for causal language model") | |
base_model = model_class.from_pretrained( | |
base_model_path, | |
load_in_8bit=False, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
if args.tokenizer_path: | |
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True) | |
else: | |
tokenizer = tokenizer_class.from_pretrained(peft_model_path, trust_remote_code=True) | |
if args.resize_emb: | |
base_model_token_size = base_model.get_input_embeddings().weight.size(0) | |
if base_model_token_size != len(tokenizer): | |
base_model.resize_token_embeddings(len(tokenizer)) | |
print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}") | |
lora_model = PeftModel.from_pretrained( | |
base_model, | |
peft_model_path, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
lora_model.eval() | |
print(f"Merging with merge_and_unload...") | |
base_model = lora_model.merge_and_unload() | |
print("Saving to Hugging Face format...") | |
tokenizer.save_pretrained(output_dir) | |
base_model.save_pretrained(output_dir) | |
print(f"Done! model saved to {output_dir}") | |
if __name__ == '__main__': | |
main() | |