File size: 1,393 Bytes
73baeae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
from math import ceil
from pathlib import Path

import torch

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_path", "-c", type=str)
    args = parser.parse_args()

    weight_name_map = {
        "model.encodec_embeddings": None,
        "encodec_embeddings": "embed_encodec",
        "encodec_mlm_head": "mcm_heads",
    }

    ckpt_path = Path(args.ckpt_path)
    weight_file = ckpt_path / "pytorch_model.bin"
    state_dict = torch.load(weight_file, map_location="cpu")
    new_state_dict = {}
    for key in state_dict:
        new_key = key
        for orig, repl in weight_name_map.items():
            if repl is None:
                if orig in new_key:
                    new_key = None
                    break
                continue
            new_key = new_key.replace(orig, repl)
        if new_key:
            new_state_dict[new_key] = state_dict[key]
    for key in new_state_dict:
        if "model.encoder.embed_encodec" in key:
            dim = new_state_dict[key].shape[0]
            new_weight = torch.normal(
                0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1])
            )
            new_weight[:dim] = new_state_dict[key]
            new_state_dict[key] = new_weight
    weight_file.rename(weight_file.with_suffix(".bin.bak"))
    torch.save(new_state_dict, weight_file)