|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torch.nn import SyncBatchNorm as BatchNorm2d |
|
from functools import partial |
|
import re |
|
from models.base_models.resnet import resnet101, resnet18, resnet50 |
|
from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp |
|
|
|
class DeepLabV3p_r18(nn.Module): |
|
def __init__(self, num_classes, config): |
|
super(DeepLabV3p_r18, self).__init__() |
|
self.norm_layer = BatchNorm2d |
|
self.backbone = resnet18(config.pretrained_model_r18, norm_layer=self.norm_layer, |
|
bn_eps=config.bn_eps, |
|
bn_momentum=config.bn_momentum, |
|
deep_stem=False, stem_width=64) |
|
self.dilate = 2 |
|
for m in self.backbone.layer4.children(): |
|
m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
|
self.dilate *= 2 |
|
|
|
self.head = Head('r18', num_classes, self.norm_layer, config.bn_momentum) |
|
self.business_layer = [] |
|
self.business_layer.append(self.head) |
|
|
|
self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
|
self.business_layer.append(self.classifier) |
|
init_weight(self.business_layer, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
init_weight(self.classifier, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
|
|
def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
|
data = data[0] |
|
blocks = self.backbone(data) |
|
v3plus_feature = self.head(blocks) |
|
b, c, h, w = v3plus_feature.shape |
|
|
|
pred = self.classifier(v3plus_feature) |
|
|
|
b, c, h, w = data.shape |
|
pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
|
if not self.training: |
|
return pred |
|
else: |
|
if get_sup_loss: |
|
return pred, self.get_sup_loss(pred, gt, criterion) |
|
else: |
|
return pred |
|
|
|
def get_sup_loss(self, pred, gt, criterion): |
|
pred = pred[:gt.shape[0]] |
|
return criterion(pred, gt) |
|
|
|
|
|
def _nostride_dilate(self, m, dilate): |
|
if isinstance(m, nn.Conv2d): |
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def get_params(self): |
|
param_groups = [[], [], []] |
|
enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
|
param_groups[0].extend(enc) |
|
param_groups[1].extend(enc_no_decay) |
|
dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
|
param_groups[2].extend(dec) |
|
param_groups[1].extend(dec_no_decay) |
|
classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
|
param_groups[2].extend(classifier) |
|
param_groups[1].extend(classifier_no_decay) |
|
return param_groups |
|
|
|
class DeepLabV3p_r50(nn.Module): |
|
def __init__(self, num_classes, config): |
|
super(DeepLabV3p_r50, self).__init__() |
|
self.norm_layer = BatchNorm2d |
|
self.backbone = resnet50(config.pretrained_model_r50, norm_layer=self.norm_layer, |
|
bn_eps=config.bn_eps, |
|
bn_momentum=config.bn_momentum, |
|
deep_stem=True, stem_width=64) |
|
self.dilate = 2 |
|
for m in self.backbone.layer4.children(): |
|
m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
|
self.dilate *= 2 |
|
|
|
self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) |
|
self.business_layer = [] |
|
self.business_layer.append(self.head) |
|
|
|
self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
|
self.business_layer.append(self.classifier) |
|
init_weight(self.business_layer, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
init_weight(self.classifier, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
|
|
|
|
def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
|
data = data[0] |
|
blocks = self.backbone(data) |
|
v3plus_feature = self.head(blocks) |
|
b, c, h, w = v3plus_feature.shape |
|
|
|
pred = self.classifier(v3plus_feature) |
|
|
|
b, c, h, w = data.shape |
|
pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
|
if not self.training: |
|
return pred |
|
else: |
|
if get_sup_loss: |
|
return pred, self.get_sup_loss(pred, gt, criterion) |
|
else: |
|
return pred |
|
|
|
def get_sup_loss(self, pred, gt, criterion): |
|
pred = pred[:gt.shape[0]] |
|
return criterion(pred, gt) |
|
|
|
def get_params(self): |
|
param_groups = [[], [], []] |
|
enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
|
param_groups[0].extend(enc) |
|
param_groups[1].extend(enc_no_decay) |
|
dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
|
param_groups[2].extend(dec) |
|
param_groups[1].extend(dec_no_decay) |
|
classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
|
param_groups[2].extend(classifier) |
|
param_groups[1].extend(classifier_no_decay) |
|
return param_groups |
|
|
|
|
|
def _nostride_dilate(self, m, dilate): |
|
if isinstance(m, nn.Conv2d): |
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
class DeepLabV3p_r101(nn.Module): |
|
def __init__(self, num_classes, config): |
|
super(DeepLabV3p_r101, self).__init__() |
|
self.norm_layer = BatchNorm2d |
|
self.backbone = resnet101(config.pretrained_model_r101, norm_layer=self.norm_layer, |
|
bn_eps=config.bn_eps, |
|
bn_momentum=config.bn_momentum, |
|
deep_stem=True, stem_width=64) |
|
self.dilate = 2 |
|
for m in self.backbone.layer4.children(): |
|
m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
|
self.dilate *= 2 |
|
|
|
self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) |
|
self.business_layer = [] |
|
self.business_layer.append(self.head) |
|
|
|
self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
|
self.business_layer.append(self.classifier) |
|
init_weight(self.business_layer, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
init_weight(self.classifier, nn.init.kaiming_normal_, |
|
BatchNorm2d, config.bn_eps, config.bn_momentum, |
|
mode='fan_in', nonlinearity='relu') |
|
|
|
|
|
def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
|
data = data[0] |
|
blocks = self.backbone(data) |
|
v3plus_feature = self.head(blocks) |
|
b, c, h, w = v3plus_feature.shape |
|
|
|
pred = self.classifier(v3plus_feature) |
|
|
|
b, c, h, w = data.shape |
|
pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
|
if not self.training: |
|
return pred |
|
else: |
|
if get_sup_loss: |
|
return pred, self.get_sup_loss(pred, gt, criterion) |
|
else: |
|
return pred |
|
|
|
def get_sup_loss(self, pred, gt, criterion): |
|
pred = pred[:gt.shape[0]] |
|
return criterion(pred, gt) |
|
|
|
def get_params(self): |
|
param_groups = [[], [], []] |
|
enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
|
param_groups[0].extend(enc) |
|
param_groups[1].extend(enc_no_decay) |
|
dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
|
param_groups[2].extend(dec) |
|
param_groups[1].extend(dec_no_decay) |
|
classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
|
param_groups[2].extend(classifier) |
|
param_groups[1].extend(classifier_no_decay) |
|
return param_groups |
|
|
|
|
|
def _nostride_dilate(self, m, dilate): |
|
if isinstance(m, nn.Conv2d): |
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
|
|
class ASPP(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
dilation_rates=(12, 24, 36), |
|
hidden_channels=256, |
|
norm_act=nn.BatchNorm2d, |
|
pooling_size=None): |
|
super(ASPP, self).__init__() |
|
self.pooling_size = pooling_size |
|
|
|
self.map_convs = nn.ModuleList([ |
|
nn.Conv2d(in_channels, hidden_channels, 1, bias=False), |
|
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[0], |
|
padding=dilation_rates[0]), |
|
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[1], |
|
padding=dilation_rates[1]), |
|
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[2], |
|
padding=dilation_rates[2]) |
|
]) |
|
self.map_bn = norm_act(hidden_channels * 4) |
|
|
|
self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) |
|
self.global_pooling_bn = norm_act(hidden_channels) |
|
|
|
self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) |
|
self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) |
|
self.red_bn = norm_act(out_channels) |
|
|
|
self.leak_relu = nn.LeakyReLU() |
|
|
|
def forward(self, x): |
|
|
|
out = torch.cat([m(x) for m in self.map_convs], dim=1) |
|
out = self.map_bn(out) |
|
out = self.leak_relu(out) |
|
out = self.red_conv(out) |
|
|
|
|
|
pool = self._global_pooling(x) |
|
pool = self.global_pooling_conv(pool) |
|
pool = self.global_pooling_bn(pool) |
|
|
|
pool = self.leak_relu(pool) |
|
|
|
pool = self.pool_red_conv(pool) |
|
if self.training or self.pooling_size is None: |
|
pool = pool.repeat(1, 1, x.size(2), x.size(3)) |
|
|
|
out += pool |
|
out = self.red_bn(out) |
|
out = self.leak_relu(out) |
|
return out |
|
|
|
def _global_pooling(self, x): |
|
pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) |
|
pool = pool.view(x.size(0), x.size(1), 1, 1) |
|
return pool |
|
|
|
|
|
class Head(nn.Module): |
|
def __init__(self, base_model, classify_classes, norm_act=nn.BatchNorm2d, bn_momentum=0.0003): |
|
super(Head, self).__init__() |
|
|
|
self.classify_classes = classify_classes |
|
if base_model == 'r18': |
|
self.aspp = ASPP(512, 256, [6, 12, 18], norm_act=norm_act) |
|
|
|
self.reduce = nn.Sequential( |
|
nn.Conv2d(64, 48, 1, bias=False), |
|
norm_act(48, momentum=bn_momentum), |
|
nn.ReLU(), |
|
) |
|
elif base_model == 'r50': |
|
self.aspp = ASPP(2048, 256, [6, 12, 18], norm_act=norm_act) |
|
self.reduce = nn.Sequential( |
|
nn.Conv2d(256, 48, 1, bias=False), |
|
norm_act(48, momentum=bn_momentum), |
|
nn.ReLU(), |
|
) |
|
else: |
|
raise Exception(f"Head not implemented for {base_model}") |
|
|
|
|
|
|
|
self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), |
|
norm_act(256, momentum=bn_momentum), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), |
|
norm_act(256, momentum=bn_momentum), |
|
nn.ReLU(), |
|
) |
|
|
|
def forward(self, f_list): |
|
f = f_list[-1] |
|
f = self.aspp(f) |
|
|
|
low_level_features = f_list[0] |
|
low_h, low_w = low_level_features.size(2), low_level_features.size(3) |
|
low_level_features = self.reduce(low_level_features) |
|
|
|
f = F.interpolate(f, size=(low_h, low_w), mode='bilinear', align_corners=True) |
|
f = torch.cat((f, low_level_features), dim=1) |
|
f = self.last_conv(f) |
|
|
|
return f |
|
|
|
|
|
def group_weight(module, norm_layer): |
|
group_decay = [] |
|
group_no_decay = [] |
|
for m in module.modules(): |
|
if isinstance(m, nn.Linear): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, Conv2_5D_depth): |
|
group_decay.append(m.weight_0) |
|
group_decay.append(m.weight_1) |
|
group_decay.append(m.weight_2) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, Conv2_5D_disp): |
|
group_decay.append(m.weight_0) |
|
group_decay.append(m.weight_1) |
|
group_decay.append(m.weight_2) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ |
|
or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): |
|
if m.weight is not None: |
|
group_no_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, nn.Parameter): |
|
group_decay.append(m) |
|
elif isinstance(m, nn.Embedding): |
|
group_decay.append(m) |
|
assert len(list(module.parameters())) == len(group_decay) + len( |
|
group_no_decay) |
|
return group_decay, group_no_decay |
|
|
|
def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs): |
|
for name, m in feature.named_modules(): |
|
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
|
conv_init(m.weight, **kwargs) |
|
elif isinstance(m, Conv2_5D_depth): |
|
conv_init(m.weight_0, **kwargs) |
|
conv_init(m.weight_1, **kwargs) |
|
conv_init(m.weight_2, **kwargs) |
|
elif isinstance(m, Conv2_5D_disp): |
|
conv_init(m.weight_0, **kwargs) |
|
conv_init(m.weight_1, **kwargs) |
|
conv_init(m.weight_2, **kwargs) |
|
elif isinstance(m, norm_layer): |
|
m.eps = bn_eps |
|
m.momentum = bn_momentum |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs): |
|
if isinstance(module_list, list): |
|
for feature in module_list: |
|
__init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs) |
|
else: |
|
__init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs) |