Caleb Spradlin
initial commit
ab687e7
raw
history blame
1.79 kB
from pytorch_caney.models.build import build_model
from pytorch_caney.config import get_config
import unittest
import argparse
import logging
class TestBuildModel(unittest.TestCase):
def setUp(self):
# Initialize any required configuration here
config_path = 'pytorch_caney/' + \
'tests/config/test_config.yaml'
args = argparse.Namespace(cfg=config_path)
self.config = get_config(args)
self.logger = logging.getLogger("TestLogger")
self.logger.setLevel(logging.DEBUG)
def test_build_mim_model(self):
_ = build_model(self.config,
pretrain=True,
pretrain_method='mim',
logger=self.logger)
# Add assertions here to validate the returned 'model' instance
# For example: self.assertIsInstance(model, YourMimModelClass)
def test_build_swinv2_encoder(self):
_ = build_model(self.config, logger=self.logger)
# Add assertions here to validate the returned 'model' instance
# For example: self.assertIsInstance(model, SwinTransformerV2)
def test_build_unet_decoder(self):
self.config.defrost()
self.config.MODEL.DECODER = 'unet'
self.config.freeze()
_ = build_model(self.config, logger=self.logger)
# Add assertions here to validate the returned 'model' instance
# For example: self.assertIsInstance(model, YourUnetSwinModelClass)
def test_unknown_decoder_architecture(self):
self.config.defrost()
self.config.MODEL.DECODER = 'unknown_decoder'
self.config.freeze()
with self.assertRaises(NotImplementedError):
build_model(self.config, logger=self.logger)
if __name__ == '__main__':
unittest.main()