#!/usr/bin/env python # -*- coding:utf-8 -*- # Author: Donny You(youansheng@gmail.com) import torch.nn as nn from networks.resnet_models import * class NormalResnetBackbone(nn.Module): def __init__(self, orig_resnet): super(NormalResnetBackbone, self).__init__() self.num_features = 2048 # take pretrained resnet, except AvgPool and FC self.prefix = orig_resnet.prefix self.maxpool = orig_resnet.maxpool self.layer1 = orig_resnet.layer1 self.layer2 = orig_resnet.layer2 self.layer3 = orig_resnet.layer3 self.layer4 = orig_resnet.layer4 def get_num_features(self): return self.num_features def forward(self, x): tuple_features = list() x = self.prefix(x) x = self.maxpool(x) x = self.layer1(x) tuple_features.append(x) x = self.layer2(x) tuple_features.append(x) x = self.layer3(x) tuple_features.append(x) x = self.layer4(x) tuple_features.append(x) return tuple_features class DilatedResnetBackbone(nn.Module): def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)): super(DilatedResnetBackbone, self).__init__() self.num_features = 2048 from functools import partial if dilate_scale == 8: orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) if multi_grid is None: orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) else: for i, r in enumerate(multi_grid): orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r))) elif dilate_scale == 16: if multi_grid is None: orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) else: for i, r in enumerate(multi_grid): orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r))) # Take pretrained resnet, except AvgPool and FC self.prefix = orig_resnet.prefix self.maxpool = orig_resnet.maxpool self.layer1 = orig_resnet.layer1 self.layer2 = orig_resnet.layer2 self.layer3 = orig_resnet.layer3 self.layer4 = orig_resnet.layer4 def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate // 2, dilate // 2) m.padding = (dilate // 2, dilate // 2) # other convoluions else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def get_num_features(self): return self.num_features def forward(self, x): tuple_features = list() x = self.prefix(x) x = self.maxpool(x) x = self.layer1(x) tuple_features.append(x) x = self.layer2(x) tuple_features.append(x) x = self.layer3(x) tuple_features.append(x) x = self.layer4(x) tuple_features.append(x) return tuple_features def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'): arch = backbone if arch == 'resnet18': orig_resnet = resnet18(pretrained=pretrained) arch_net = NormalResnetBackbone(orig_resnet) arch_net.num_features = 512 elif arch == 'resnet18_dilated8': orig_resnet = resnet18(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) arch_net.num_features = 512 elif arch == 'resnet34': orig_resnet = resnet34(pretrained=pretrained) arch_net = NormalResnetBackbone(orig_resnet) arch_net.num_features = 512 elif arch == 'resnet34_dilated8': orig_resnet = resnet34(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) arch_net.num_features = 512 elif arch == 'resnet34_dilated16': orig_resnet = resnet34(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) arch_net.num_features = 512 elif arch == 'resnet50': orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) arch_net = NormalResnetBackbone(orig_resnet) elif arch == 'resnet50_dilated8': orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) elif arch == 'resnet50_dilated16': orig_resnet = resnet50(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) elif arch == 'deepbase_resnet50': if pretrained: pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' orig_resnet = deepbase_resnet50(pretrained=pretrained) arch_net = NormalResnetBackbone(orig_resnet) elif arch == 'deepbase_resnet50_dilated8': if pretrained: pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth" orig_resnet = deepbase_resnet50(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) elif arch == 'deepbase_resnet50_dilated16': orig_resnet = deepbase_resnet50(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) elif arch == 'resnet101': orig_resnet = resnet101(pretrained=pretrained) arch_net = NormalResnetBackbone(orig_resnet) elif arch == 'resnet101_dilated8': orig_resnet = resnet101(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) elif arch == 'resnet101_dilated16': orig_resnet = resnet101(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) elif arch == 'deepbase_resnet101': orig_resnet = deepbase_resnet101(pretrained=pretrained) arch_net = NormalResnetBackbone(orig_resnet) elif arch == 'deepbase_resnet101_dilated8': if pretrained: pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth' orig_resnet = deepbase_resnet101(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) elif arch == 'deepbase_resnet101_dilated16': orig_resnet = deepbase_resnet101(pretrained=pretrained) arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) else: raise Exception('Architecture undefined!') return arch_net