import torchvision import torch def create_vit_b_16_model(num_classes=3): weights = torchvision.models.ViT_B_16_Weights.DEFAULT transform = torchvision.models.ViT_B_16_Weights.DEFAULT.transforms() model = torchvision.models.vit_b_16(weights=weights) # freeze the layers for param in model.parameters(): param.requires_grad = False # modify the heads layer model.heads = torch.nn.Sequential( torch.nn.Linear(in_features=768,out_features=num_classes) ) return model,transform