File size: 2,082 Bytes
039647a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import unittest
import torch
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
from torch_geometric.data import Data
class TestModelComponents(unittest.TestCase):
def test_efficientnetv2_extractor_output_shape(self):
model = EfficientNetV2FeatureExtractor()
model.eval()
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
features = model(x)
# Check output shape - depends on inception intermediate layer
# Example: shape could be (2, 768, 8, 8) depending on the chosen layer
self.assertEqual(features.size(0), 2)
self.assertTrue(features.size(1) > 0)
self.assertTrue(features.size(2) > 0)
self.assertTrue(features.size(3) > 0)
def test_gatgnn_forward(self):
# Graph with 4 nodes, each node feature dim=256
x = torch.randn(4, 256)
edge_index = torch.tensor([[0,1,1,2],[1,0,2,3]], dtype=torch.long)
batch = torch.tensor([0,0,0,0])
data = Data(x=x, edge_index=edge_index, batch=batch)
gnn = GATGNN(in_channels=256, hidden_channels=64, out_channels=32)
output = gnn(data)
# After pooling: should be (batch_size, out_channels) = (1,32)
self.assertEqual(output.shape, (1, 32))
def test_transformer_encoder(self):
# (B, N, D) = (2, 10, 64)
x = torch.randn(2, 10, 64)
encoder = TransformerEncoder(d_model=64, nhead=4, num_layers=2, dim_feedforward=64)
out = encoder(x)
# same shape as input
self.assertEqual(out.shape, (2, 10, 64))
def test_mlp_block(self):
mlp = MLPBlock(in_features=64, hidden_features=128, out_features=10)
x = torch.randn(2, 64)
out = mlp(x)
self.assertEqual(out.shape, (2,10))
def test_efficientnetv2_freeze(self):
# Ensure params are frozen
model = EfficientNetV2FeatureExtractor()
for param in model.parameters():
self.assertFalse(param.requires_grad)
if __name__ == '__main__':
unittest.main()
|