File size: 2,120 Bytes
d76abce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import argparse
import huggingface_hub
import torch
from vqmodel.configuration_vqmodel import VQModelConfig
from vqmodel.image_processing_vqmodel import VQModelImageProcessor
from vqmodel.modeling_vqmodel import VQModel
VQModelConfig.register_for_auto_class()
VQModel.register_for_auto_class()
VQModelImageProcessor.register_for_auto_class()
def main():
args = parse_args()
config = VQModelConfig(yaml_path=args.yaml_path)
model = VQModel(config)
load_model_weights(model, args.ckpt_path)
# Define image processor
ddconfig = model.vq_cfg.model.params.ddconfig
image_processor = VQModelImageProcessor(
size=ddconfig.resolution,
convert_rgb=ddconfig.in_channels == 3,
)
# Edit config
model.config.repo_id = args.repo_id
model.config.yaml_path = "config.yaml"
# Push to hub
model.push_to_hub(args.repo_id, private=True)
image_processor.push_to_hub(args.repo_id, private=True)
api = huggingface_hub.HfApi()
api.upload_file(
path_or_fileobj=args.yaml_path,
path_in_repo="config.yaml",
repo_id=args.repo_id,
)
api.upload_file(
path_or_fileobj=__file__,
path_in_repo="push_to_hub.py",
repo_id=args.repo_id,
)
api.upload_file(
path_or_fileobj="requirements.txt",
path_in_repo="requirements.txt",
repo_id=args.repo_id,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id", type=str, required=True, help="Repository ID")
parser.add_argument(
"--yaml_path", type=str, required=True, help="Path to YAML file"
)
parser.add_argument(
"--ckpt_path", type=str, required=True, help="Path to checkpoint file"
)
return parser.parse_args()
def load_model_weights(model, ckpt_path):
# Load checkpoint
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# Remove loss related states
for key in list(ckpt.keys()):
if key.startswith("loss."):
del ckpt[key]
model.model.load_state_dict(ckpt)
if __name__ == "__main__":
main()
|