|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
import logging |
|
import argparse |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def merge_lora(base_model_name_or_path, peft_model_path, output_dir, device='auto', push_to_hub=False): |
|
if device == 'auto': |
|
device_arg = { 'device_map': 'auto' } |
|
else: |
|
device_arg = { 'device_map': { "": device} } |
|
|
|
logger.info(f"Loading base model: {base_model_name_or_path}") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name_or_path, |
|
return_dict=True, |
|
torch_dtype=torch.float16, |
|
**device_arg |
|
) |
|
|
|
logger.info(f"Loading PEFT: {peft_model_path}") |
|
model = PeftModel.from_pretrained(base_model, peft_model_path, torch_dtype=torch.float16, **device_arg) |
|
|
|
logger.info(f"Running merge_and_unload") |
|
model = model.merge_and_unload() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) |
|
|
|
if push_to_hub: |
|
logger.info(f"Saving to hub ...") |
|
model.push_to_hub(f"{output_dir}", use_temp_dir=False) |
|
tokenizer.push_to_hub(f"{output_dir}", use_temp_dir=False) |
|
else: |
|
model.save_pretrained(output_dir) |
|
tokenizer.save_pretrained(output_dir, torch_dtype=torch.float16) |
|
logger.info(f"Model saved to {output_dir}") |
|
|
|
if __name__ == "__main__" : |
|
logger = logging.getLogger() |
|
logging.basicConfig( |
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" |
|
) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--base_model_name_or_path", type=str) |
|
parser.add_argument("--peft_model_path", type=str) |
|
parser.add_argument("--output_dir", type=str) |
|
parser.add_argument("--device", type=str, default="auto") |
|
parser.add_argument("--push_to_hub", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
merge_lora(base_model_name_or_path = args.base_model_name_or_path, peft_model_path = args.peft_model_path, |
|
output_dir = args.output_dir, device = args.device, push_to_hub = args.push_to_hub) |
|
|