File size: 2,691 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
from typing import Dict, Optional

import safetensors.torch
import torch
import tqdm


def merge_checkpoints(
    model_checkpoint: Dict[str, torch.Tensor],
    lora_checkpoint: Dict[str, torch.Tensor],
    scaling: float,
    save_dtype: Optional[torch.dtype] = None,
):
    save_dtype = save_dtype or next(iter(lora_checkpoint.values())).dtype
    print(f"Merging to {save_dtype} precision...")

    keys_to_update = [
        key for key in lora_checkpoint.keys() if "norm" in key or "lora_A" in key
    ]
    assert any(
        "lora_A" in k or "lora_B" in k for k in keys_to_update
    ), "No `lora` keys found in your checkpoint. Check that `lora_ckpt` is correct."

    for key in tqdm.tqdm(keys_to_update):
        if "norm" in key:
            model_checkpoint[key] = lora_checkpoint[key].to("cpu")
        else:
            weight_name = key.replace("lora_A.weight", "weight")

            lora_A_weight = lora_checkpoint[key].to("cuda")
            lora_B_weight = lora_checkpoint[key.replace("lora_A", "lora_B")].to("cuda")

            weight = lora_B_weight.mm(lora_A_weight) * scaling
            weight += model_checkpoint[weight_name].to("cuda")
            weight = weight.to(save_dtype)

            model_checkpoint[weight_name] = weight.to("cpu")

    # cast all tensors to save dtype
    for key in tqdm.tqdm(model_checkpoint.keys()):
        model_checkpoint[key] = model_checkpoint[key].to(save_dtype)


def load(filename: str):
    if filename.endswith(".safetensors"):
        return safetensors.torch.load_file(filename)
    else:
        return torch.load(filename)


def main(args):
    model_checkpoint = load(args.initial_model_ckpt)
    lora_checkpoint = load(args.lora_ckpt)

    merge_checkpoints(model_checkpoint, lora_checkpoint, args.scaling)

    safetensors.torch.save_file(model_checkpoint, args.dump_ckpt)

    print(f"Merged checkpoint saved to {args.dump_ckpt}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Merge a LoRA checkpoint into a model checkpoint."
    )
    parser.add_argument(
        "--initial_model_ckpt",
        type=str,
        required=True,
        help="Path to the model checkpoint.",
    )
    parser.add_argument(
        "--lora_ckpt", type=str, required=True, help="Path to the LoRA checkpoint."
    )
    parser.add_argument(
        "--dump_ckpt",
        type=str,
        required=True,
        help="Path to save the merged checkpoint.",
    )
    parser.add_argument(
        "--scaling",
        type=float,
        default=2.0,
        help="Scaling factor for the LoRA checkpoint. Default is 2.0.",
    )

    args = parser.parse_args()
    main(args)