File size: 5,762 Bytes
05d39bf
 
 
54e537f
 
 
 
 
 
 
 
 
 
 
 
2d36d30
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
 
54e537f
2d36d30
54e537f
2d36d30
a28da73
 
 
 
505495c
a28da73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a15635e
a28da73
 
 
 
 
 
 
 
a15635e
a28da73
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
---
datasets:
- Matthijs/snacks
model-index:
- name: matteopilotto/vit-base-patch16-224-in21k-snacks
  results:
  - task:
      type: image-classification
      name: Image Classification
    dataset:
      name: Matthijs/snacks
      type: Matthijs/snacks
      config: default
      split: test
    metrics:
    - type: accuracy
      value: 0.8928571428571429
      name: Accuracy
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiN2IzY2JhYjg2MWI5NjQ3NTAyN2ZiNmI0ZGQ2YTBlZmI3MTk2MjEzMTk2ZWRiZjc3MTQ3Y2NmNzE3YTE0OWVkMiIsInZlcnNpb24iOjF9.TUt1-MR0dGTqzzQwxIzRBJ6J5jIPrGRwJo2wfdBnL3iEkmn-nKjIJm9omEIYfBMGgJa1CXnGdULRHk16DeiHBg
    - type: precision
      value: 0.8990033704680036
      name: Precision Macro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZjZjOGM3Y2IyODczNzhhNGEzOTUxNjAzNzhlOWFiNGFhMjU1ZTBkOGVlMTQzZWI0ZDM1NGZjYWIyYThlYzNiMCIsInZlcnNpb24iOjF9.yo8EHikUrpF-MAP1eJpKCWc7nOersQjSq07JqX_zbbqM1YSAFhGacEwjavfMY4sa1VcY6NU1dqeP3KbTlNtBDg
    - type: precision
      value: 0.8928571428571429
      name: Precision Micro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiYWM5MjhkZjc1YzIzZDFmMDFkYzVjMDdhYTBlMGU2YmExMDQyNzQzZWFlMWNmNDIzNjUwMTdiYjNjYWJmNmE3OSIsInZlcnNpb24iOjF9.DpPgzQXykudTcwa_shu0h9FeZfuhPBqbKCpAAx-QYHyx2B9MEcKpdrsN8HcczqYZ5x3XIJ7ZeKPzXpfAz3ySAA
    - type: precision
      value: 0.8972398709051788
      name: Precision Weighted
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDRkMDAwN2ViNGE4OWZhZWUyYWUxNWM2NDM1ODE3MGNkN2NjNzY3NjU5YzU1YzAyMDE2MDQ2YjEyY2IxNjJhNyIsInZlcnNpb24iOjF9.4ezCcJOFrjn4J3-GW3FDapCVzOk9rvl2u-Hhtuae2JdUQwksT9eeMRm2532el4q6wRbFIzZ2hPcPdwYEyLZbCA
    - type: recall
      value: 0.8914608843537415
      name: Recall Macro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMWI5M2RkMzQ2ODdjZmU3YTRmZjZmNzFkNzIwMGE4ZGZiZWE2ZjA2N2Y5NWI2NjJmZmI4MjUzODY0NjZkZDM0OSIsInZlcnNpb24iOjF9.2dlij3z_6tRc8_UW-bMcflibboU25wQqP-zIaMAJuI-xmQmYhYkWM1RRxcxITj5TGSrROGfOUKYIA7Xqt_1nBw
    - type: recall
      value: 0.8928571428571429
      name: Recall Micro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiOGE0Yzc1MDljZjhlMmQ3ODMzMmZjYmQxYmUzOWNjM2I0MzIwNTNkM2M0ZmEzYzE3YTgxMGJkMDEyNjY0ZGM2MyIsInZlcnNpb24iOjF9.laI8vntC4coo_DhE46nNe-DHpeNlC9VKxqO-vp7Qmn6UknL1BfiHMAdAfbHE8AYap9AZ82MIWN5pxghrRNcxDg
    - type: recall
      value: 0.8928571428571429
      name: Recall Weighted
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZWY0MmUyMDUxYjMzYTQyNGM1ZmYyOWQ0OTI2ODk1ZTFiZmRkYTU3OTFlN2M2ZjcyMzQwYjA0MTE0OTUzMWI0NSIsInZlcnNpb24iOjF9.LbXSrvdULHxq5EzSLLCgla7-ZOBX5qtqr5MSdeRRsP2Bv0VZ91AhmN-ko8YM54_8Grs6hzOrDhDYA8KQfqBfCg
    - type: f1
      value: 0.892544821273258
      name: F1 Macro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiM2U5NThhNjQxNWE4MzFkZmRiYWNkMzY5NTMzOGJiMmNkMDM1NWFiNzZmNTI2MDc0NWVkZjMyNDM1MmUxNjJlNSIsInZlcnNpb24iOjF9.koeCiZOOAkASs9W8013N3DYysLnkxfnpcjHHEvwD5hnFXczRzgnmKVJ5WN8pKA13jeemmdxjPnNS8itbzVMDCw
    - type: f1
      value: 0.8928571428571429
      name: F1 Micro
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiNTc1OTVlNDQyZTM1OTg2NjBmZTdkOWM0YzQxNDIyMTQzZmMwNTY4MWIwYzJkYmRhODU2YjQ3Yzg3N2I4ODdjYiIsInZlcnNpb24iOjF9.apLx2WthXu6hDi3NW-jOlwEQlqWw9TJYUYnu8fD0uZwmd4SOep3F9DWydroDNCkKZPawooRz2Fsr0lo9DW3mAg
    - type: f1
      value: 0.8924168605019522
      name: F1 Weighted
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMWNiMzBhYzM3N2Q4NjYwZjkxNjQyYTRiZmU1MTAyMDhmMjRiOTc1ZjYwMzE2NjY5NzZlMTllNWFmNmQ1YjM2ZSIsInZlcnNpb24iOjF9.VX8tiKQHyI9NUm8FLJSSi-02pGgtGoO5tFFbw-Cwhx_IPo68NKZqENYdaZ9RA7T5ezHOHWxvzPXQM__fYcu0CQ
    - type: loss
      value: 0.479541540145874
      name: loss
      verified: true
      verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDQxYzdlYTU1NGIyY2E5ODNkMGU0MzgxMDljMTE3NzE2MDcxMTExZmI0ZjQ0MTZjZDMyMzQxYmIyODg0Y2U2ZSIsInZlcnNpb24iOjF9.r88Ba2B3tj3DHxDgBDRM0_33X8nPxN3zqEhGb1ZesA9qerEGoRkP7bqBhj5YDFS1jIIrVOiIt98UX6s7vnEYAg
---

# Vision Transformer fine-tuned on `Matthijs/snacks` dataset

Vision Transformer (ViT) model pre-trained on ImageNet-21k and fine-tuned on [**Matthijs/snacks**](https://huggingface.co/datasets/Matthijs/snacks) for 5 epochs using various data augmentation transformations from `torchvision`.

The model achieves a **94.97%** and **94.43%** accuracy on the validation and test set, respectively.

## Data augmentation pipeline

The code block below shows the various transformations applied during pre-processing to augment the original dataset.
The augmented images where generated on-the-fly with the `set_transform` method.

```python
from transformers import ViTFeatureExtractor
from torchvision.transforms import (
    Compose,
    Normalize,
    Resize,
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
    ToTensor
)

checkpoint = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(checkpoint)

# transformations on the training set
train_aug_transforms = Compose([
    RandomResizedCrop(size=feature_extractor.size),
    RandomHorizontalFlip(p=0.5),
    RandomAdjustSharpness(sharpness_factor=5, p=0.5),
    ToTensor(),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

# transformations on the validation/test set
valid_aug_transforms = Compose([
    Resize(size=(feature_extractor.size, feature_extractor.size)),
    ToTensor(),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])
```