mfrashad's picture
Init code
8f87579
import torch
import torch.nn as nn
import torchvision
from . import resnet, resnext
try:
from lib.nn import SynchronizedBatchNorm2d
except ImportError:
from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d
class SegmentationModuleBase(nn.Module):
def __init__(self):
super(SegmentationModuleBase, self).__init__()
def pixel_acc(self, pred, label):
_, preds = torch.max(pred, dim=1)
valid = (label >= 0).long()
acc_sum = torch.sum(valid * (preds == label).long())
pixel_sum = torch.sum(valid)
acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
return acc
class SegmentationModule(SegmentationModuleBase):
def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
super(SegmentationModule, self).__init__()
self.encoder = net_enc
self.decoder = net_dec
self.crit = crit
self.deep_sup_scale = deep_sup_scale
def forward(self, feed_dict, *, segSize=None):
if segSize is None: # training
if self.deep_sup_scale is not None: # use deep supervision technique
(pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
else:
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
loss = self.crit(pred, feed_dict['seg_label'])
if self.deep_sup_scale is not None:
loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
loss = loss + loss_deepsup * self.deep_sup_scale
acc = self.pixel_acc(pred, feed_dict['seg_label'])
return loss, acc
else: # inference
pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
return pred
def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=has_bias)
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
return nn.Sequential(
conv3x3(in_planes, out_planes, stride),
SynchronizedBatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)
class ModelBuilder():
# custom weights initialization
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
#elif classname.find('Linear') != -1:
# m.weight.data.normal_(0.0, 0.0001)
def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
pretrained = True if len(weights) == 0 else False
if arch == 'resnet34':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet34_dilated8':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=8)
elif arch == 'resnet34_dilated16':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=16)
elif arch == 'resnet50':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet50_dilated8':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=8)
elif arch == 'resnet50_dilated16':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=16)
elif arch == 'resnet101':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet101_dilated8':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=8)
elif arch == 'resnet101_dilated16':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=16)
elif arch == 'resnext101':
orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
net_encoder = Resnet(orig_resnext) # we can still use class Resnet
else:
raise Exception('Architecture undefined!')
# net_encoder.apply(self.weights_init)
if len(weights) > 0:
# print('Loading weights for net_encoder')
net_encoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_encoder
def build_decoder(self, arch='ppm_bilinear_deepsup',
fc_dim=512, num_class=150,
weights='', inference=False, use_softmax=False):
if arch == 'c1_bilinear_deepsup':
net_decoder = C1BilinearDeepSup(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax)
elif arch == 'c1_bilinear':
net_decoder = C1Bilinear(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax)
elif arch == 'ppm_bilinear':
net_decoder = PPMBilinear(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax)
elif arch == 'ppm_bilinear_deepsup':
net_decoder = PPMBilinearDeepsup(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax)
elif arch == 'upernet_lite':
net_decoder = UPerNet(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax,
fpn_dim=256)
elif arch == 'upernet':
net_decoder = UPerNet(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax,
fpn_dim=512)
elif arch == 'upernet_tmp':
net_decoder = UPerNetTmp(
num_class=num_class,
fc_dim=fc_dim,
inference=inference,
use_softmax=use_softmax,
fpn_dim=512)
else:
raise Exception('Architecture undefined!')
net_decoder.apply(self.weights_init)
if len(weights) > 0:
# print('Loading weights for net_decoder')
net_decoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_decoder
class Resnet(nn.Module):
def __init__(self, orig_resnet):
super(Resnet, self).__init__()
# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu1 = orig_resnet.relu1
self.conv2 = orig_resnet.conv2
self.bn2 = orig_resnet.bn2
self.relu2 = orig_resnet.relu2
self.conv3 = orig_resnet.conv3
self.bn3 = orig_resnet.bn3
self.relu3 = orig_resnet.relu3
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def forward(self, x, return_feature_maps=False):
conv_out = []
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.layer1(x); conv_out.append(x);
x = self.layer2(x); conv_out.append(x);
x = self.layer3(x); conv_out.append(x);
x = self.layer4(x); conv_out.append(x);
if return_feature_maps:
return conv_out
return [x]
class ResnetDilated(nn.Module):
def __init__(self, orig_resnet, dilate_scale=8):
super(ResnetDilated, self).__init__()
from functools import partial
if dilate_scale == 8:
orig_resnet.layer3.apply(
partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(
partial(self._nostride_dilate, dilate=4))
elif dilate_scale == 16:
orig_resnet.layer4.apply(
partial(self._nostride_dilate, dilate=2))
# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu1 = orig_resnet.relu1
self.conv2 = orig_resnet.conv2
self.bn2 = orig_resnet.bn2
self.relu2 = orig_resnet.relu2
self.conv3 = orig_resnet.conv3
self.bn3 = orig_resnet.bn3
self.relu3 = orig_resnet.relu3
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
# the convolution with stride
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate//2, dilate//2)
m.padding = (dilate//2, dilate//2)
# other convoluions
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x, return_feature_maps=False):
conv_out = []
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.layer1(x); conv_out.append(x);
x = self.layer2(x); conv_out.append(x);
x = self.layer3(x); conv_out.append(x);
x = self.layer4(x); conv_out.append(x);
if return_feature_maps:
return conv_out
return [x]
# last conv, bilinear upsample
class C1BilinearDeepSup(nn.Module):
def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False):
super(C1BilinearDeepSup, self).__init__()
self.use_softmax = use_softmax
self.inference = inference
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
# last conv
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
x = self.cbr(conv5)
x = self.conv_last(x)
if self.inference or self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode='bilinear', align_corners=False)
if self.use_softmax:
x = nn.functional.softmax(x, dim=1)
return x
# deep sup
conv4 = conv_out[-2]
_ = self.cbr_deepsup(conv4)
_ = self.conv_last_deepsup(_)
x = nn.functional.log_softmax(x, dim=1)
_ = nn.functional.log_softmax(_, dim=1)
return (x, _)
# last conv, bilinear upsample
class C1Bilinear(nn.Module):
def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False):
super(C1Bilinear, self).__init__()
self.use_softmax = use_softmax
self.inference = inference
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
# last conv
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
x = self.cbr(conv5)
x = self.conv_last(x)
if self.inference or self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode='bilinear', align_corners=False)
if self.use_softmax:
x = nn.functional.softmax(x, dim=1)
else:
x = nn.functional.log_softmax(x, dim=1)
return x
# pyramid pooling, bilinear upsample
class PPMBilinear(nn.Module):
def __init__(self, num_class=150, fc_dim=4096,
inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)):
super(PPMBilinear, self).__init__()
self.use_softmax = use_softmax
self.inference = inference
self.ppm = []
for scale in pool_scales:
self.ppm.append(nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm = nn.ModuleList(self.ppm)
self.conv_last = nn.Sequential(
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
kernel_size=3, padding=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1)
)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode='bilinear', align_corners=False))
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_last(ppm_out)
if self.inference or self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode='bilinear', align_corners=False)
if self.use_softmax:
x = nn.functional.softmax(x, dim=1)
else:
x = nn.functional.log_softmax(x, dim=1)
return x
# pyramid pooling, bilinear upsample
class PPMBilinearDeepsup(nn.Module):
def __init__(self, num_class=150, fc_dim=4096,
inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)):
super(PPMBilinearDeepsup, self).__init__()
self.use_softmax = use_softmax
self.inference = inference
self.ppm = []
for scale in pool_scales:
self.ppm.append(nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm = nn.ModuleList(self.ppm)
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
self.conv_last = nn.Sequential(
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
kernel_size=3, padding=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1)
)
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
self.dropout_deepsup = nn.Dropout2d(0.1)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode='bilinear', align_corners=False))
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_last(ppm_out)
if self.inference or self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode='bilinear', align_corners=False)
if self.use_softmax:
x = nn.functional.softmax(x, dim=1)
return x
# deep sup
conv4 = conv_out[-2]
_ = self.cbr_deepsup(conv4)
_ = self.dropout_deepsup(_)
_ = self.conv_last_deepsup(_)
x = nn.functional.log_softmax(x, dim=1)
_ = nn.functional.log_softmax(_, dim=1)
return (x, _)
# upernet
class UPerNet(nn.Module):
def __init__(self, num_class=150, fc_dim=4096,
inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6),
fpn_inplanes=(256,512,1024,2048), fpn_dim=256):
super(UPerNet, self).__init__()
self.use_softmax = use_softmax
self.inference = inference
# PPM Module
self.ppm_pooling = []
self.ppm_conv = []
for scale in pool_scales:
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
self.ppm_conv.append(nn.Sequential(
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
self.ppm_conv = nn.ModuleList(self.ppm_conv)
self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
# FPN Module
self.fpn_in = []
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
self.fpn_in.append(nn.Sequential(
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(fpn_dim),
nn.ReLU(inplace=True)
))
self.fpn_in = nn.ModuleList(self.fpn_in)
self.fpn_out = []
for i in range(len(fpn_inplanes) - 1): # skip the top layer
self.fpn_out.append(nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
))
self.fpn_out = nn.ModuleList(self.fpn_out)
self.conv_last = nn.Sequential(
conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
nn.Conv2d(fpn_dim, num_class, kernel_size=1)
)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
ppm_out.append(pool_conv(nn.functional.interploate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode='bilinear', align_corners=False)))
ppm_out = torch.cat(ppm_out, 1)
f = self.ppm_last_conv(ppm_out)
fpn_feature_list = [f]
for i in reversed(range(len(conv_out) - 1)):
conv_x = conv_out[i]
conv_x = self.fpn_in[i](conv_x) # lateral branch
f = nn.functional.interpolate(
f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
f = conv_x + f
fpn_feature_list.append(self.fpn_out[i](f))
fpn_feature_list.reverse() # [P2 - P5]
output_size = fpn_feature_list[0].size()[2:]
fusion_list = [fpn_feature_list[0]]
for i in range(1, len(fpn_feature_list)):
fusion_list.append(nn.functional.interpolate(
fpn_feature_list[i],
output_size,
mode='bilinear', align_corners=False))
fusion_out = torch.cat(fusion_list, 1)
x = self.conv_last(fusion_out)
if self.inference or self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode='bilinear', align_corners=False)
if self.use_softmax:
x = nn.functional.softmax(x, dim=1)
return x
x = nn.functional.log_softmax(x, dim=1)
return x