|
import unittest |
|
import torch |
|
from modeling_sagvit import SAGViTClassifier |
|
|
|
class TestSAGViTModel(unittest.TestCase): |
|
def test_forward_pass(self): |
|
model = SAGViTClassifier( |
|
patch_size=(4,4), |
|
num_classes=10, |
|
d_model=64, |
|
nhead=4, |
|
num_layers=2, |
|
dim_feedforward=64, |
|
hidden_mlp_features=64, |
|
in_channels=2560, |
|
gcn_hidden=128, |
|
gcn_out=64 |
|
) |
|
model.eval() |
|
x = torch.randn(2, 3, 224, 224) |
|
with torch.no_grad(): |
|
out = model(x) |
|
|
|
self.assertEqual(out.shape, (2,10)) |
|
|
|
def test_empty_input(self): |
|
model = SAGViTClassifier() |
|
|
|
with self.assertRaises(Exception): |
|
model(torch.empty(0,3,224,224)) |
|
|
|
def test_invalid_input_dimensions(self): |
|
model = SAGViTClassifier() |
|
|
|
with self.assertRaises(RuntimeError): |
|
model(torch.randn(2, 224, 224)) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|