import torch import torchvision from torch import nn from torchvision import transforms from transformers import ViTForImageClassification from transformers import ViTImageProcessor from typing import List device = "cuda" if torch.cuda.is_available() else "cpu" def create_vit(output_shape:int, classes:List, device:torch.device=device): """Creates a HuggingFace ViT model google/vit-base-patch16-224 Args: output_shape: The output shape classes: A list of classes device: A torch.device Returns: A tuple of the model, train_transforms, val_transforms, test_transforms """ id2label = {id:label for id, label in enumerate(classes)} label2id = {label:id for id,label in id2label.items()} model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=len(classes), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True) for param in model.parameters(): param.requires_grad = False # Can add dropout here if needed model.classifier = nn.Linear(in_features=768, out_features=output_shape) #https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") image_mean = processor.image_mean image_std = processor.image_std size = processor.size["height"] normalize = transforms.Normalize(mean=image_mean, std=image_std) train_transforms = transforms.Compose([ #transforms.RandomResizedCrop(size), transforms.Resize(size), transforms.CenterCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) val_transforms = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), normalize]) test_transforms = val_transforms return model.to(device), train_transforms, val_transforms, test_transforms