from torchvision.models import ResNet from torchvision.models.resnet import Bottleneck, BasicBlock from .csra import CSRA, MHA import torch.utils.model_zoo as model_zoo import logging import torch import torch.nn as nn import torch.nn.functional as F model_urls = { "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", } class ResNet_CSRA(ResNet): arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)), } def __init__( self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None ): self.block, self.layers = self.arch_settings[depth] self.depth = depth super(ResNet_CSRA, self).__init__(self.block, self.layers) self.init_weights(pretrained=True, cutmix=cutmix) self.classifier = MHA(num_heads, lam, input_dim, num_classes) self.loss_func = F.binary_cross_entropy_with_logits # todo # criterion = nn.BCEWithLogitsLoss() # loss combines a Sigmoid layer and the BCELoss in one single class def backbone(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x def forward_train(self, x, target): x = self.backbone(x) logit = self.classifier(x) loss = self.loss_func(logit, target, reduction="mean") return logit, loss def forward_test(self, x): x = self.backbone(x) x = self.classifier(x) return x def forward(self, x, target=None): if target is not None: return self.forward_train(x, target) else: return self.forward_test(x) def init_weights(self, pretrained=True, cutmix=None): if cutmix is not None: print("backbone params inited by CutMix pretrained model") state_dict = torch.load(cutmix) elif pretrained: print("backbone params inited by Pytorch official model") model_url = model_urls["resnet{}".format(self.depth)] state_dict = model_zoo.load_url(model_url) model_dict = self.state_dict() try: pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} self.load_state_dict(pretrained_dict) except: logger = logging.getLogger() logger.info( "the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix..." ) state_dict = self._keysFix(model_dict, state_dict) self.load_state_dict(state_dict) # remove the original 1000-class fc self.fc = nn.Sequential()