SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers
Model Description
Implementation of the SAG-ViT model as proposed in the SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers paper.
It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.
Model Architecture
Image source: SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers
Usage
SAG-ViT expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (N, 3, H, W)
, where N
is the number of images, H
and W
are expected to be at least 49
pixels.
The images have to be loaded in to a range of [0, 1]
and then normalized using mean = [0.485, 0.456, 0.406]
and std = [0.229, 0.224, 0.225]
.
To train or run inference on our model, refer to the following steps:
Clone our repository and load the model pretrained on CIFAR-10 dataset.
git clone https://huggingface.co/shravvvv/SAG-ViT
cd SAG-ViT
Install required dependencies.
pip install -r requirements.txt
Use from_pretrained
to load the model from Hugging Face Hub and run inference on a sample input image.
from transformers import AutoModel, AutoConfig
from PIL import Image
from torchvision import transforms
import torch
# Step 1: Load the model and configuration directly from Hugging Face Hub
repo_name = "shravvvv/SAG-ViT"
config = AutoConfig.from_pretrained(repo_name) # Load config from hub
model = AutoModel.from_pretrained(repo_name, config=config) # Load model from hub
# Step 2: Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match the expected input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalization
])
# Step 3: Load and preprocess the input image
input_image_path = "path/to/your/image.jpg"
img = Image.open(input_image_path).convert("RGB")
img = transform(img).unsqueeze(0) # Add batch dimension
# Step 4: Ensure the model is in evaluation mode
model.eval()
# Step 5: Run inference
with torch.no_grad():
outputs = model(img)
logits = outputs.logits # Accessing logits from ModelOutput
# Step 6: Post-process the predictions
predicted_class_index = torch.argmax(logits, dim=1) # Get the predicted class index
# CIFAR-10 label mapping
class_names = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
# Get the predicted class name from the class index
predicted_class_name = class_names[predicted_class_index.item()]
print(f"Predicted class: {predicted_class_name}")
Running Tests
If you clone our repository, the 'tests' folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:
python -m unittest discover -s tests
or, if you are using pytest
, you can run:
pytest tests
Results
We evaluated SAG-ViT on diverse datasets:
- CIFAR-10 (natural images)
- GTSRB (traffic sign recognition)
- NCT-CRC-HE-100K (histopathological images)
- NWPU-RESISC45 (remote sensing imagery)
- PlantVillage (agricultural imagery)
SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):
Backbone | CIFAR-10 | GTSRB | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage |
---|---|---|---|---|---|
DenseNet201 | 0.5427 | 0.9862 | 0.9214 | 0.4493 | 0.8725 |
Vgg16 | 0.5345 | 0.8180 | 0.8234 | 0.4114 | 0.7064 |
Vgg19 | 0.5307 | 0.7551 | 0.8178 | 0.3844 | 0.6811 |
DenseNet121 | 0.5290 | 0.9813 | 0.9247 | 0.4381 | 0.8321 |
AlexNet | 0.6126 | 0.9059 | 0.8743 | 0.4397 | 0.7684 |
Inception | 0.7734 | 0.8934 | 0.8707 | 0.8707 | 0.8216 |
ResNet | 0.9172 | 0.9134 | 0.9478 | 0.9103 | 0.8905 |
MobileNet | 0.9169 | 0.3006 | 0.4965 | 0.1667 | 0.2213 |
ViT - S | 0.8465 | 0.8542 | 0.8234 | 0.6116 | 0.8654 |
ViT - L | 0.8637 | 0.8613 | 0.8345 | 0.8358 | 0.8842 |
MNASNet1_0 | 0.1032 | 0.0024 | 0.0212 | 0.0011 | 0.0049 |
ShuffleNet_V2_x1_0 | 0.3523 | 0.4244 | 0.4598 | 0.1808 | 0.3190 |
SqueezeNet1_0 | 0.4328 | 0.8392 | 0.7843 | 0.3913 | 0.6638 |
GoogLeNet | 0.4954 | 0.9455 | 0.8631 | 0.3720 | 0.7726 |
Proposed (SAG-ViT) | 0.9574 | 0.9958 | 0.9861 | 0.9549 | 0.9772 |
Citation
If you find our paper and code helpful for your research, please consider citing our work and giving the repository a star:
@misc{SAGViT,
title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers},
author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
year={2024},
eprint={2411.09420},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.09420},
}
- Downloads last month
- 20