Spaces:
Runtime error
Runtime error
import torchvision | |
import torch | |
def create_vit_b_16_model(num_classes=3): | |
weights = torchvision.models.ViT_B_16_Weights.DEFAULT | |
transform = torchvision.models.ViT_B_16_Weights.DEFAULT.transforms() | |
model = torchvision.models.vit_b_16(weights=weights) | |
# freeze the layers | |
for param in model.parameters(): | |
param.requires_grad = False | |
# modify the heads layer | |
model.heads = torch.nn.Sequential( | |
torch.nn.Linear(in_features=768,out_features=num_classes) | |
) | |
return model,transform | |