File size: 2,082 Bytes
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import unittest
import torch
from model_components import EfficientNetV2FeatureExtractor, GATGNN, TransformerEncoder, MLPBlock
from torch_geometric.data import Data

class TestModelComponents(unittest.TestCase):
    def test_efficientnetv2_extractor_output_shape(self):
        model = EfficientNetV2FeatureExtractor()
        model.eval()
        x = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            features = model(x)
        # Check output shape - depends on inception intermediate layer
        # Example: shape could be (2, 768, 8, 8) depending on the chosen layer
        self.assertEqual(features.size(0), 2)
        self.assertTrue(features.size(1) > 0)
        self.assertTrue(features.size(2) > 0)
        self.assertTrue(features.size(3) > 0)

    def test_gatgnn_forward(self):
        # Graph with 4 nodes, each node feature dim=256
        x = torch.randn(4, 256)
        edge_index = torch.tensor([[0,1,1,2],[1,0,2,3]], dtype=torch.long)
        batch = torch.tensor([0,0,0,0])
        data = Data(x=x, edge_index=edge_index, batch=batch)
        
        gnn = GATGNN(in_channels=256, hidden_channels=64, out_channels=32)
        output = gnn(data)
        # After pooling: should be (batch_size, out_channels) = (1,32)
        self.assertEqual(output.shape, (1, 32))

    def test_transformer_encoder(self):
        # (B, N, D) = (2, 10, 64)
        x = torch.randn(2, 10, 64)
        encoder = TransformerEncoder(d_model=64, nhead=4, num_layers=2, dim_feedforward=64)
        out = encoder(x)
        # same shape as input
        self.assertEqual(out.shape, (2, 10, 64))

    def test_mlp_block(self):
        mlp = MLPBlock(in_features=64, hidden_features=128, out_features=10)
        x = torch.randn(2, 64)
        out = mlp(x)
        self.assertEqual(out.shape, (2,10))

    def test_efficientnetv2_freeze(self):
        # Ensure params are frozen
        model = EfficientNetV2FeatureExtractor()
        for param in model.parameters():
            self.assertFalse(param.requires_grad)

if __name__ == '__main__':
    unittest.main()