shravvvv commited on
Commit
039647a
·
1 Parent(s): 9d8e0f4

Added model files and updated config.json

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data/
2
+ __pycache__
3
+ tests/__pycache__
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Shravan Venkatraman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "d_model": 64,
3
  "dim_feedforward": 64,
4
  "gcn_hidden": 128,
 
1
  {
2
+ "model_type": "sag-vit",
3
  "d_model": 64,
4
  "dim_feedforward": 64,
5
  "gcn_hidden": 128,
data_loader.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import DataLoader, random_split
3
+ from torchvision import datasets, transforms
4
+
5
+ def get_dataloaders(data_dir="path/to/data/dir", batch_size=512, train_split=0.8, img_size=224, num_workers=4):
6
+ """
7
+ Returns training and validation dataloaders for an image classification dataset.
8
+
9
+ Parameters:
10
+ - data_dir (str): Path to the directory containing image data in a folder structure compatible with ImageFolder.
11
+ - batch_size (int): Number of samples per batch.
12
+ - train_split (float): Fraction of data to use for training. Remaining is for validation.
13
+ - img_size (int): Target size to which all images are resized after validation.
14
+ - num_workers (int): Number of worker processes for data loading.
15
+
16
+ Image Size Validation:
17
+ - Minimum allowed image size: 49x49 pixels.
18
+ - If an image has either width or height less than 49 pixels, a ValueError is raised.
19
+
20
+ Returns:
21
+ - train_dataloader (DataLoader): DataLoader for the training split.
22
+ - val_dataloader (DataLoader): DataLoader for the validation split.
23
+ """
24
+
25
+ # Check if the provided image size is valid
26
+ if img_size < 49:
27
+ raise ValueError(f"Image size must be at least 49x49 pixels, but got {img_size}x{img_size}.")
28
+
29
+ transform = transforms.Compose([
30
+ transforms.Resize((img_size, img_size)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # Load full dataset
36
+ full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
37
+
38
+ # Split into training and validation sets
39
+ train_size = int(train_split * len(full_dataset))
40
+ val_size = len(full_dataset) - train_size
41
+ train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
42
+
43
+ # Create dataloaders
44
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
45
+ val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
46
+
47
+ return train_dataloader, val_dataloader
graph_construction.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import networkx as nx
3
+ from torch_geometric.utils import from_networkx
4
+
5
+ ####################################################################
6
+ # These functions reflect the methods described in Section 3.1 and 3.2
7
+ # of the SAG-ViT paper, where high-fidelity feature patches are extracted
8
+ # from the CNN feature maps and organized into a graph structure.
9
+ ####################################################################
10
+
11
+ def extract_patches(feature_map, patch_size=(4, 4)):
12
+ """
13
+ Extracts non-overlapping patches from a feature map to form nodes in a graph.
14
+
15
+ Parameters:
16
+ - feature_map (Tensor): The feature map from the CNN of shape (B, C, H', W').
17
+ H' and W' are reduced spatial dimensions after CNN feature extraction.
18
+ - patch_size (tuple): Spatial size (height, width) of each patch.
19
+
20
+ Returns:
21
+ - patches (Tensor): Tensor of shape (B, N, C, patch_h, patch_w), where N is the number of patches per image.
22
+ """
23
+ b, c, h, w = feature_map.size()
24
+ patch_h, patch_w = patch_size
25
+
26
+ # Unfold extracts sliding patches; here we align so that they are non-overlapping
27
+ patches = feature_map.unfold(2, patch_h, patch_h).unfold(3, patch_w, patch_w)
28
+
29
+ # Rearrange to have patches as separate units
30
+ patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
31
+ patches = patches.view(b, -1, c, patch_h, patch_w)
32
+ return patches
33
+
34
+ def construct_graph_from_patch(patch_index, patch_shape, image_shape):
35
+ """
36
+ Constructs edges between patch nodes based on spatial adjacency (k-connectivity).
37
+ This follows the approach described in Section 3.2 of SAG-ViT, where patches
38
+ are arranged in a grid and connected to their spatial neighbors.
39
+
40
+ Parameters:
41
+ - patch_index (int): Index of the current patch node.
42
+ - patch_shape (tuple): (patch_height, patch_width).
43
+ - image_shape (tuple): (height, width) of the feature map.
44
+
45
+ Returns:
46
+ - G (nx.Graph): A graph with a single node and edges to its neighbors (to be composed globally).
47
+ """
48
+ G = nx.Graph()
49
+
50
+ # Compute grid dimensions (how many patches along height and width)
51
+ grid_height = image_shape[0] // patch_shape[0]
52
+ grid_width = image_shape[1] // patch_shape[1]
53
+
54
+ # Current node index in a flattened grid
55
+ current_node = patch_index
56
+
57
+ G.add_node(current_node)
58
+
59
+ # 8-neighborhood connectivity (up, down, left, right, diagonals)
60
+ neighbor_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1),
61
+ (-1, -1), (-1, 1), (1, -1), (1, 1)]
62
+
63
+ # Recover row, col from patch_index
64
+ row = current_node // grid_width
65
+ col = current_node % grid_width
66
+
67
+ for dr, dc in neighbor_offsets:
68
+ neighbor_row = row + dr
69
+ neighbor_col = col + dc
70
+ if 0 <= neighbor_row < grid_height and 0 <= neighbor_col < grid_width:
71
+ neighbor_node = neighbor_row * grid_width + neighbor_col
72
+ G.add_edge(current_node, neighbor_node)
73
+
74
+ return G
75
+
76
+ def build_graph_from_patches(feature_map, patch_size=(4,4)):
77
+ """
78
+ Builds a global graph for each image in the batch, where each node corresponds
79
+ to a patch, and edges represent spatial adjacency. This graph captures local
80
+ spatial relationships of the patches, as outlined in Sections 3.1 and 3.2 of SAG-ViT.
81
+
82
+ Parameters:
83
+ - feature_map (Tensor): CNN output (B, C, H', W').
84
+ - patch_size (tuple): Size of each patch (patch_h, patch_w).
85
+
86
+ Returns:
87
+ - G_global_batch (list): A list of NetworkX graphs, one per image in the batch.
88
+ - patches (Tensor): The extracted patches (B, N, C, patch_h, patch_w).
89
+ """
90
+ patches = extract_patches(feature_map, patch_size)
91
+ batch_size = patches.size(0)
92
+
93
+ grid_height = feature_map.size(2) // patch_size[0]
94
+ grid_width = feature_map.size(3) // patch_size[1]
95
+ num_patches = grid_height * grid_width
96
+
97
+ G_global_batch = []
98
+ for batch_idx in range(batch_size):
99
+ G_global = nx.Graph()
100
+ # Construct a global graph by composing individual patch-based graphs
101
+ for patch_idx in range(num_patches):
102
+ G_patch = construct_graph_from_patch(
103
+ patch_index=patch_idx,
104
+ patch_shape=patch_size,
105
+ image_shape=(feature_map.size(2), feature_map.size(3))
106
+ )
107
+ G_global = nx.compose(G_global, G_patch)
108
+ G_global_batch.append(G_global)
109
+
110
+ return G_global_batch, patches
111
+
112
+ def build_graph_data_from_patches(G_global_batch, patches):
113
+ """
114
+ Converts NetworkX graphs and associated patches into PyTorch Geometric Data objects.
115
+ Each node corresponds to a patch vectorized into a feature node embedding.
116
+
117
+ Parameters:
118
+ - G_global_batch (list): List of global graphs (one per image) in NetworkX form.
119
+ - patches (Tensor): (B, N, C, patch_h, patch_w) patch tensor.
120
+
121
+ Returns:
122
+ - data_list (list): List of PyTorch Geometric Data objects, where data.x are node features,
123
+ and data.edge_index is the adjacency from the constructed graph.
124
+ """
125
+ from_networkx_ = from_networkx # local alias to avoid confusion
126
+
127
+ data_list = []
128
+ batch_size, num_patches, channels, patch_h, patch_w = patches.size()
129
+
130
+ for batch_idx, G_global in enumerate(G_global_batch):
131
+ # Flatten each patch into a feature vector
132
+ node_features = patches[batch_idx].view(num_patches, -1)
133
+
134
+ G_pygeom = from_networkx_(G_global)
135
+ G_pygeom.x = node_features
136
+ data_list.append(G_pygeom)
137
+
138
+ return data_list
hubconf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch']
2
+
3
+ from sag_vit_model import SAGViTClassifier
4
+ import torch
5
+
6
+ def SAGViT(pretrained=False, **kwargs):
7
+ """
8
+ SAG-ViT model endpoint.
9
+ Args:
10
+ pretrained (bool): If True, loads pretrained weights.
11
+ **kwargs: Additional arguments for the model.
12
+ Returns:
13
+ model (nn.Module): The SAG-ViT model as proposed in the
14
+ paper: SAG-ViT: A Scale-Aware, High-Fidelity Patching
15
+ Approach with Graph Attention for Vision Transformers.
16
+ https://doi.org/10.48550/arXiv.2411.09420
17
+ """
18
+ model = SAGViTClassifier(**kwargs)
19
+ if pretrained:
20
+ checkpoint = ''
21
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
22
+ model.load_state_dict(state_dict)
23
+ return model
model_components.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.nn import GATConv, global_mean_pool
5
+
6
+ from torchvision import models
7
+
8
+ ###############################################################
9
+ # These modules correspond to core building blocks of SAG-ViT:
10
+ # 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
11
+ # 2. A Graph Attention Network (GAT) to refine patch embeddings.
12
+ # 3. A Transformer Encoder to capture global long-range dependencies.
13
+ # 4. An MLP classifier head.
14
+ ###############################################################
15
+
16
+ class EfficientNetV2FeatureExtractor(nn.Module):
17
+ """
18
+ Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps
19
+ from images using a pre-trained EfficientNetV2-S model. This corresponds
20
+ to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich
21
+ feature maps that preserve semantic information at multiple scales.
22
+ """
23
+ def __init__(self, pretrained=False):
24
+ super(EfficientNetV2FeatureExtractor, self).__init__()
25
+
26
+ # Load EfficientNetV2-S with pretrained weights
27
+ efficientnet = models.efficientnet_v2_s(
28
+ weights="IMAGENET1K_V1" if pretrained else None
29
+ )
30
+
31
+ # Extract layers up to the last block before downsampling below 16x16
32
+ self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])
33
+
34
+
35
+ def forward(self, x):
36
+ """
37
+ Forward pass through the CNN backbone.
38
+
39
+ Input:
40
+ - x (Tensor): Input images of shape (B, 3, H, W)
41
+
42
+ Output:
43
+ - features (Tensor): Extracted feature map of shape (B, C, H', W'),
44
+ where H' and W' are reduced spatial dimensions.
45
+ """
46
+ features = self.extractor(x)
47
+ return features
48
+
49
+ class GATGNN(nn.Module):
50
+ """
51
+ A Graph Attention Network (GAT) that processes patch-graph embeddings.
52
+ This module corresponds to the Graph Attention stage (Section 3.3),
53
+ refining local relationships between patches in a learned manner.
54
+ """
55
+ def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
56
+ super(GATGNN, self).__init__()
57
+ # GAT layers:
58
+ # First layer maps raw patch embeddings to a higher-level representation.
59
+ self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
60
+ # Second layer produces final node embeddings with a single head.
61
+ self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
62
+ self.pool = global_mean_pool
63
+
64
+ def forward(self, data):
65
+ """
66
+ Input:
67
+ - data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
68
+
69
+ Output:
70
+ - x (Tensor): Aggregated graph-level embedding after mean pooling.
71
+ """
72
+ x, edge_index, batch = data.x, data.edge_index, data.batch
73
+ x = F.elu(self.conv1(x, edge_index))
74
+ x = self.conv2(x, edge_index)
75
+ x = self.pool(x, batch)
76
+ return x
77
+
78
+ class TransformerEncoder(nn.Module):
79
+ """
80
+ A Transformer encoder to capture long-range dependencies among patch embeddings.
81
+ Integrates global dependencies after GAT processing, as per Section 3.3.
82
+ """
83
+ def __init__(self, d_model, nhead, num_layers, dim_feedforward):
84
+ super(TransformerEncoder, self).__init__()
85
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
86
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
87
+
88
+ def forward(self, x):
89
+ """
90
+ Input:
91
+ - x (Tensor): Sequence of patch embeddings with shape (B, N, D).
92
+
93
+ Output:
94
+ - (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
95
+ """
96
+ # The Transformer expects (N, B, D), so transpose first
97
+ x = x.transpose(0, 1) # (N, B, D)
98
+ x = self.transformer_encoder(x)
99
+ x = x.transpose(0, 1) # (B, N, D)
100
+ return x
101
+
102
+ class MLPBlock(nn.Module):
103
+ """
104
+ An MLP classification head to map final global embeddings to classification logits.
105
+ """
106
+ def __init__(self, in_features, hidden_features, out_features):
107
+ super(MLPBlock, self).__init__()
108
+ self.mlp = nn.Sequential(
109
+ nn.Linear(in_features, hidden_features),
110
+ nn.ReLU(),
111
+ nn.Linear(hidden_features, out_features)
112
+ )
113
+
114
+ def forward(self, x):
115
+ return self.mlp(x)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ pandas==2.2.3
3
+ matplotlib==3.7.5
4
+ seaborn==0.12.2
5
+ tqdm==4.66.4
6
+ scikit-learn==1.2.2
7
+ torch==2.4.0
8
+ torch-geometric==2.6.1
9
+ torchvision==0.19.0
10
+ networkx==3.3
sag_vit_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ from torch_geometric.data import Batch
6
+ from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
7
+ from graph_construction import build_graph_from_patches, build_graph_data_from_patches
8
+
9
+ ###############################################################################
10
+ # SAG-ViT Model:
11
+ # This class combines:
12
+ # 1) CNN backbone to produce high-fidelity feature maps (Section 3.1),
13
+ # 2) Graph construction and GAT to refine local patch embeddings (Section 3.2 and 3.3),
14
+ # 3) A Transformer encoder to capture global relationships (Section 3.3),
15
+ # 4) A final MLP classifier.
16
+ ###############################################################################
17
+
18
+ class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
19
+ """
20
+ SAG-ViT: Scale-Aware Graph Attention Vision Transformer
21
+
22
+ This model integrates the following steps:
23
+ - Extract multi-scale features from images using a CNN backbone (EfficientNetv2 here).
24
+ - Partition the feature map into patches and build a graph where each node is a patch.
25
+ - Use a Graph Attention Network (GAT) to refine patch embeddings based on local spatial relationships.
26
+ - Utilize a Transformer encoder to model long-range dependencies and integrate multi-scale information.
27
+ - Finally, classify the resulting representation into desired classes.
28
+
29
+ Inputs:
30
+ - x (Tensor): Input images (B, 3, H, W)
31
+
32
+ Outputs:
33
+ - out (Tensor): Classification logits (B, num_classes)
34
+ """
35
+ def __init__(
36
+ self,
37
+ patch_size=(4,4),
38
+ num_classes=10,
39
+ d_model=64,
40
+ nhead=4,
41
+ num_layers=2,
42
+ dim_feedforward=64,
43
+ hidden_mlp_features=64,
44
+ in_channels=2560, # Derived from patch dimensions and CNN output channels
45
+ gcn_hidden=128,
46
+ gcn_out=64
47
+ ):
48
+ super(SAGViTClassifier, self).__init__()
49
+
50
+ # CNN feature extractor (frozen pre-trained EfficientNetv2)
51
+ self.cnn = EfficientNetV2FeatureExtractor()
52
+
53
+ # Graph Attention Network to process patch embeddings
54
+ self.gcn = GATGNN(in_channels=in_channels, hidden_channels=gcn_hidden, out_channels=gcn_out)
55
+
56
+ # Learnable positional embedding for Transformer input
57
+ self.positional_embedding = nn.Parameter(torch.randn(1, 1, d_model))
58
+ # Extra embedding token (similar to class token) to summarize global info
59
+ self.extra_embedding = nn.Parameter(torch.randn(1, d_model))
60
+
61
+ # Transformer encoder to capture long-range global dependencies
62
+ self.transformer_encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward)
63
+
64
+ # MLP classification head
65
+ self.mlp = MLPBlock(d_model, hidden_mlp_features, num_classes)
66
+
67
+ self.patch_size = patch_size
68
+
69
+ def forward(self, x):
70
+ # Step 1: High-fidelity feature extraction from CNN
71
+ feature_map = self.cnn(x)
72
+
73
+ # Step 2: Build graphs from patches
74
+ G_global_batch, patches = build_graph_from_patches(feature_map, self.patch_size)
75
+
76
+ # Step 3: Convert to PyG Data format and batch
77
+ data_list = build_graph_data_from_patches(G_global_batch, patches)
78
+ device = x.device
79
+ batch = Batch.from_data_list(data_list).to(device)
80
+
81
+ # Step 4: GAT stage
82
+ x_gcn = self.gcn(batch)
83
+
84
+ # Step 5: Reshape GCN output back to (B, N, D)
85
+ # The number of patches per image is determined by patch size and feature map dimensions.
86
+ B = x.size(0)
87
+ D = x_gcn.size(-1)
88
+ # N is automatically inferred
89
+ # Thus x_gcn is (B, D) now. We need a sequence dimension for the Transformer.
90
+ # Let's treat each image-level embedding as one "patch token" plus an extra token:
91
+ patch_embeddings = x_gcn.unsqueeze(1) # (B, 1, D)
92
+
93
+ # Add positional embedding
94
+ patch_embeddings = patch_embeddings + self.positional_embedding # (B, 1, D)
95
+
96
+ # Add an extra learnable embedding (like a CLS token)
97
+ patch_embeddings = torch.cat([patch_embeddings, self.extra_embedding.unsqueeze(0).expand(B, -1, -1)], dim=1) # (B, 2, D)
98
+
99
+ # Step 6: Transformer encoder
100
+ x_trans = self.transformer_encoder(patch_embeddings)
101
+
102
+ # Step 7: Global pooling (here we just take the mean)
103
+ x_pooled = x_trans.mean(dim=1) # (B, D)
104
+
105
+ # Classification
106
+ out = self.mlp(x_pooled)
107
+ return out
tests/test_graph_construction.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ import networkx as nx
4
+ from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches
5
+
6
+ class TestGraphConstruction(unittest.TestCase):
7
+ def test_extract_patches_shape(self):
8
+ # Create a dummy feature map: B=2, C=16, H=32, W=32
9
+ feature_map = torch.randn(2, 16, 32, 32)
10
+ patches = extract_patches(feature_map, patch_size=(4,4))
11
+ # Check dimensions: after extraction,
12
+ # number_of_patches = (H/4)*(W/4) = 8*8=64 per image, total 2*64=128
13
+ self.assertEqual(patches.shape, (2, 64, 16, 4, 4))
14
+
15
+ def test_build_graph_from_patches_graph_structure(self):
16
+ feature_map = torch.randn(1, 16, 32, 32)
17
+ G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
18
+ # 1 image => G_batch[0] is the graph
19
+ G = G_batch[0]
20
+ # We have 64 patches
21
+ self.assertEqual(len(G.nodes), 64)
22
+ # Check if edges exist (8-neighborhood).
23
+ # Interior nodes should have edges to neighbors.
24
+ # Just check a random node in the middle
25
+ node_index = 9 # assuming row=1, col=1 in an 8x8 grid
26
+ self.assertTrue(len(list(G.neighbors(node_index))) > 0)
27
+
28
+ def test_build_graph_data_from_patches_conversion(self):
29
+ feature_map = torch.randn(2, 16, 32, 32)
30
+ G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
31
+ data_list = build_graph_data_from_patches(G_batch, patches)
32
+ self.assertEqual(len(data_list), 2)
33
+ # Check node feature shape
34
+ self.assertEqual(data_list[0].x.shape[1], 16*4*4) # C * patch_h * patch_w = 16*4*4=256
35
+ # Check edges are present
36
+ self.assertTrue(data_list[0].edge_index.shape[1] > 0)
37
+
38
+ if __name__ == '__main__':
39
+ unittest.main()
tests/test_model_components.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
4
+ from torch_geometric.data import Data
5
+
6
+ class TestModelComponents(unittest.TestCase):
7
+ def test_efficientnetv2_extractor_output_shape(self):
8
+ model = EfficientNetV2FeatureExtractor()
9
+ model.eval()
10
+ x = torch.randn(2, 3, 224, 224)
11
+ with torch.no_grad():
12
+ features = model(x)
13
+ # Check output shape - depends on inception intermediate layer
14
+ # Example: shape could be (2, 768, 8, 8) depending on the chosen layer
15
+ self.assertEqual(features.size(0), 2)
16
+ self.assertTrue(features.size(1) > 0)
17
+ self.assertTrue(features.size(2) > 0)
18
+ self.assertTrue(features.size(3) > 0)
19
+
20
+ def test_gatgnn_forward(self):
21
+ # Graph with 4 nodes, each node feature dim=256
22
+ x = torch.randn(4, 256)
23
+ edge_index = torch.tensor([[0,1,1,2],[1,0,2,3]], dtype=torch.long)
24
+ batch = torch.tensor([0,0,0,0])
25
+ data = Data(x=x, edge_index=edge_index, batch=batch)
26
+
27
+ gnn = GATGNN(in_channels=256, hidden_channels=64, out_channels=32)
28
+ output = gnn(data)
29
+ # After pooling: should be (batch_size, out_channels) = (1,32)
30
+ self.assertEqual(output.shape, (1, 32))
31
+
32
+ def test_transformer_encoder(self):
33
+ # (B, N, D) = (2, 10, 64)
34
+ x = torch.randn(2, 10, 64)
35
+ encoder = TransformerEncoder(d_model=64, nhead=4, num_layers=2, dim_feedforward=64)
36
+ out = encoder(x)
37
+ # same shape as input
38
+ self.assertEqual(out.shape, (2, 10, 64))
39
+
40
+ def test_mlp_block(self):
41
+ mlp = MLPBlock(in_features=64, hidden_features=128, out_features=10)
42
+ x = torch.randn(2, 64)
43
+ out = mlp(x)
44
+ self.assertEqual(out.shape, (2,10))
45
+
46
+ def test_efficientnetv2_freeze(self):
47
+ # Ensure params are frozen
48
+ model = EfficientNetV2FeatureExtractor()
49
+ for param in model.parameters():
50
+ self.assertFalse(param.requires_grad)
51
+
52
+ if __name__ == '__main__':
53
+ unittest.main()
tests/test_sag_vit_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from sag_vit_model import SAGViTClassifier
4
+
5
+ class TestSAGViTModel(unittest.TestCase):
6
+ def test_forward_pass(self):
7
+ model = SAGViTClassifier(
8
+ patch_size=(4,4),
9
+ num_classes=10, # smaller num classes for test
10
+ d_model=64,
11
+ nhead=4,
12
+ num_layers=2,
13
+ dim_feedforward=64,
14
+ hidden_mlp_features=64,
15
+ in_channels=2560, # from patch dimension example
16
+ gcn_hidden=128,
17
+ gcn_out=64
18
+ )
19
+ model.eval()
20
+ x = torch.randn(2, 3, 224, 224)
21
+ with torch.no_grad():
22
+ out = model(x)
23
+ # Check output shape: (B, num_classes) = (2,10)
24
+ self.assertEqual(out.shape, (2,10))
25
+
26
+ def test_empty_input(self):
27
+ model = SAGViTClassifier()
28
+ # Passing an empty tensor should fail gracefully
29
+ with self.assertRaises(Exception):
30
+ model(torch.empty(0,3,224,224))
31
+
32
+ def test_invalid_input_dimensions(self):
33
+ model = SAGViTClassifier()
34
+ # Incorrect dimension (e.g., missing channel)
35
+ with self.assertRaises(RuntimeError):
36
+ model(torch.randn(2, 224, 224)) # no channel dimension
37
+
38
+ if __name__ == '__main__':
39
+ unittest.main()
tests/test_train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ import torch
4
+ import torch.nn as nn
5
+ from train import train_model
6
+ from sag_vit_model import SAGViTClassifier
7
+
8
+ class TestTrain(unittest.TestCase):
9
+ @patch("train.optim.Adam")
10
+ def test_train_model_loop(self, mock_adam):
11
+ # Mock the optimizer
12
+ mock_optimizer = MagicMock()
13
+ mock_adam.return_value = mock_optimizer
14
+
15
+ # Mock dataloaders with a small dummy dataset
16
+ # Just one batch with a couple of samples
17
+ train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
18
+ val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
19
+
20
+ model = SAGViTClassifier(num_classes=2)
21
+
22
+ criterion = nn.CrossEntropyLoss()
23
+ device = torch.device("cpu")
24
+
25
+ # Test a single epoch training
26
+ history = train_model(model, "TestModel", train_dataloader, val_dataloader,
27
+ num_epochs=1, criterion=criterion, optimizer=mock_optimizer, device=device, patience=2, verbose=False)
28
+
29
+ # Check if history is properly recorded
30
+ self.assertIn("train_loss", history)
31
+ self.assertIn("val_loss", history)
32
+ self.assertGreaterEqual(len(history["train_loss"]), 1)
33
+ self.assertGreaterEqual(len(history["val_loss"]), 1)
34
+
35
+ def test_early_stopping(self):
36
+ # Mocking dataloaders where validation loss doesn't improve
37
+ model = SAGViTClassifier(num_classes=2)
38
+ criterion = nn.CrossEntropyLoss()
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
40
+ device = torch.device("cpu")
41
+
42
+ # create a scenario where val loss won't improve
43
+ # first epoch normal, second epoch slightly worse
44
+ train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
45
+ val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
46
+
47
+ history = train_model(model, "TestModelEarlyStop", train_dataloader, val_dataloader,
48
+ num_epochs=5, criterion=criterion, optimizer=optimizer, device=device, patience=1, verbose=False)
49
+
50
+ # Should have triggered early stopping before all 5 epochs
51
+ self.assertLessEqual(len(history["train_loss"]), 5)
52
+
53
+ if __name__ == '__main__':
54
+ unittest.main()
train.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn, optim
4
+ from tqdm import tqdm
5
+ from huggingface_hub import HfApi
6
+ import numpy as np
7
+ from sklearn.metrics import (precision_score, recall_score, f1_score,
8
+ roc_auc_score, cohen_kappa_score, matthews_corrcoef,
9
+ confusion_matrix)
10
+
11
+ from sag_vit_model import SAGViTClassifier
12
+ from data_loader import get_dataloaders
13
+
14
+ #####################################################################
15
+ # This file provides the training loop and metric computation. It uses
16
+ # the SAG-ViT model defined in sag_vit_model.py, and the data from data_loader.py.
17
+ # The training loop is adapted to implement early stopping and track various metrics.
18
+ #####################################################################
19
+
20
+ def train_model(model, model_name, train_loader, val_loader, num_epochs, criterion, optimizer, device, patience=8, verbose=True):
21
+ """
22
+ Trains the SAG-ViT model and evaluates it on the validation set.
23
+ Implements early stopping based on validation loss.
24
+
25
+ Parameters:
26
+ - model (nn.Module): The SAG-ViT model.
27
+ - model_name (str): A name to identify the model (used for saving checkpoints).
28
+ - train_loader, val_loader: DataLoaders for training and validation.
29
+ - num_epochs (int): Maximum number of epochs.
30
+ - criterion (nn.Module): Loss function.
31
+ - optimizer (torch.optim.Optimizer): Optimization algorithm.
32
+ - device (torch.device): Device to run the computations on (CPU/GPU).
33
+ - patience (int): Early stopping patience.
34
+
35
+ Returns:
36
+ - history (dict): Dictionary containing training and validation metrics per epoch.
37
+ """
38
+
39
+ history = {
40
+ 'train_loss': [], 'train_acc': [], 'train_prec': [], 'train_rec': [], 'train_f1': [],
41
+ 'train_auc': [], 'train_mcc': [], 'train_cohen_kappa': [], 'train_confusion_matrix': [],
42
+ 'val_loss': [], 'val_acc': [], 'val_prec': [], 'val_rec': [], 'val_f1': [],
43
+ 'val_auc': [], 'val_mcc': [], 'val_cohen_kappa': [], 'val_confusion_matrix': []
44
+ }
45
+
46
+ best_val_loss = float('inf')
47
+ patience_counter = 0
48
+ best_model_state = None
49
+
50
+ for epoch in range(num_epochs):
51
+ print(f'Epoch {epoch+1}/{num_epochs}')
52
+ model.train()
53
+
54
+ train_loss_total, correct, total = 0, 0, 0
55
+ all_preds, all_labels, all_probs = [], [], []
56
+
57
+ # Training loop
58
+ for batch_idx, (X, y) in enumerate(tqdm(train_loader)):
59
+ inputs, labels = X.to(device), y.to(device)
60
+ optimizer.zero_grad()
61
+
62
+ outputs = model(inputs)
63
+ loss = criterion(outputs, labels)
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ train_loss_total += loss.item()
68
+
69
+ probs = torch.softmax(outputs, dim=1)
70
+ _, preds = torch.max(outputs, 1)
71
+ correct += (preds == labels).sum().item()
72
+ total += labels.size(0)
73
+
74
+ all_preds.extend(preds.cpu().numpy())
75
+ all_labels.extend(labels.cpu().numpy())
76
+ all_probs.extend(probs.detach().cpu().numpy())
77
+
78
+ # Compute training metrics
79
+ train_acc = correct / total
80
+ train_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
81
+ train_rec = recall_score(all_labels, all_preds, average='macro')
82
+ train_f1 = f1_score(all_labels, all_preds, average='macro')
83
+ train_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
84
+ train_mcc = matthews_corrcoef(all_labels, all_preds)
85
+ train_confusion = confusion_matrix(all_labels, all_preds)
86
+
87
+ history['train_loss'].append(train_loss_total / len(train_loader))
88
+ history['train_acc'].append(train_acc)
89
+ history['train_prec'].append(train_prec)
90
+ history['train_rec'].append(train_rec)
91
+ history['train_f1'].append(train_f1)
92
+ history['train_cohen_kappa'].append(train_cohen_kappa)
93
+ history['train_mcc'].append(train_mcc)
94
+ history['train_confusion_matrix'].append(train_confusion)
95
+
96
+ # Validation
97
+ model.eval()
98
+ val_loss_total, correct, total = 0, 0, 0
99
+ all_preds, all_labels, all_probs = [], [], []
100
+
101
+ with torch.no_grad():
102
+ for batch_idx, (X, y) in enumerate(tqdm(val_loader)):
103
+ inputs, labels = X.to(device), y.to(device)
104
+ outputs = model(inputs)
105
+ loss = criterion(outputs, labels)
106
+
107
+ val_loss_total += loss.item()
108
+ probs = torch.softmax(outputs, dim=1)
109
+ _, preds = torch.max(outputs, 1)
110
+ correct += (preds == labels).sum().item()
111
+ total += labels.size(0)
112
+
113
+ all_preds.extend(preds.cpu().numpy())
114
+ all_labels.extend(labels.cpu().numpy())
115
+ all_probs.extend(probs.detach().cpu().numpy())
116
+
117
+ # Compute validation metrics
118
+ val_acc = correct / total
119
+ val_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
120
+ val_rec = recall_score(all_labels, all_preds, average='macro')
121
+ val_f1 = f1_score(all_labels, all_preds, average='macro')
122
+ val_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
123
+ val_mcc = matthews_corrcoef(all_labels, all_preds)
124
+ val_confusion = confusion_matrix(all_labels, all_preds)
125
+
126
+ history['val_loss'].append(val_loss_total / len(val_loader))
127
+ history['val_acc'].append(val_acc)
128
+ history['val_prec'].append(val_prec)
129
+ history['val_rec'].append(val_rec)
130
+ history['val_f1'].append(val_f1)
131
+ history['val_cohen_kappa'].append(val_cohen_kappa)
132
+ history['val_mcc'].append(val_mcc)
133
+ history['val_confusion_matrix'].append(val_confusion)
134
+
135
+ # Print epoch summary
136
+ if verbose:
137
+ print(f"Train Loss: {history['train_loss'][-1]:.4f}, Train Acc: {history['train_acc'][-1]:.4f}, "
138
+ f"Val Loss: {history['val_loss'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.4f}")
139
+
140
+ # Early stopping
141
+ current_val_loss = history['val_loss'][-1]
142
+ if current_val_loss < best_val_loss:
143
+ best_val_loss = current_val_loss
144
+ best_model_state = model.state_dict()
145
+ patience_counter = 0
146
+ else:
147
+ patience_counter += 1
148
+ print(f"Patience counter: {patience_counter}/{patience}")
149
+ if patience_counter >= patience:
150
+ print("Early stopping triggered.")
151
+ model.load_state_dict(best_model_state)
152
+ torch.save(model.state_dict(), f'{model_name}.pth')
153
+ return history
154
+
155
+ model.load_state_dict(best_model_state)
156
+ torch.save(model.state_dict(), f'{model_name}.pth')
157
+
158
+ return history
159
+
160
+
161
+ if __name__ == "__main__":
162
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
163
+ print(f"Training on device: {device}")
164
+ data_dir = "data/PlantVillage" # "path/to/data/dir"
165
+ num_classes = len(os.listdir(data_dir))
166
+ train_loader, val_loader = get_dataloaders(data_dir=data_dir, img_size=224, batch_size=32) # Minimum image size should be atleast (49, 49)
167
+
168
+ model = SAGViTClassifier(num_classes=num_classes).to(device)
169
+
170
+ criterion = nn.CrossEntropyLoss()
171
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
172
+ num_epochs = 100
173
+
174
+ history = train_model(
175
+ model,
176
+ 'SAG-ViT',
177
+ train_loader,
178
+ val_loader,
179
+ num_epochs,
180
+ criterion,
181
+ optimizer,
182
+ device
183
+ )
184
+
185
+ # You may save history to a CSV or analyze it further as needed.
186
+ # Example:
187
+ # import pandas as pd
188
+ # history_df = pd.DataFrame(history)
189
+ # history_df.to_csv("training_history.csv", index=False)
190
+
191
+ # Load the saved model back (best practice before pushing)
192
+ model.load_state_dict(torch.load("SAG-ViT.pth"))
193
+ model.eval()
194
+
195
+ # Push the model to the Hugging Face Hub
196
+ model.push_to_hub("shravvvv/SAG-ViT", commit_message="Initial model push", private=True, trust_remote_code=True)