File size: 2,121 Bytes
2e3da55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import logging
import argparse
logger = logging.getLogger(__name__)
def merge_lora(base_model_name_or_path, peft_model_path, output_dir, device='auto', push_to_hub=False):
if device == 'auto':
device_arg = { 'device_map': 'auto' }
else:
device_arg = { 'device_map': { "": device} }
logger.info(f"Loading base model: {base_model_name_or_path}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
**device_arg
)
logger.info(f"Loading PEFT: {peft_model_path}")
model = PeftModel.from_pretrained(base_model, peft_model_path, torch_dtype=torch.float16, **device_arg)
logger.info(f"Running merge_and_unload")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
if push_to_hub:
logger.info(f"Saving to hub ...")
model.push_to_hub(f"{output_dir}", use_temp_dir=False)
tokenizer.push_to_hub(f"{output_dir}", use_temp_dir=False)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir, torch_dtype=torch.float16)
logger.info(f"Model saved to {output_dir}")
if __name__ == "__main__" :
logger = logging.getLogger()
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_name_or_path", type=str)
parser.add_argument("--peft_model_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()
merge_lora(base_model_name_or_path = args.base_model_name_or_path, peft_model_path = args.peft_model_path,
output_dir = args.output_dir, device = args.device, push_to_hub = args.push_to_hub)
|