Added model files and updated config.json
Browse files- .gitignore +3 -0
- LICENSE +21 -0
- config.json +1 -0
- data_loader.py +47 -0
- graph_construction.py +138 -0
- hubconf.py +23 -0
- model_components.py +115 -0
- requirements.txt +10 -0
- sag_vit_model.py +107 -0
- tests/test_graph_construction.py +39 -0
- tests/test_model_components.py +53 -0
- tests/test_sag_vit_model.py +39 -0
- tests/test_train.py +54 -0
- train.py +196 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
__pycache__
|
3 |
+
tests/__pycache__
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Shravan Venkatraman
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"d_model": 64,
|
3 |
"dim_feedforward": 64,
|
4 |
"gcn_hidden": 128,
|
|
|
1 |
{
|
2 |
+
"model_type": "sag-vit",
|
3 |
"d_model": 64,
|
4 |
"dim_feedforward": 64,
|
5 |
"gcn_hidden": 128,
|
data_loader.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.data import DataLoader, random_split
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
|
5 |
+
def get_dataloaders(data_dir="path/to/data/dir", batch_size=512, train_split=0.8, img_size=224, num_workers=4):
|
6 |
+
"""
|
7 |
+
Returns training and validation dataloaders for an image classification dataset.
|
8 |
+
|
9 |
+
Parameters:
|
10 |
+
- data_dir (str): Path to the directory containing image data in a folder structure compatible with ImageFolder.
|
11 |
+
- batch_size (int): Number of samples per batch.
|
12 |
+
- train_split (float): Fraction of data to use for training. Remaining is for validation.
|
13 |
+
- img_size (int): Target size to which all images are resized after validation.
|
14 |
+
- num_workers (int): Number of worker processes for data loading.
|
15 |
+
|
16 |
+
Image Size Validation:
|
17 |
+
- Minimum allowed image size: 49x49 pixels.
|
18 |
+
- If an image has either width or height less than 49 pixels, a ValueError is raised.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
- train_dataloader (DataLoader): DataLoader for the training split.
|
22 |
+
- val_dataloader (DataLoader): DataLoader for the validation split.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Check if the provided image size is valid
|
26 |
+
if img_size < 49:
|
27 |
+
raise ValueError(f"Image size must be at least 49x49 pixels, but got {img_size}x{img_size}.")
|
28 |
+
|
29 |
+
transform = transforms.Compose([
|
30 |
+
transforms.Resize((img_size, img_size)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
33 |
+
])
|
34 |
+
|
35 |
+
# Load full dataset
|
36 |
+
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
|
37 |
+
|
38 |
+
# Split into training and validation sets
|
39 |
+
train_size = int(train_split * len(full_dataset))
|
40 |
+
val_size = len(full_dataset) - train_size
|
41 |
+
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
42 |
+
|
43 |
+
# Create dataloaders
|
44 |
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
45 |
+
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
46 |
+
|
47 |
+
return train_dataloader, val_dataloader
|
graph_construction.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import networkx as nx
|
3 |
+
from torch_geometric.utils import from_networkx
|
4 |
+
|
5 |
+
####################################################################
|
6 |
+
# These functions reflect the methods described in Section 3.1 and 3.2
|
7 |
+
# of the SAG-ViT paper, where high-fidelity feature patches are extracted
|
8 |
+
# from the CNN feature maps and organized into a graph structure.
|
9 |
+
####################################################################
|
10 |
+
|
11 |
+
def extract_patches(feature_map, patch_size=(4, 4)):
|
12 |
+
"""
|
13 |
+
Extracts non-overlapping patches from a feature map to form nodes in a graph.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
- feature_map (Tensor): The feature map from the CNN of shape (B, C, H', W').
|
17 |
+
H' and W' are reduced spatial dimensions after CNN feature extraction.
|
18 |
+
- patch_size (tuple): Spatial size (height, width) of each patch.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
- patches (Tensor): Tensor of shape (B, N, C, patch_h, patch_w), where N is the number of patches per image.
|
22 |
+
"""
|
23 |
+
b, c, h, w = feature_map.size()
|
24 |
+
patch_h, patch_w = patch_size
|
25 |
+
|
26 |
+
# Unfold extracts sliding patches; here we align so that they are non-overlapping
|
27 |
+
patches = feature_map.unfold(2, patch_h, patch_h).unfold(3, patch_w, patch_w)
|
28 |
+
|
29 |
+
# Rearrange to have patches as separate units
|
30 |
+
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
|
31 |
+
patches = patches.view(b, -1, c, patch_h, patch_w)
|
32 |
+
return patches
|
33 |
+
|
34 |
+
def construct_graph_from_patch(patch_index, patch_shape, image_shape):
|
35 |
+
"""
|
36 |
+
Constructs edges between patch nodes based on spatial adjacency (k-connectivity).
|
37 |
+
This follows the approach described in Section 3.2 of SAG-ViT, where patches
|
38 |
+
are arranged in a grid and connected to their spatial neighbors.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
- patch_index (int): Index of the current patch node.
|
42 |
+
- patch_shape (tuple): (patch_height, patch_width).
|
43 |
+
- image_shape (tuple): (height, width) of the feature map.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
- G (nx.Graph): A graph with a single node and edges to its neighbors (to be composed globally).
|
47 |
+
"""
|
48 |
+
G = nx.Graph()
|
49 |
+
|
50 |
+
# Compute grid dimensions (how many patches along height and width)
|
51 |
+
grid_height = image_shape[0] // patch_shape[0]
|
52 |
+
grid_width = image_shape[1] // patch_shape[1]
|
53 |
+
|
54 |
+
# Current node index in a flattened grid
|
55 |
+
current_node = patch_index
|
56 |
+
|
57 |
+
G.add_node(current_node)
|
58 |
+
|
59 |
+
# 8-neighborhood connectivity (up, down, left, right, diagonals)
|
60 |
+
neighbor_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1),
|
61 |
+
(-1, -1), (-1, 1), (1, -1), (1, 1)]
|
62 |
+
|
63 |
+
# Recover row, col from patch_index
|
64 |
+
row = current_node // grid_width
|
65 |
+
col = current_node % grid_width
|
66 |
+
|
67 |
+
for dr, dc in neighbor_offsets:
|
68 |
+
neighbor_row = row + dr
|
69 |
+
neighbor_col = col + dc
|
70 |
+
if 0 <= neighbor_row < grid_height and 0 <= neighbor_col < grid_width:
|
71 |
+
neighbor_node = neighbor_row * grid_width + neighbor_col
|
72 |
+
G.add_edge(current_node, neighbor_node)
|
73 |
+
|
74 |
+
return G
|
75 |
+
|
76 |
+
def build_graph_from_patches(feature_map, patch_size=(4,4)):
|
77 |
+
"""
|
78 |
+
Builds a global graph for each image in the batch, where each node corresponds
|
79 |
+
to a patch, and edges represent spatial adjacency. This graph captures local
|
80 |
+
spatial relationships of the patches, as outlined in Sections 3.1 and 3.2 of SAG-ViT.
|
81 |
+
|
82 |
+
Parameters:
|
83 |
+
- feature_map (Tensor): CNN output (B, C, H', W').
|
84 |
+
- patch_size (tuple): Size of each patch (patch_h, patch_w).
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
- G_global_batch (list): A list of NetworkX graphs, one per image in the batch.
|
88 |
+
- patches (Tensor): The extracted patches (B, N, C, patch_h, patch_w).
|
89 |
+
"""
|
90 |
+
patches = extract_patches(feature_map, patch_size)
|
91 |
+
batch_size = patches.size(0)
|
92 |
+
|
93 |
+
grid_height = feature_map.size(2) // patch_size[0]
|
94 |
+
grid_width = feature_map.size(3) // patch_size[1]
|
95 |
+
num_patches = grid_height * grid_width
|
96 |
+
|
97 |
+
G_global_batch = []
|
98 |
+
for batch_idx in range(batch_size):
|
99 |
+
G_global = nx.Graph()
|
100 |
+
# Construct a global graph by composing individual patch-based graphs
|
101 |
+
for patch_idx in range(num_patches):
|
102 |
+
G_patch = construct_graph_from_patch(
|
103 |
+
patch_index=patch_idx,
|
104 |
+
patch_shape=patch_size,
|
105 |
+
image_shape=(feature_map.size(2), feature_map.size(3))
|
106 |
+
)
|
107 |
+
G_global = nx.compose(G_global, G_patch)
|
108 |
+
G_global_batch.append(G_global)
|
109 |
+
|
110 |
+
return G_global_batch, patches
|
111 |
+
|
112 |
+
def build_graph_data_from_patches(G_global_batch, patches):
|
113 |
+
"""
|
114 |
+
Converts NetworkX graphs and associated patches into PyTorch Geometric Data objects.
|
115 |
+
Each node corresponds to a patch vectorized into a feature node embedding.
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
- G_global_batch (list): List of global graphs (one per image) in NetworkX form.
|
119 |
+
- patches (Tensor): (B, N, C, patch_h, patch_w) patch tensor.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
- data_list (list): List of PyTorch Geometric Data objects, where data.x are node features,
|
123 |
+
and data.edge_index is the adjacency from the constructed graph.
|
124 |
+
"""
|
125 |
+
from_networkx_ = from_networkx # local alias to avoid confusion
|
126 |
+
|
127 |
+
data_list = []
|
128 |
+
batch_size, num_patches, channels, patch_h, patch_w = patches.size()
|
129 |
+
|
130 |
+
for batch_idx, G_global in enumerate(G_global_batch):
|
131 |
+
# Flatten each patch into a feature vector
|
132 |
+
node_features = patches[batch_idx].view(num_patches, -1)
|
133 |
+
|
134 |
+
G_pygeom = from_networkx_(G_global)
|
135 |
+
G_pygeom.x = node_features
|
136 |
+
data_list.append(G_pygeom)
|
137 |
+
|
138 |
+
return data_list
|
hubconf.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dependencies = ['torch']
|
2 |
+
|
3 |
+
from sag_vit_model import SAGViTClassifier
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def SAGViT(pretrained=False, **kwargs):
|
7 |
+
"""
|
8 |
+
SAG-ViT model endpoint.
|
9 |
+
Args:
|
10 |
+
pretrained (bool): If True, loads pretrained weights.
|
11 |
+
**kwargs: Additional arguments for the model.
|
12 |
+
Returns:
|
13 |
+
model (nn.Module): The SAG-ViT model as proposed in the
|
14 |
+
paper: SAG-ViT: A Scale-Aware, High-Fidelity Patching
|
15 |
+
Approach with Graph Attention for Vision Transformers.
|
16 |
+
https://doi.org/10.48550/arXiv.2411.09420
|
17 |
+
"""
|
18 |
+
model = SAGViTClassifier(**kwargs)
|
19 |
+
if pretrained:
|
20 |
+
checkpoint = ''
|
21 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
|
22 |
+
model.load_state_dict(state_dict)
|
23 |
+
return model
|
model_components.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_geometric.nn import GATConv, global_mean_pool
|
5 |
+
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
###############################################################
|
9 |
+
# These modules correspond to core building blocks of SAG-ViT:
|
10 |
+
# 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
|
11 |
+
# 2. A Graph Attention Network (GAT) to refine patch embeddings.
|
12 |
+
# 3. A Transformer Encoder to capture global long-range dependencies.
|
13 |
+
# 4. An MLP classifier head.
|
14 |
+
###############################################################
|
15 |
+
|
16 |
+
class EfficientNetV2FeatureExtractor(nn.Module):
|
17 |
+
"""
|
18 |
+
Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps
|
19 |
+
from images using a pre-trained EfficientNetV2-S model. This corresponds
|
20 |
+
to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich
|
21 |
+
feature maps that preserve semantic information at multiple scales.
|
22 |
+
"""
|
23 |
+
def __init__(self, pretrained=False):
|
24 |
+
super(EfficientNetV2FeatureExtractor, self).__init__()
|
25 |
+
|
26 |
+
# Load EfficientNetV2-S with pretrained weights
|
27 |
+
efficientnet = models.efficientnet_v2_s(
|
28 |
+
weights="IMAGENET1K_V1" if pretrained else None
|
29 |
+
)
|
30 |
+
|
31 |
+
# Extract layers up to the last block before downsampling below 16x16
|
32 |
+
self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])
|
33 |
+
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
"""
|
37 |
+
Forward pass through the CNN backbone.
|
38 |
+
|
39 |
+
Input:
|
40 |
+
- x (Tensor): Input images of shape (B, 3, H, W)
|
41 |
+
|
42 |
+
Output:
|
43 |
+
- features (Tensor): Extracted feature map of shape (B, C, H', W'),
|
44 |
+
where H' and W' are reduced spatial dimensions.
|
45 |
+
"""
|
46 |
+
features = self.extractor(x)
|
47 |
+
return features
|
48 |
+
|
49 |
+
class GATGNN(nn.Module):
|
50 |
+
"""
|
51 |
+
A Graph Attention Network (GAT) that processes patch-graph embeddings.
|
52 |
+
This module corresponds to the Graph Attention stage (Section 3.3),
|
53 |
+
refining local relationships between patches in a learned manner.
|
54 |
+
"""
|
55 |
+
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
|
56 |
+
super(GATGNN, self).__init__()
|
57 |
+
# GAT layers:
|
58 |
+
# First layer maps raw patch embeddings to a higher-level representation.
|
59 |
+
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
60 |
+
# Second layer produces final node embeddings with a single head.
|
61 |
+
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
|
62 |
+
self.pool = global_mean_pool
|
63 |
+
|
64 |
+
def forward(self, data):
|
65 |
+
"""
|
66 |
+
Input:
|
67 |
+
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
|
68 |
+
|
69 |
+
Output:
|
70 |
+
- x (Tensor): Aggregated graph-level embedding after mean pooling.
|
71 |
+
"""
|
72 |
+
x, edge_index, batch = data.x, data.edge_index, data.batch
|
73 |
+
x = F.elu(self.conv1(x, edge_index))
|
74 |
+
x = self.conv2(x, edge_index)
|
75 |
+
x = self.pool(x, batch)
|
76 |
+
return x
|
77 |
+
|
78 |
+
class TransformerEncoder(nn.Module):
|
79 |
+
"""
|
80 |
+
A Transformer encoder to capture long-range dependencies among patch embeddings.
|
81 |
+
Integrates global dependencies after GAT processing, as per Section 3.3.
|
82 |
+
"""
|
83 |
+
def __init__(self, d_model, nhead, num_layers, dim_feedforward):
|
84 |
+
super(TransformerEncoder, self).__init__()
|
85 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
|
86 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
"""
|
90 |
+
Input:
|
91 |
+
- x (Tensor): Sequence of patch embeddings with shape (B, N, D).
|
92 |
+
|
93 |
+
Output:
|
94 |
+
- (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
|
95 |
+
"""
|
96 |
+
# The Transformer expects (N, B, D), so transpose first
|
97 |
+
x = x.transpose(0, 1) # (N, B, D)
|
98 |
+
x = self.transformer_encoder(x)
|
99 |
+
x = x.transpose(0, 1) # (B, N, D)
|
100 |
+
return x
|
101 |
+
|
102 |
+
class MLPBlock(nn.Module):
|
103 |
+
"""
|
104 |
+
An MLP classification head to map final global embeddings to classification logits.
|
105 |
+
"""
|
106 |
+
def __init__(self, in_features, hidden_features, out_features):
|
107 |
+
super(MLPBlock, self).__init__()
|
108 |
+
self.mlp = nn.Sequential(
|
109 |
+
nn.Linear(in_features, hidden_features),
|
110 |
+
nn.ReLU(),
|
111 |
+
nn.Linear(hidden_features, out_features)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
return self.mlp(x)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
+
pandas==2.2.3
|
3 |
+
matplotlib==3.7.5
|
4 |
+
seaborn==0.12.2
|
5 |
+
tqdm==4.66.4
|
6 |
+
scikit-learn==1.2.2
|
7 |
+
torch==2.4.0
|
8 |
+
torch-geometric==2.6.1
|
9 |
+
torchvision==0.19.0
|
10 |
+
networkx==3.3
|
sag_vit_model.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
4 |
+
|
5 |
+
from torch_geometric.data import Batch
|
6 |
+
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
|
7 |
+
from graph_construction import build_graph_from_patches, build_graph_data_from_patches
|
8 |
+
|
9 |
+
###############################################################################
|
10 |
+
# SAG-ViT Model:
|
11 |
+
# This class combines:
|
12 |
+
# 1) CNN backbone to produce high-fidelity feature maps (Section 3.1),
|
13 |
+
# 2) Graph construction and GAT to refine local patch embeddings (Section 3.2 and 3.3),
|
14 |
+
# 3) A Transformer encoder to capture global relationships (Section 3.3),
|
15 |
+
# 4) A final MLP classifier.
|
16 |
+
###############################################################################
|
17 |
+
|
18 |
+
class SAGViTClassifier(nn.Module, PyTorchModelHubMixin):
|
19 |
+
"""
|
20 |
+
SAG-ViT: Scale-Aware Graph Attention Vision Transformer
|
21 |
+
|
22 |
+
This model integrates the following steps:
|
23 |
+
- Extract multi-scale features from images using a CNN backbone (EfficientNetv2 here).
|
24 |
+
- Partition the feature map into patches and build a graph where each node is a patch.
|
25 |
+
- Use a Graph Attention Network (GAT) to refine patch embeddings based on local spatial relationships.
|
26 |
+
- Utilize a Transformer encoder to model long-range dependencies and integrate multi-scale information.
|
27 |
+
- Finally, classify the resulting representation into desired classes.
|
28 |
+
|
29 |
+
Inputs:
|
30 |
+
- x (Tensor): Input images (B, 3, H, W)
|
31 |
+
|
32 |
+
Outputs:
|
33 |
+
- out (Tensor): Classification logits (B, num_classes)
|
34 |
+
"""
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
patch_size=(4,4),
|
38 |
+
num_classes=10,
|
39 |
+
d_model=64,
|
40 |
+
nhead=4,
|
41 |
+
num_layers=2,
|
42 |
+
dim_feedforward=64,
|
43 |
+
hidden_mlp_features=64,
|
44 |
+
in_channels=2560, # Derived from patch dimensions and CNN output channels
|
45 |
+
gcn_hidden=128,
|
46 |
+
gcn_out=64
|
47 |
+
):
|
48 |
+
super(SAGViTClassifier, self).__init__()
|
49 |
+
|
50 |
+
# CNN feature extractor (frozen pre-trained EfficientNetv2)
|
51 |
+
self.cnn = EfficientNetV2FeatureExtractor()
|
52 |
+
|
53 |
+
# Graph Attention Network to process patch embeddings
|
54 |
+
self.gcn = GATGNN(in_channels=in_channels, hidden_channels=gcn_hidden, out_channels=gcn_out)
|
55 |
+
|
56 |
+
# Learnable positional embedding for Transformer input
|
57 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, 1, d_model))
|
58 |
+
# Extra embedding token (similar to class token) to summarize global info
|
59 |
+
self.extra_embedding = nn.Parameter(torch.randn(1, d_model))
|
60 |
+
|
61 |
+
# Transformer encoder to capture long-range global dependencies
|
62 |
+
self.transformer_encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward)
|
63 |
+
|
64 |
+
# MLP classification head
|
65 |
+
self.mlp = MLPBlock(d_model, hidden_mlp_features, num_classes)
|
66 |
+
|
67 |
+
self.patch_size = patch_size
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
# Step 1: High-fidelity feature extraction from CNN
|
71 |
+
feature_map = self.cnn(x)
|
72 |
+
|
73 |
+
# Step 2: Build graphs from patches
|
74 |
+
G_global_batch, patches = build_graph_from_patches(feature_map, self.patch_size)
|
75 |
+
|
76 |
+
# Step 3: Convert to PyG Data format and batch
|
77 |
+
data_list = build_graph_data_from_patches(G_global_batch, patches)
|
78 |
+
device = x.device
|
79 |
+
batch = Batch.from_data_list(data_list).to(device)
|
80 |
+
|
81 |
+
# Step 4: GAT stage
|
82 |
+
x_gcn = self.gcn(batch)
|
83 |
+
|
84 |
+
# Step 5: Reshape GCN output back to (B, N, D)
|
85 |
+
# The number of patches per image is determined by patch size and feature map dimensions.
|
86 |
+
B = x.size(0)
|
87 |
+
D = x_gcn.size(-1)
|
88 |
+
# N is automatically inferred
|
89 |
+
# Thus x_gcn is (B, D) now. We need a sequence dimension for the Transformer.
|
90 |
+
# Let's treat each image-level embedding as one "patch token" plus an extra token:
|
91 |
+
patch_embeddings = x_gcn.unsqueeze(1) # (B, 1, D)
|
92 |
+
|
93 |
+
# Add positional embedding
|
94 |
+
patch_embeddings = patch_embeddings + self.positional_embedding # (B, 1, D)
|
95 |
+
|
96 |
+
# Add an extra learnable embedding (like a CLS token)
|
97 |
+
patch_embeddings = torch.cat([patch_embeddings, self.extra_embedding.unsqueeze(0).expand(B, -1, -1)], dim=1) # (B, 2, D)
|
98 |
+
|
99 |
+
# Step 6: Transformer encoder
|
100 |
+
x_trans = self.transformer_encoder(patch_embeddings)
|
101 |
+
|
102 |
+
# Step 7: Global pooling (here we just take the mean)
|
103 |
+
x_pooled = x_trans.mean(dim=1) # (B, D)
|
104 |
+
|
105 |
+
# Classification
|
106 |
+
out = self.mlp(x_pooled)
|
107 |
+
return out
|
tests/test_graph_construction.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
import networkx as nx
|
4 |
+
from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches
|
5 |
+
|
6 |
+
class TestGraphConstruction(unittest.TestCase):
|
7 |
+
def test_extract_patches_shape(self):
|
8 |
+
# Create a dummy feature map: B=2, C=16, H=32, W=32
|
9 |
+
feature_map = torch.randn(2, 16, 32, 32)
|
10 |
+
patches = extract_patches(feature_map, patch_size=(4,4))
|
11 |
+
# Check dimensions: after extraction,
|
12 |
+
# number_of_patches = (H/4)*(W/4) = 8*8=64 per image, total 2*64=128
|
13 |
+
self.assertEqual(patches.shape, (2, 64, 16, 4, 4))
|
14 |
+
|
15 |
+
def test_build_graph_from_patches_graph_structure(self):
|
16 |
+
feature_map = torch.randn(1, 16, 32, 32)
|
17 |
+
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
|
18 |
+
# 1 image => G_batch[0] is the graph
|
19 |
+
G = G_batch[0]
|
20 |
+
# We have 64 patches
|
21 |
+
self.assertEqual(len(G.nodes), 64)
|
22 |
+
# Check if edges exist (8-neighborhood).
|
23 |
+
# Interior nodes should have edges to neighbors.
|
24 |
+
# Just check a random node in the middle
|
25 |
+
node_index = 9 # assuming row=1, col=1 in an 8x8 grid
|
26 |
+
self.assertTrue(len(list(G.neighbors(node_index))) > 0)
|
27 |
+
|
28 |
+
def test_build_graph_data_from_patches_conversion(self):
|
29 |
+
feature_map = torch.randn(2, 16, 32, 32)
|
30 |
+
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
|
31 |
+
data_list = build_graph_data_from_patches(G_batch, patches)
|
32 |
+
self.assertEqual(len(data_list), 2)
|
33 |
+
# Check node feature shape
|
34 |
+
self.assertEqual(data_list[0].x.shape[1], 16*4*4) # C * patch_h * patch_w = 16*4*4=256
|
35 |
+
# Check edges are present
|
36 |
+
self.assertTrue(data_list[0].edge_index.shape[1] > 0)
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
unittest.main()
|
tests/test_model_components.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
|
4 |
+
from torch_geometric.data import Data
|
5 |
+
|
6 |
+
class TestModelComponents(unittest.TestCase):
|
7 |
+
def test_efficientnetv2_extractor_output_shape(self):
|
8 |
+
model = EfficientNetV2FeatureExtractor()
|
9 |
+
model.eval()
|
10 |
+
x = torch.randn(2, 3, 224, 224)
|
11 |
+
with torch.no_grad():
|
12 |
+
features = model(x)
|
13 |
+
# Check output shape - depends on inception intermediate layer
|
14 |
+
# Example: shape could be (2, 768, 8, 8) depending on the chosen layer
|
15 |
+
self.assertEqual(features.size(0), 2)
|
16 |
+
self.assertTrue(features.size(1) > 0)
|
17 |
+
self.assertTrue(features.size(2) > 0)
|
18 |
+
self.assertTrue(features.size(3) > 0)
|
19 |
+
|
20 |
+
def test_gatgnn_forward(self):
|
21 |
+
# Graph with 4 nodes, each node feature dim=256
|
22 |
+
x = torch.randn(4, 256)
|
23 |
+
edge_index = torch.tensor([[0,1,1,2],[1,0,2,3]], dtype=torch.long)
|
24 |
+
batch = torch.tensor([0,0,0,0])
|
25 |
+
data = Data(x=x, edge_index=edge_index, batch=batch)
|
26 |
+
|
27 |
+
gnn = GATGNN(in_channels=256, hidden_channels=64, out_channels=32)
|
28 |
+
output = gnn(data)
|
29 |
+
# After pooling: should be (batch_size, out_channels) = (1,32)
|
30 |
+
self.assertEqual(output.shape, (1, 32))
|
31 |
+
|
32 |
+
def test_transformer_encoder(self):
|
33 |
+
# (B, N, D) = (2, 10, 64)
|
34 |
+
x = torch.randn(2, 10, 64)
|
35 |
+
encoder = TransformerEncoder(d_model=64, nhead=4, num_layers=2, dim_feedforward=64)
|
36 |
+
out = encoder(x)
|
37 |
+
# same shape as input
|
38 |
+
self.assertEqual(out.shape, (2, 10, 64))
|
39 |
+
|
40 |
+
def test_mlp_block(self):
|
41 |
+
mlp = MLPBlock(in_features=64, hidden_features=128, out_features=10)
|
42 |
+
x = torch.randn(2, 64)
|
43 |
+
out = mlp(x)
|
44 |
+
self.assertEqual(out.shape, (2,10))
|
45 |
+
|
46 |
+
def test_efficientnetv2_freeze(self):
|
47 |
+
# Ensure params are frozen
|
48 |
+
model = EfficientNetV2FeatureExtractor()
|
49 |
+
for param in model.parameters():
|
50 |
+
self.assertFalse(param.requires_grad)
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
unittest.main()
|
tests/test_sag_vit_model.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
from sag_vit_model import SAGViTClassifier
|
4 |
+
|
5 |
+
class TestSAGViTModel(unittest.TestCase):
|
6 |
+
def test_forward_pass(self):
|
7 |
+
model = SAGViTClassifier(
|
8 |
+
patch_size=(4,4),
|
9 |
+
num_classes=10, # smaller num classes for test
|
10 |
+
d_model=64,
|
11 |
+
nhead=4,
|
12 |
+
num_layers=2,
|
13 |
+
dim_feedforward=64,
|
14 |
+
hidden_mlp_features=64,
|
15 |
+
in_channels=2560, # from patch dimension example
|
16 |
+
gcn_hidden=128,
|
17 |
+
gcn_out=64
|
18 |
+
)
|
19 |
+
model.eval()
|
20 |
+
x = torch.randn(2, 3, 224, 224)
|
21 |
+
with torch.no_grad():
|
22 |
+
out = model(x)
|
23 |
+
# Check output shape: (B, num_classes) = (2,10)
|
24 |
+
self.assertEqual(out.shape, (2,10))
|
25 |
+
|
26 |
+
def test_empty_input(self):
|
27 |
+
model = SAGViTClassifier()
|
28 |
+
# Passing an empty tensor should fail gracefully
|
29 |
+
with self.assertRaises(Exception):
|
30 |
+
model(torch.empty(0,3,224,224))
|
31 |
+
|
32 |
+
def test_invalid_input_dimensions(self):
|
33 |
+
model = SAGViTClassifier()
|
34 |
+
# Incorrect dimension (e.g., missing channel)
|
35 |
+
with self.assertRaises(RuntimeError):
|
36 |
+
model(torch.randn(2, 224, 224)) # no channel dimension
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
unittest.main()
|
tests/test_train.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from unittest.mock import MagicMock, patch
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from train import train_model
|
6 |
+
from sag_vit_model import SAGViTClassifier
|
7 |
+
|
8 |
+
class TestTrain(unittest.TestCase):
|
9 |
+
@patch("train.optim.Adam")
|
10 |
+
def test_train_model_loop(self, mock_adam):
|
11 |
+
# Mock the optimizer
|
12 |
+
mock_optimizer = MagicMock()
|
13 |
+
mock_adam.return_value = mock_optimizer
|
14 |
+
|
15 |
+
# Mock dataloaders with a small dummy dataset
|
16 |
+
# Just one batch with a couple of samples
|
17 |
+
train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
|
18 |
+
val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
|
19 |
+
|
20 |
+
model = SAGViTClassifier(num_classes=2)
|
21 |
+
|
22 |
+
criterion = nn.CrossEntropyLoss()
|
23 |
+
device = torch.device("cpu")
|
24 |
+
|
25 |
+
# Test a single epoch training
|
26 |
+
history = train_model(model, "TestModel", train_dataloader, val_dataloader,
|
27 |
+
num_epochs=1, criterion=criterion, optimizer=mock_optimizer, device=device, patience=2, verbose=False)
|
28 |
+
|
29 |
+
# Check if history is properly recorded
|
30 |
+
self.assertIn("train_loss", history)
|
31 |
+
self.assertIn("val_loss", history)
|
32 |
+
self.assertGreaterEqual(len(history["train_loss"]), 1)
|
33 |
+
self.assertGreaterEqual(len(history["val_loss"]), 1)
|
34 |
+
|
35 |
+
def test_early_stopping(self):
|
36 |
+
# Mocking dataloaders where validation loss doesn't improve
|
37 |
+
model = SAGViTClassifier(num_classes=2)
|
38 |
+
criterion = nn.CrossEntropyLoss()
|
39 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
40 |
+
device = torch.device("cpu")
|
41 |
+
|
42 |
+
# create a scenario where val loss won't improve
|
43 |
+
# first epoch normal, second epoch slightly worse
|
44 |
+
train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
|
45 |
+
val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ]
|
46 |
+
|
47 |
+
history = train_model(model, "TestModelEarlyStop", train_dataloader, val_dataloader,
|
48 |
+
num_epochs=5, criterion=criterion, optimizer=optimizer, device=device, patience=1, verbose=False)
|
49 |
+
|
50 |
+
# Should have triggered early stopping before all 5 epochs
|
51 |
+
self.assertLessEqual(len(history["train_loss"]), 5)
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
unittest.main()
|
train.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch import nn, optim
|
4 |
+
from tqdm import tqdm
|
5 |
+
from huggingface_hub import HfApi
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import (precision_score, recall_score, f1_score,
|
8 |
+
roc_auc_score, cohen_kappa_score, matthews_corrcoef,
|
9 |
+
confusion_matrix)
|
10 |
+
|
11 |
+
from sag_vit_model import SAGViTClassifier
|
12 |
+
from data_loader import get_dataloaders
|
13 |
+
|
14 |
+
#####################################################################
|
15 |
+
# This file provides the training loop and metric computation. It uses
|
16 |
+
# the SAG-ViT model defined in sag_vit_model.py, and the data from data_loader.py.
|
17 |
+
# The training loop is adapted to implement early stopping and track various metrics.
|
18 |
+
#####################################################################
|
19 |
+
|
20 |
+
def train_model(model, model_name, train_loader, val_loader, num_epochs, criterion, optimizer, device, patience=8, verbose=True):
|
21 |
+
"""
|
22 |
+
Trains the SAG-ViT model and evaluates it on the validation set.
|
23 |
+
Implements early stopping based on validation loss.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
- model (nn.Module): The SAG-ViT model.
|
27 |
+
- model_name (str): A name to identify the model (used for saving checkpoints).
|
28 |
+
- train_loader, val_loader: DataLoaders for training and validation.
|
29 |
+
- num_epochs (int): Maximum number of epochs.
|
30 |
+
- criterion (nn.Module): Loss function.
|
31 |
+
- optimizer (torch.optim.Optimizer): Optimization algorithm.
|
32 |
+
- device (torch.device): Device to run the computations on (CPU/GPU).
|
33 |
+
- patience (int): Early stopping patience.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
- history (dict): Dictionary containing training and validation metrics per epoch.
|
37 |
+
"""
|
38 |
+
|
39 |
+
history = {
|
40 |
+
'train_loss': [], 'train_acc': [], 'train_prec': [], 'train_rec': [], 'train_f1': [],
|
41 |
+
'train_auc': [], 'train_mcc': [], 'train_cohen_kappa': [], 'train_confusion_matrix': [],
|
42 |
+
'val_loss': [], 'val_acc': [], 'val_prec': [], 'val_rec': [], 'val_f1': [],
|
43 |
+
'val_auc': [], 'val_mcc': [], 'val_cohen_kappa': [], 'val_confusion_matrix': []
|
44 |
+
}
|
45 |
+
|
46 |
+
best_val_loss = float('inf')
|
47 |
+
patience_counter = 0
|
48 |
+
best_model_state = None
|
49 |
+
|
50 |
+
for epoch in range(num_epochs):
|
51 |
+
print(f'Epoch {epoch+1}/{num_epochs}')
|
52 |
+
model.train()
|
53 |
+
|
54 |
+
train_loss_total, correct, total = 0, 0, 0
|
55 |
+
all_preds, all_labels, all_probs = [], [], []
|
56 |
+
|
57 |
+
# Training loop
|
58 |
+
for batch_idx, (X, y) in enumerate(tqdm(train_loader)):
|
59 |
+
inputs, labels = X.to(device), y.to(device)
|
60 |
+
optimizer.zero_grad()
|
61 |
+
|
62 |
+
outputs = model(inputs)
|
63 |
+
loss = criterion(outputs, labels)
|
64 |
+
loss.backward()
|
65 |
+
optimizer.step()
|
66 |
+
|
67 |
+
train_loss_total += loss.item()
|
68 |
+
|
69 |
+
probs = torch.softmax(outputs, dim=1)
|
70 |
+
_, preds = torch.max(outputs, 1)
|
71 |
+
correct += (preds == labels).sum().item()
|
72 |
+
total += labels.size(0)
|
73 |
+
|
74 |
+
all_preds.extend(preds.cpu().numpy())
|
75 |
+
all_labels.extend(labels.cpu().numpy())
|
76 |
+
all_probs.extend(probs.detach().cpu().numpy())
|
77 |
+
|
78 |
+
# Compute training metrics
|
79 |
+
train_acc = correct / total
|
80 |
+
train_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
|
81 |
+
train_rec = recall_score(all_labels, all_preds, average='macro')
|
82 |
+
train_f1 = f1_score(all_labels, all_preds, average='macro')
|
83 |
+
train_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
|
84 |
+
train_mcc = matthews_corrcoef(all_labels, all_preds)
|
85 |
+
train_confusion = confusion_matrix(all_labels, all_preds)
|
86 |
+
|
87 |
+
history['train_loss'].append(train_loss_total / len(train_loader))
|
88 |
+
history['train_acc'].append(train_acc)
|
89 |
+
history['train_prec'].append(train_prec)
|
90 |
+
history['train_rec'].append(train_rec)
|
91 |
+
history['train_f1'].append(train_f1)
|
92 |
+
history['train_cohen_kappa'].append(train_cohen_kappa)
|
93 |
+
history['train_mcc'].append(train_mcc)
|
94 |
+
history['train_confusion_matrix'].append(train_confusion)
|
95 |
+
|
96 |
+
# Validation
|
97 |
+
model.eval()
|
98 |
+
val_loss_total, correct, total = 0, 0, 0
|
99 |
+
all_preds, all_labels, all_probs = [], [], []
|
100 |
+
|
101 |
+
with torch.no_grad():
|
102 |
+
for batch_idx, (X, y) in enumerate(tqdm(val_loader)):
|
103 |
+
inputs, labels = X.to(device), y.to(device)
|
104 |
+
outputs = model(inputs)
|
105 |
+
loss = criterion(outputs, labels)
|
106 |
+
|
107 |
+
val_loss_total += loss.item()
|
108 |
+
probs = torch.softmax(outputs, dim=1)
|
109 |
+
_, preds = torch.max(outputs, 1)
|
110 |
+
correct += (preds == labels).sum().item()
|
111 |
+
total += labels.size(0)
|
112 |
+
|
113 |
+
all_preds.extend(preds.cpu().numpy())
|
114 |
+
all_labels.extend(labels.cpu().numpy())
|
115 |
+
all_probs.extend(probs.detach().cpu().numpy())
|
116 |
+
|
117 |
+
# Compute validation metrics
|
118 |
+
val_acc = correct / total
|
119 |
+
val_prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
|
120 |
+
val_rec = recall_score(all_labels, all_preds, average='macro')
|
121 |
+
val_f1 = f1_score(all_labels, all_preds, average='macro')
|
122 |
+
val_cohen_kappa = cohen_kappa_score(all_labels, all_preds)
|
123 |
+
val_mcc = matthews_corrcoef(all_labels, all_preds)
|
124 |
+
val_confusion = confusion_matrix(all_labels, all_preds)
|
125 |
+
|
126 |
+
history['val_loss'].append(val_loss_total / len(val_loader))
|
127 |
+
history['val_acc'].append(val_acc)
|
128 |
+
history['val_prec'].append(val_prec)
|
129 |
+
history['val_rec'].append(val_rec)
|
130 |
+
history['val_f1'].append(val_f1)
|
131 |
+
history['val_cohen_kappa'].append(val_cohen_kappa)
|
132 |
+
history['val_mcc'].append(val_mcc)
|
133 |
+
history['val_confusion_matrix'].append(val_confusion)
|
134 |
+
|
135 |
+
# Print epoch summary
|
136 |
+
if verbose:
|
137 |
+
print(f"Train Loss: {history['train_loss'][-1]:.4f}, Train Acc: {history['train_acc'][-1]:.4f}, "
|
138 |
+
f"Val Loss: {history['val_loss'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.4f}")
|
139 |
+
|
140 |
+
# Early stopping
|
141 |
+
current_val_loss = history['val_loss'][-1]
|
142 |
+
if current_val_loss < best_val_loss:
|
143 |
+
best_val_loss = current_val_loss
|
144 |
+
best_model_state = model.state_dict()
|
145 |
+
patience_counter = 0
|
146 |
+
else:
|
147 |
+
patience_counter += 1
|
148 |
+
print(f"Patience counter: {patience_counter}/{patience}")
|
149 |
+
if patience_counter >= patience:
|
150 |
+
print("Early stopping triggered.")
|
151 |
+
model.load_state_dict(best_model_state)
|
152 |
+
torch.save(model.state_dict(), f'{model_name}.pth')
|
153 |
+
return history
|
154 |
+
|
155 |
+
model.load_state_dict(best_model_state)
|
156 |
+
torch.save(model.state_dict(), f'{model_name}.pth')
|
157 |
+
|
158 |
+
return history
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
163 |
+
print(f"Training on device: {device}")
|
164 |
+
data_dir = "data/PlantVillage" # "path/to/data/dir"
|
165 |
+
num_classes = len(os.listdir(data_dir))
|
166 |
+
train_loader, val_loader = get_dataloaders(data_dir=data_dir, img_size=224, batch_size=32) # Minimum image size should be atleast (49, 49)
|
167 |
+
|
168 |
+
model = SAGViTClassifier(num_classes=num_classes).to(device)
|
169 |
+
|
170 |
+
criterion = nn.CrossEntropyLoss()
|
171 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
172 |
+
num_epochs = 100
|
173 |
+
|
174 |
+
history = train_model(
|
175 |
+
model,
|
176 |
+
'SAG-ViT',
|
177 |
+
train_loader,
|
178 |
+
val_loader,
|
179 |
+
num_epochs,
|
180 |
+
criterion,
|
181 |
+
optimizer,
|
182 |
+
device
|
183 |
+
)
|
184 |
+
|
185 |
+
# You may save history to a CSV or analyze it further as needed.
|
186 |
+
# Example:
|
187 |
+
# import pandas as pd
|
188 |
+
# history_df = pd.DataFrame(history)
|
189 |
+
# history_df.to_csv("training_history.csv", index=False)
|
190 |
+
|
191 |
+
# Load the saved model back (best practice before pushing)
|
192 |
+
model.load_state_dict(torch.load("SAG-ViT.pth"))
|
193 |
+
model.eval()
|
194 |
+
|
195 |
+
# Push the model to the Hugging Face Hub
|
196 |
+
model.push_to_hub("shravvvv/SAG-ViT", commit_message="Initial model push", private=True, trust_remote_code=True)
|