Model Summary

  • Architecture: Vision Transformer (ViT).
  • Backbone: Token embedding via image patches, Multi-Head Self-Attention (MHSA), and MLP blocks.
  • Dataset: CIFAR-10 (10 classes, 60k images).
  • Training Framework: PyTorch.
  • Performance: Demonstration-level training loop for illustration.

Training Process

  1. Dataset & Transforms

    • We used CIFAR-10 (32×32 color images).
    • Images were resized to 224×224 to match the original ViT patching approach.
    • [Optional] Normalization can be applied as needed, e.g. using mean/std of CIFAR-10.
  2. Model Architecture

    • Patches of size P × P.
    • Embedding dimension D.
    • Multi-Head Self-Attention with k heads.
    • MLP dimension of mlp_size.
    • A stack of L Transformer blocks.
  3. Optimizer & Loss

    • Optimizer: Adam (learning rate = 1e-4).
    • Loss: CrossEntropyLoss.
  4. Training Loop

    • Standard PyTorch loop with mini-batches.
    • Multiple epochs.
    • Tracked the training loss and accuracy.

How to Use the Model

1. Installation

Make sure you have the following libraries installed:

pip install torch torchvision matplotlib gradio huggingface_hub

2. Loading the Model

If you have a local vit_cifar_model.pth (the trained state dict), you can load the model like this:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import or define your ViT class
from model_definition import ViT  # your model code

model_cifar = ViT().to(device)
checkpoint = torch.load("vit_cifar_model.pth", map_location=device)
model_cifar.load_state_dict(checkpoint)
model_cifar.eval()

3. Inference on a Single Image

from PIL import Image
import torchvision.transforms as T

transform_cifar = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

img = Image.open("some_image.jpg")  # Load an image
x = transform_cifar(img).unsqueeze(0).to(device)  # shape [1, 3, 224, 224]

with torch.no_grad():
    logits = model_cifar(x)
pred = torch.argmax(logits, dim=1).item()
print("Predicted class:", pred)

Training & Evaluation Graphs

Below is a conceptual summary of the typical outputs you might see after training. (In your code, these graphs are generated using Matplotlib.)

  1. Training Loss Plot

image/png

Shows the training loss decreasing over epochs.

  1. Training Accuracy Plot

    Training Accuracy Plot

    Tracks the Test Accuracy: 41.51% percentage of correct predictions on the training set each epoch.

  2. Test Set Accuracy

    Test Accuracy Plot

    Evaluates the model on the test set across epochs.

  3. Confusion Matrix

image/png

Visual representation of true labels vs. predicted labels.

(Note: Replace placeholder image URLs with your actual plots if you have them hosted somewhere.)


Classification Report: precision recall f1-score support

       0     0.5618    0.4090    0.4734      1000
       1     0.5385    0.3500    0.4242      1000
       2     0.2884    0.2030    0.2383      1000
       3     0.3481    0.1570    0.2164      1000
       4     0.3686    0.5050    0.4262      1000
       5     0.3280    0.3910    0.3568      1000
       6     0.5423    0.4680    0.5024      1000
       7     0.4477    0.4110    0.4286      1000
       8     0.4668    0.5770    0.5161      1000
       9     0.3602    0.6800    0.4709      1000

accuracy                         0.4151     10000

macro avg 0.4250 0.4151 0.4053 10000 weighted avg 0.4250 0.4151 0.4053 10000

###############################################################################

CELL: Vision Transformer Hyperparameters

all important parameters for your ViT model.

Batch size

B = 2 # e.g., for demonstration

Number of channels (RGB = 3)

C = 3

Image height and width

H = 224 W = 224

Patch size

P = 16

Number of patches (derived from H, W, and P)

N = (H // P) * (W // P)

Embedding dimension

D = 768

Number of attention heads

k = 12

Dimension per head (must be compatible with D)

Dh = D // k

Dropout probability

p = 0.1

Hidden layer size for MLP inside the Transformer block

mlp_size = 3072

Number of Transformer blocks (depth of the encoder)

L = 12

Number of output classes (e.g., CIFAR-10 has 10 classes)

n_classes = 10

Print them out in a structured format

print("=== Vision Transformer Parameters ===") print(f"B (Batch Size): {B}") print(f"C (Channels): {C}") print(f"H (Image Height): {H}") print(f"W (Image Width): {W}") print(f"P (Patch Size): {P}") print(f"N (Number of Patches): {N}") print(f"D (Embedding Dimension): {D}") print(f"k (Attention Heads): {k}") print(f"Dh (Dim per Head): {Dh}") print(f"p (Dropout Probability): {p}") print(f"mlp_size (MLP Hidden): {mlp_size}") print(f"L (Num Transformer Blocks): {L}") print(f"n_classes (Output Classes): {n_classes}") print("=====================================")

Integration with Gradio & Hugging Face Spaces

Gradio Demo

A simple Gradio demo can be created to classify uploaded images:

import gradio as gr
import torch
import torchvision.transforms as T
from PIL import Image

model_cifar.eval()

class_names_cifar = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

def predict_cifar(img):
    x = T.Compose([T.Resize((224, 224)), T.ToTensor()])(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model_cifar(x)
    pred_id = torch.argmax(logits, dim=1).item()
    return f"Prediction: {class_names_cifar[pred_id]}"

gr.Interface(
    fn=predict_cifar,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="ViT on CIFAR-10"
).launch()

Hugging Face Hub

You can push the model and code to the Hugging Face Hub:

from huggingface_hub import HfApi, HfFolder

api = HfApi()
repo_id = "username/my-cifar-vit"

api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_file(
    path_or_fileobj="vit_cifar_model.pth",
    path_in_repo="vit_cifar_model.pth",
    repo_id=repo_id,
    repo_type="model"
)

Then create a Space with Gradio integration if you want a hosted web app.


License

MIT License or any license of your choice.


Author

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.

Dataset used to train Omarrran/HNM-Vision-model-cifar-vit