aesthetics_vit / test.py
dveranieto
Added usage helpers
036a350
# from transformers import AutoModel
from huggingface_hub import hf_hub_download
from vision_transformer import vit_large_patch16_224_in21k
import torch
import numpy as np
REPO_ID = "ethz-mtc/aesthetics_vit"
FILENAME="pytorch_model.bin"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
print(path)
REPO_ID = "ethz-mtc/shot_scale_classifier-resnet50"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
print(path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = vit_large_patch16_224_in21k()
model.reset_classifier(num_classes=1)
model.load_state_dict(torch.load(path, map_location=device))
print(
f"Model has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)