import torch from transformers import AutoConfig, AutoModel from modeling_sagvit import SAGViTClassifier, SAGViTConfig print("Registering model...") AutoConfig.register("sagvit", SAGViTConfig) AutoModel.register(SAGViTConfig, SAGViTClassifier) print("Registered model...") # Load config and model config = SAGViTConfig() model = SAGViTClassifier(config) # Load the state dict into the model print("Loading model weights...") state_dict = torch.load('SAG-ViT.pth') model.load_state_dict(state_dict) print("Loaded model weights...") # Push model and code model.save_pretrained('.') model.push_to_hub("shravvvv/SAG-ViT") print("Pushed model to Hugging Face hub...")