File size: 2,691 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import argparse
from typing import Dict, Optional
import safetensors.torch
import torch
import tqdm
def merge_checkpoints(
model_checkpoint: Dict[str, torch.Tensor],
lora_checkpoint: Dict[str, torch.Tensor],
scaling: float,
save_dtype: Optional[torch.dtype] = None,
):
save_dtype = save_dtype or next(iter(lora_checkpoint.values())).dtype
print(f"Merging to {save_dtype} precision...")
keys_to_update = [
key for key in lora_checkpoint.keys() if "norm" in key or "lora_A" in key
]
assert any(
"lora_A" in k or "lora_B" in k for k in keys_to_update
), "No `lora` keys found in your checkpoint. Check that `lora_ckpt` is correct."
for key in tqdm.tqdm(keys_to_update):
if "norm" in key:
model_checkpoint[key] = lora_checkpoint[key].to("cpu")
else:
weight_name = key.replace("lora_A.weight", "weight")
lora_A_weight = lora_checkpoint[key].to("cuda")
lora_B_weight = lora_checkpoint[key.replace("lora_A", "lora_B")].to("cuda")
weight = lora_B_weight.mm(lora_A_weight) * scaling
weight += model_checkpoint[weight_name].to("cuda")
weight = weight.to(save_dtype)
model_checkpoint[weight_name] = weight.to("cpu")
# cast all tensors to save dtype
for key in tqdm.tqdm(model_checkpoint.keys()):
model_checkpoint[key] = model_checkpoint[key].to(save_dtype)
def load(filename: str):
if filename.endswith(".safetensors"):
return safetensors.torch.load_file(filename)
else:
return torch.load(filename)
def main(args):
model_checkpoint = load(args.initial_model_ckpt)
lora_checkpoint = load(args.lora_ckpt)
merge_checkpoints(model_checkpoint, lora_checkpoint, args.scaling)
safetensors.torch.save_file(model_checkpoint, args.dump_ckpt)
print(f"Merged checkpoint saved to {args.dump_ckpt}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Merge a LoRA checkpoint into a model checkpoint."
)
parser.add_argument(
"--initial_model_ckpt",
type=str,
required=True,
help="Path to the model checkpoint.",
)
parser.add_argument(
"--lora_ckpt", type=str, required=True, help="Path to the LoRA checkpoint."
)
parser.add_argument(
"--dump_ckpt",
type=str,
required=True,
help="Path to save the merged checkpoint.",
)
parser.add_argument(
"--scaling",
type=float,
default=2.0,
help="Scaling factor for the LoRA checkpoint. Default is 2.0.",
)
args = parser.parse_args()
main(args)
|