File size: 2,120 Bytes
d76abce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse

import huggingface_hub
import torch
from vqmodel.configuration_vqmodel import VQModelConfig
from vqmodel.image_processing_vqmodel import VQModelImageProcessor
from vqmodel.modeling_vqmodel import VQModel

VQModelConfig.register_for_auto_class()
VQModel.register_for_auto_class()
VQModelImageProcessor.register_for_auto_class()


def main():
    args = parse_args()
    config = VQModelConfig(yaml_path=args.yaml_path)
    model = VQModel(config)
    load_model_weights(model, args.ckpt_path)

    # Define image processor
    ddconfig = model.vq_cfg.model.params.ddconfig
    image_processor = VQModelImageProcessor(
        size=ddconfig.resolution,
        convert_rgb=ddconfig.in_channels == 3,
    )

    # Edit config
    model.config.repo_id = args.repo_id
    model.config.yaml_path = "config.yaml"

    # Push to hub
    model.push_to_hub(args.repo_id, private=True)
    image_processor.push_to_hub(args.repo_id, private=True)
    api = huggingface_hub.HfApi()
    api.upload_file(
        path_or_fileobj=args.yaml_path,
        path_in_repo="config.yaml",
        repo_id=args.repo_id,
    )
    api.upload_file(
        path_or_fileobj=__file__,
        path_in_repo="push_to_hub.py",
        repo_id=args.repo_id,
    )
    api.upload_file(
        path_or_fileobj="requirements.txt",
        path_in_repo="requirements.txt",
        repo_id=args.repo_id,
    )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--repo_id", type=str, required=True, help="Repository ID")
    parser.add_argument(
        "--yaml_path", type=str, required=True, help="Path to YAML file"
    )
    parser.add_argument(
        "--ckpt_path", type=str, required=True, help="Path to checkpoint file"
    )
    return parser.parse_args()


def load_model_weights(model, ckpt_path):
    # Load checkpoint
    ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]

    # Remove loss related states
    for key in list(ckpt.keys()):
        if key.startswith("loss."):
            del ckpt[key]
    model.model.load_state_dict(ckpt)


if __name__ == "__main__":
    main()