# Copyright (c) OpenMMLab. All rights reserved. import torch from mmpose.models.backbones import AlexNet def test_alexnet_backbone(): """Test alexnet backbone.""" model = AlexNet(-1) model.train() imgs = torch.randn(1, 3, 256, 192) feat = model(imgs) assert feat.shape == (1, 256, 7, 5) model = AlexNet(1) model.train() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat.shape == (1, 1)