File size: 2,121 Bytes
2e3da55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import logging
import argparse

logger = logging.getLogger(__name__)

def merge_lora(base_model_name_or_path, peft_model_path, output_dir, device='auto', push_to_hub=False):
    if device == 'auto':
        device_arg = { 'device_map': 'auto' }
    else:
        device_arg = { 'device_map': { "": device} }

    logger.info(f"Loading base model: {base_model_name_or_path}")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name_or_path,
        return_dict=True,
        torch_dtype=torch.float16,
        **device_arg
    )

    logger.info(f"Loading PEFT: {peft_model_path}")
    model = PeftModel.from_pretrained(base_model, peft_model_path, torch_dtype=torch.float16, **device_arg)

    logger.info(f"Running merge_and_unload")
    model = model.merge_and_unload()

    tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

    if push_to_hub:
        logger.info(f"Saving to hub ...")
        model.push_to_hub(f"{output_dir}", use_temp_dir=False)
        tokenizer.push_to_hub(f"{output_dir}", use_temp_dir=False)
    else:
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir, torch_dtype=torch.float16)
        logger.info(f"Model saved to {output_dir}")

if __name__ == "__main__" :
    logger = logging.getLogger()
    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )

    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model_name_or_path", type=str)
    parser.add_argument("--peft_model_path", type=str)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--push_to_hub", action="store_true")

    args = parser.parse_args()

    merge_lora(base_model_name_or_path = args.base_model_name_or_path, peft_model_path = args.peft_model_path,
               output_dir = args.output_dir, device = args.device, push_to_hub = args.push_to_hub)