Spaces:
Sleeping
Sleeping
from torchvision.models import ResNet | |
from torchvision.models.resnet import Bottleneck, BasicBlock | |
from .csra import CSRA, MHA | |
import torch.utils.model_zoo as model_zoo | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
model_urls = { | |
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", | |
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", | |
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", | |
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", | |
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", | |
} | |
class ResNet_CSRA(ResNet): | |
arch_settings = { | |
18: (BasicBlock, (2, 2, 2, 2)), | |
34: (BasicBlock, (3, 4, 6, 3)), | |
50: (Bottleneck, (3, 4, 6, 3)), | |
101: (Bottleneck, (3, 4, 23, 3)), | |
152: (Bottleneck, (3, 8, 36, 3)), | |
} | |
def __init__( | |
self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None | |
): | |
self.block, self.layers = self.arch_settings[depth] | |
self.depth = depth | |
super(ResNet_CSRA, self).__init__(self.block, self.layers) | |
self.init_weights(pretrained=True, cutmix=cutmix) | |
self.classifier = MHA(num_heads, lam, input_dim, num_classes) | |
self.loss_func = F.binary_cross_entropy_with_logits | |
# todo | |
# criterion = nn.BCEWithLogitsLoss() # loss combines a Sigmoid layer and the BCELoss in one single class | |
def backbone(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
return x | |
def forward_train(self, x, target): | |
x = self.backbone(x) | |
logit = self.classifier(x) | |
loss = self.loss_func(logit, target, reduction="mean") | |
return logit, loss | |
def forward_test(self, x): | |
x = self.backbone(x) | |
x = self.classifier(x) | |
return x | |
def forward(self, x, target=None): | |
if target is not None: | |
return self.forward_train(x, target) | |
else: | |
return self.forward_test(x) | |
def init_weights(self, pretrained=True, cutmix=None): | |
if cutmix is not None: | |
print("backbone params inited by CutMix pretrained model") | |
state_dict = torch.load(cutmix) | |
elif pretrained: | |
print("backbone params inited by Pytorch official model") | |
model_url = model_urls["resnet{}".format(self.depth)] | |
state_dict = model_zoo.load_url(model_url) | |
model_dict = self.state_dict() | |
try: | |
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} | |
self.load_state_dict(pretrained_dict) | |
except: | |
logger = logging.getLogger() | |
logger.info( | |
"the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix..." | |
) | |
state_dict = self._keysFix(model_dict, state_dict) | |
self.load_state_dict(state_dict) | |
# remove the original 1000-class fc | |
self.fc = nn.Sequential() | |