|
from torchvision.models import ResNet |
|
from torchvision.models.resnet import Bottleneck, BasicBlock |
|
from .csra import CSRA, MHA |
|
import torch.utils.model_zoo as model_zoo |
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
model_urls = { |
|
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", |
|
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", |
|
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", |
|
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", |
|
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", |
|
} |
|
|
|
|
|
class ResNet_CSRA(ResNet): |
|
arch_settings = { |
|
18: (BasicBlock, (2, 2, 2, 2)), |
|
34: (BasicBlock, (3, 4, 6, 3)), |
|
50: (Bottleneck, (3, 4, 6, 3)), |
|
101: (Bottleneck, (3, 4, 23, 3)), |
|
152: (Bottleneck, (3, 8, 36, 3)), |
|
} |
|
|
|
def __init__( |
|
self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None |
|
): |
|
self.block, self.layers = self.arch_settings[depth] |
|
self.depth = depth |
|
super(ResNet_CSRA, self).__init__(self.block, self.layers) |
|
self.init_weights(pretrained=True, cutmix=cutmix) |
|
|
|
self.classifier = MHA(num_heads, lam, input_dim, num_classes) |
|
self.loss_func = F.binary_cross_entropy_with_logits |
|
|
|
|
|
|
|
def backbone(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
|
|
return x |
|
|
|
def forward_train(self, x, target): |
|
x = self.backbone(x) |
|
logit = self.classifier(x) |
|
loss = self.loss_func(logit, target, reduction="mean") |
|
return logit, loss |
|
|
|
def forward_test(self, x): |
|
x = self.backbone(x) |
|
x = self.classifier(x) |
|
return x |
|
|
|
def forward(self, x, target=None): |
|
if target is not None: |
|
return self.forward_train(x, target) |
|
else: |
|
return self.forward_test(x) |
|
|
|
def init_weights(self, pretrained=True, cutmix=None): |
|
if cutmix is not None: |
|
print("backbone params inited by CutMix pretrained model") |
|
state_dict = torch.load(cutmix) |
|
elif pretrained: |
|
print("backbone params inited by Pytorch official model") |
|
model_url = model_urls["resnet{}".format(self.depth)] |
|
state_dict = model_zoo.load_url(model_url) |
|
|
|
model_dict = self.state_dict() |
|
try: |
|
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} |
|
self.load_state_dict(pretrained_dict) |
|
except: |
|
logger = logging.getLogger() |
|
logger.info( |
|
"the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix..." |
|
) |
|
state_dict = self._keysFix(model_dict, state_dict) |
|
self.load_state_dict(state_dict) |
|
|
|
|
|
self.fc = nn.Sequential() |
|
|