Spaces:
Sleeping
Sleeping
File size: 1,393 Bytes
73baeae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import argparse
from math import ceil
from pathlib import Path
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", "-c", type=str)
args = parser.parse_args()
weight_name_map = {
"model.encodec_embeddings": None,
"encodec_embeddings": "embed_encodec",
"encodec_mlm_head": "mcm_heads",
}
ckpt_path = Path(args.ckpt_path)
weight_file = ckpt_path / "pytorch_model.bin"
state_dict = torch.load(weight_file, map_location="cpu")
new_state_dict = {}
for key in state_dict:
new_key = key
for orig, repl in weight_name_map.items():
if repl is None:
if orig in new_key:
new_key = None
break
continue
new_key = new_key.replace(orig, repl)
if new_key:
new_state_dict[new_key] = state_dict[key]
for key in new_state_dict:
if "model.encoder.embed_encodec" in key:
dim = new_state_dict[key].shape[0]
new_weight = torch.normal(
0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1])
)
new_weight[:dim] = new_state_dict[key]
new_state_dict[key] = new_weight
weight_file.rename(weight_file.with_suffix(".bin.bak"))
torch.save(new_state_dict, weight_file)
|