selfmask / networks /resnet_backbone.py
noelshin's picture
Add application file
35188e4
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Donny You([email protected])
import torch.nn as nn
from networks.resnet_models import *
class NormalResnetBackbone(nn.Module):
def __init__(self, orig_resnet):
super(NormalResnetBackbone, self).__init__()
self.num_features = 2048
# take pretrained resnet, except AvgPool and FC
self.prefix = orig_resnet.prefix
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def get_num_features(self):
return self.num_features
def forward(self, x):
tuple_features = list()
x = self.prefix(x)
x = self.maxpool(x)
x = self.layer1(x)
tuple_features.append(x)
x = self.layer2(x)
tuple_features.append(x)
x = self.layer3(x)
tuple_features.append(x)
x = self.layer4(x)
tuple_features.append(x)
return tuple_features
class DilatedResnetBackbone(nn.Module):
def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
super(DilatedResnetBackbone, self).__init__()
self.num_features = 2048
from functools import partial
if dilate_scale == 8:
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
if multi_grid is None:
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
else:
for i, r in enumerate(multi_grid):
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
elif dilate_scale == 16:
if multi_grid is None:
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
else:
for i, r in enumerate(multi_grid):
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
# Take pretrained resnet, except AvgPool and FC
self.prefix = orig_resnet.prefix
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
# the convolution with stride
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate // 2, dilate // 2)
m.padding = (dilate // 2, dilate // 2)
# other convoluions
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def get_num_features(self):
return self.num_features
def forward(self, x):
tuple_features = list()
x = self.prefix(x)
x = self.maxpool(x)
x = self.layer1(x)
tuple_features.append(x)
x = self.layer2(x)
tuple_features.append(x)
x = self.layer3(x)
tuple_features.append(x)
x = self.layer4(x)
tuple_features.append(x)
return tuple_features
def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
arch = backbone
if arch == 'resnet18':
orig_resnet = resnet18(pretrained=pretrained)
arch_net = NormalResnetBackbone(orig_resnet)
arch_net.num_features = 512
elif arch == 'resnet18_dilated8':
orig_resnet = resnet18(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
arch_net.num_features = 512
elif arch == 'resnet34':
orig_resnet = resnet34(pretrained=pretrained)
arch_net = NormalResnetBackbone(orig_resnet)
arch_net.num_features = 512
elif arch == 'resnet34_dilated8':
orig_resnet = resnet34(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
arch_net.num_features = 512
elif arch == 'resnet34_dilated16':
orig_resnet = resnet34(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
arch_net.num_features = 512
elif arch == 'resnet50':
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
arch_net = NormalResnetBackbone(orig_resnet)
elif arch == 'resnet50_dilated8':
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
elif arch == 'resnet50_dilated16':
orig_resnet = resnet50(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
elif arch == 'deepbase_resnet50':
if pretrained:
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
orig_resnet = deepbase_resnet50(pretrained=pretrained)
arch_net = NormalResnetBackbone(orig_resnet)
elif arch == 'deepbase_resnet50_dilated8':
if pretrained:
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
# pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
orig_resnet = deepbase_resnet50(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
elif arch == 'deepbase_resnet50_dilated16':
orig_resnet = deepbase_resnet50(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
elif arch == 'resnet101':
orig_resnet = resnet101(pretrained=pretrained)
arch_net = NormalResnetBackbone(orig_resnet)
elif arch == 'resnet101_dilated8':
orig_resnet = resnet101(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
elif arch == 'resnet101_dilated16':
orig_resnet = resnet101(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
elif arch == 'deepbase_resnet101':
orig_resnet = deepbase_resnet101(pretrained=pretrained)
arch_net = NormalResnetBackbone(orig_resnet)
elif arch == 'deepbase_resnet101_dilated8':
if pretrained:
pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
orig_resnet = deepbase_resnet101(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
elif arch == 'deepbase_resnet101_dilated16':
orig_resnet = deepbase_resnet101(pretrained=pretrained)
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
else:
raise Exception('Architecture undefined!')
return arch_net