flashvenom's picture
init
9b9f34f
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import logging
import argparse
logger = logging.getLogger(__name__)
def merge_lora(base_model_name_or_path, peft_model_path, output_dir, device='auto', push_to_hub=False):
if device == 'auto':
device_arg = { 'device_map': 'auto' }
else:
device_arg = { 'device_map': { "": device} }
logger.info(f"Loading base model: {base_model_name_or_path}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
**device_arg
)
logger.info(f"Loading PEFT: {peft_model_path}")
model = PeftModel.from_pretrained(base_model, peft_model_path, torch_dtype=torch.float16, **device_arg)
logger.info(f"Running merge_and_unload")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
if push_to_hub:
logger.info(f"Saving to hub ...")
model.push_to_hub(f"{output_dir}", use_temp_dir=False)
tokenizer.push_to_hub(f"{output_dir}", use_temp_dir=False)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir, torch_dtype=torch.float16)
logger.info(f"Model saved to {output_dir}")
if __name__ == "__main__" :
logger = logging.getLogger()
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_name_or_path", type=str)
parser.add_argument("--peft_model_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()
merge_lora(base_model_name_or_path = args.base_model_name_or_path, peft_model_path = args.peft_model_path,
output_dir = args.output_dir, device = args.device, push_to_hub = args.push_to_hub)