SAG-ViT / model_components.py
shravvvv's picture
Updated stuff
986f758
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv
from torchvision import models
###############################################################
# These modules correspond to core building blocks of SAG-ViT:
# 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
# 2. A Graph Attention Network (GAT) to refine patch embeddings.
# 3. A Transformer Encoder to capture global long-range dependencies.
# 4. An MLP classifier head.
###############################################################
class EfficientNetV2FeatureExtractor(nn.Module):
"""
Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps
from images using a pre-trained EfficientNetV2-S model. This corresponds
to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich
feature maps that preserve semantic information at multiple scales.
"""
def __init__(self, pretrained=False):
super(EfficientNetV2FeatureExtractor, self).__init__()
# Load EfficientNetV2-S with pretrained weights
efficientnet = models.efficientnet_v2_s(
weights="IMAGENET1K_V1" if pretrained else None
)
# Extract layers up to the last block before downsampling below 16x16
self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])
def forward(self, x):
"""
Forward pass through the CNN backbone.
Input:
- x (Tensor): Input images of shape (B, 3, H, W)
Output:
- features (Tensor): Extracted feature map of shape (B, C, H', W'),
where H' and W' are reduced spatial dimensions.
"""
features = self.extractor(x)
return features
class GATGNN(nn.Module):
"""
A Graph Attention Network (GAT) that processes patch-graph embeddings.
This module corresponds to the Graph Attention stage (Section 3.3),
refining local relationships between patches in a learned manner.
"""
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
super(GATGNN, self).__init__()
# GAT layers:
# First layer maps raw patch embeddings to a higher-level representation.
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
# Final GCN layer for refined representation
self.conv2 = GCNConv(hidden_channels * heads, out_channels)
self.pool = global_mean_pool
def forward(self, data):
"""
Input:
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
Output:
- x (Tensor): Aggregated graph-level embedding after mean pooling.
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
# GAT layer with ReLU activation
x = F.relu(self.conv1(x, edge_index))
# GCN layer for further aggregation
x = self.conv2(x, edge_index)
# Global mean pooling to obtain graph-level representation
out = self.pool(x, batch)
return out
def forward(self, data):
"""
Input:
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
Output:
- x (Tensor): Aggregated graph-level embedding after mean pooling.
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.elu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
x = self.pool(x, batch)
return x
class TransformerEncoder(nn.Module):
"""
A Transformer encoder to capture long-range dependencies among patch embeddings.
Integrates global dependencies after GAT processing, as per Section 3.3.
"""
def __init__(self, d_model, nhead, num_layers, dim_feedforward):
super(TransformerEncoder, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, x):
"""
Input:
- x (Tensor): Sequence of patch embeddings with shape (B, N, D).
Output:
- (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
"""
# The Transformer expects (N, B, D), so transpose first
x = x.transpose(0, 1) # (N, B, D)
x = self.transformer_encoder(x)
x = x.transpose(0, 1) # (B, N, D)
return x
class MLPBlock(nn.Module):
"""
An MLP classification head to map final global embeddings to classification logits.
"""
def __init__(self, in_features, hidden_features, out_features):
super(MLPBlock, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, out_features)
)
def forward(self, x):
return self.mlp(x)