File size: 1,568 Bytes
05d39bf
 
 
a28da73
 
 
 
505495c
a28da73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a15635e
a28da73
 
 
 
 
 
 
 
a15635e
a28da73
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
---
datasets:
- Matthijs/snacks
---

# Vision Transformer fine-tuned on `Matthijs/snacks` dataset

Vision Transformer (ViT) model pre-trained on ImageNet-21k and fine-tuned on [**Matthijs/snacks**](https://huggingface.co/datasets/Matthijs/snacks) for 5 epochs using various data augmentation transformations from `torchvision`.

The model achieves a **94.97%** and **94.43%** accuracy on the validation and test set, respectively.

## Data augmentation pipeline

The code block below shows the various transformations applied during pre-processing to augment the original dataset.
The augmented images where generated on-the-fly with the `set_transform` method.

```python
from transformers import ViTFeatureExtractor
from torchvision.transforms import (
    Compose,
    Normalize,
    Resize,
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
    ToTensor
)

checkpoint = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(checkpoint)

# transformations on the training set
train_aug_transforms = Compose([
    RandomResizedCrop(size=feature_extractor.size),
    RandomHorizontalFlip(p=0.5),
    RandomAdjustSharpness(sharpness_factor=5, p=0.5),
    ToTensor(),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

# transformations on the validation/test set
valid_aug_transforms = Compose([
    Resize(size=(feature_extractor.size, feature_extractor.size)),
    ToTensor(),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])
```