File size: 777 Bytes
036a350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# from transformers import AutoModel
from huggingface_hub import hf_hub_download
from vision_transformer import vit_large_patch16_224_in21k
import torch
import numpy as np

REPO_ID = "ethz-mtc/aesthetics_vit"
FILENAME="pytorch_model.bin"

path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
print(path)
REPO_ID = "ethz-mtc/shot_scale_classifier-resnet50"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
print(path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = vit_large_patch16_224_in21k()
model.reset_classifier(num_classes=1)
model.load_state_dict(torch.load(path, map_location=device))

print(
    f"Model has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)