SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers

Model Description

Implementation of the SAG-ViT model as proposed in the SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers paper.

It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.

Model Architecture

SAGViTArchitecture

Image source: SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers

Usage

SAG-ViT expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (N, 3, H, W), where N is the number of images, H and W are expected to be at least 49 pixels. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

To train or run inference on our model, refer to the following steps:

Clone our repository and load the model pretrained on CIFAR-10 dataset.

git clone https://huggingface.co/shravvvv/SAG-ViT
cd SAG-ViT

Install required dependencies.

pip install -r requirements.txt

Use from_pretrained to load the model from Hugging Face Hub and run inference on a sample input image.

from transformers import AutoModel, AutoConfig
from PIL import Image
from torchvision import transforms
import torch

# Step 1: Load the model and configuration directly from Hugging Face Hub
repo_name = "shravvvv/SAG-ViT" 
config = AutoConfig.from_pretrained(repo_name)  # Load config from hub
model = AutoModel.from_pretrained(repo_name, config=config)  # Load model from hub

# Step 2: Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Example normalization
])

# Step 3: Load and preprocess the input image
input_image_path = "path/to/your/image.jpg"
img = Image.open(input_image_path).convert("RGB")
img = transform(img).unsqueeze(0)  # Add batch dimension

# Step 4: Ensure the model is in evaluation mode
model.eval()

# Step 5: Run inference
with torch.no_grad():
    outputs = model(img)
    logits = outputs.logits  # Accessing logits from ModelOutput

# Step 6: Post-process the predictions
predicted_class_index = torch.argmax(logits, dim=1)  # Get the predicted class index

# CIFAR-10 label mapping
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Get the predicted class name from the class index
predicted_class_name = class_names[predicted_class_index.item()]
print(f"Predicted class: {predicted_class_name}")

Running Tests

If you clone our repository, the 'tests' folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:

python -m unittest discover -s tests

or, if you are using pytest, you can run:

pytest tests

Results
We evaluated SAG-ViT on diverse datasets:

  • CIFAR-10 (natural images)
  • GTSRB (traffic sign recognition)
  • NCT-CRC-HE-100K (histopathological images)
  • NWPU-RESISC45 (remote sensing imagery)
  • PlantVillage (agricultural imagery)

SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):

Backbone CIFAR-10 GTSRB NCT-CRC-HE-100K NWPU-RESISC45 PlantVillage
DenseNet201 0.5427 0.9862 0.9214 0.4493 0.8725
Vgg16 0.5345 0.8180 0.8234 0.4114 0.7064
Vgg19 0.5307 0.7551 0.8178 0.3844 0.6811
DenseNet121 0.5290 0.9813 0.9247 0.4381 0.8321
AlexNet 0.6126 0.9059 0.8743 0.4397 0.7684
Inception 0.7734 0.8934 0.8707 0.8707 0.8216
ResNet 0.9172 0.9134 0.9478 0.9103 0.8905
MobileNet 0.9169 0.3006 0.4965 0.1667 0.2213
ViT - S 0.8465 0.8542 0.8234 0.6116 0.8654
ViT - L 0.8637 0.8613 0.8345 0.8358 0.8842
MNASNet1_0 0.1032 0.0024 0.0212 0.0011 0.0049
ShuffleNet_V2_x1_0 0.3523 0.4244 0.4598 0.1808 0.3190
SqueezeNet1_0 0.4328 0.8392 0.7843 0.3913 0.6638
GoogLeNet 0.4954 0.9455 0.8631 0.3720 0.7726
Proposed (SAG-ViT) 0.9574 0.9958 0.9861 0.9549 0.9772

Citation

If you find our paper and code helpful for your research, please consider citing our work and giving the repository a star:

@misc{SAGViT,
      title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers}, 
      author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
      year={2024},
      eprint={2411.09420},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2411.09420}, 
}
Downloads last month
20
Safetensors
Model size
6.74M params
Tensor type
F32
ยท
Inference Examples
Unable to determine this model's library. Check the docs .

Dataset used to train shravvvv/SAG-ViT

Space using shravvvv/SAG-ViT 1