|
import unittest |
|
from unittest.mock import MagicMock, patch |
|
import torch |
|
import torch.nn as nn |
|
from train import train_model |
|
from modeling_sagvit import SAGViTClassifier |
|
|
|
class TestTrain(unittest.TestCase): |
|
@patch("train.optim.Adam") |
|
def test_train_model_loop(self, mock_adam): |
|
|
|
mock_optimizer = MagicMock() |
|
mock_adam.return_value = mock_optimizer |
|
|
|
|
|
|
|
train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ] |
|
val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ] |
|
|
|
model = SAGViTClassifier(num_classes=2) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
device = torch.device("cpu") |
|
|
|
|
|
history = train_model(model, "TestModel", train_dataloader, val_dataloader, |
|
num_epochs=1, criterion=criterion, optimizer=mock_optimizer, device=device, patience=2, verbose=False) |
|
|
|
|
|
self.assertIn("train_loss", history) |
|
self.assertIn("val_loss", history) |
|
self.assertGreaterEqual(len(history["train_loss"]), 1) |
|
self.assertGreaterEqual(len(history["val_loss"]), 1) |
|
|
|
def test_early_stopping(self): |
|
|
|
model = SAGViTClassifier(num_classes=2) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
train_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ] |
|
val_dataloader = [ (torch.randn(2,3,224,224), torch.tensor([0,1])) ] |
|
|
|
history = train_model(model, "TestModelEarlyStop", train_dataloader, val_dataloader, |
|
num_epochs=5, criterion=criterion, optimizer=optimizer, device=device, patience=1, verbose=False) |
|
|
|
|
|
self.assertLessEqual(len(history["train_loss"]), 5) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|