import torch import torch.nn as nn import torch.nn.functional as F # from torch.nn import BatchNorm2d from torch.nn import SyncBatchNorm as BatchNorm2d from functools import partial import re from models.base_models.resnet import resnet101, resnet18, resnet50 from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp class DeepLabV3p_r18(nn.Module): def __init__(self, num_classes, config): super(DeepLabV3p_r18, self).__init__() self.norm_layer = BatchNorm2d self.backbone = resnet18(config.pretrained_model_r18, norm_layer=self.norm_layer, bn_eps=config.bn_eps, bn_momentum=config.bn_momentum, deep_stem=False, stem_width=64) self.dilate = 2 for m in self.backbone.layer4.children(): m.apply(partial(self._nostride_dilate, dilate=self.dilate)) self.dilate *= 2 self.head = Head('r18', num_classes, self.norm_layer, config.bn_momentum) self.business_layer = [] self.business_layer.append(self.head) self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) self.business_layer.append(self.classifier) init_weight(self.business_layer, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') init_weight(self.classifier, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') def forward(self, data, get_sup_loss = False, gt = None, criterion = None): data = data[0] #rgb is the first element in the list blocks = self.backbone(data) v3plus_feature = self.head(blocks) #(b, c, h, w) b, c, h, w = v3plus_feature.shape pred = self.classifier(v3plus_feature) b, c, h, w = data.shape pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) if not self.training: #return pred for evaluation return pred else: if get_sup_loss: return pred, self.get_sup_loss(pred, gt, criterion) else: return pred def get_sup_loss(self, pred, gt, criterion): pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. return criterion(pred, gt) # @staticmethod def _nostride_dilate(self, m, dilate): if isinstance(m, nn.Conv2d): if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def get_params(self): param_groups = [[], [], []] enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) param_groups[0].extend(enc) param_groups[1].extend(enc_no_decay) dec, dec_no_decay = group_weight(self.head, self.norm_layer) param_groups[2].extend(dec) param_groups[1].extend(dec_no_decay) classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) param_groups[2].extend(classifier) param_groups[1].extend(classifier_no_decay) return param_groups class DeepLabV3p_r50(nn.Module): def __init__(self, num_classes, config): super(DeepLabV3p_r50, self).__init__() self.norm_layer = BatchNorm2d self.backbone = resnet50(config.pretrained_model_r50, norm_layer=self.norm_layer, bn_eps=config.bn_eps, bn_momentum=config.bn_momentum, deep_stem=True, stem_width=64) self.dilate = 2 for m in self.backbone.layer4.children(): m.apply(partial(self._nostride_dilate, dilate=self.dilate)) self.dilate *= 2 self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) self.business_layer = [] self.business_layer.append(self.head) self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) self.business_layer.append(self.classifier) init_weight(self.business_layer, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') init_weight(self.classifier, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') def forward(self, data, get_sup_loss = False, gt = None, criterion = None): data = data[0] #rgb is the first element in the list blocks = self.backbone(data) v3plus_feature = self.head(blocks) #(b, c, h, w) b, c, h, w = v3plus_feature.shape pred = self.classifier(v3plus_feature) b, c, h, w = data.shape pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) if not self.training: #return pred for evaluation return pred else: if get_sup_loss: return pred, self.get_sup_loss(pred, gt, criterion) else: return pred def get_sup_loss(self, pred, gt, criterion): pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. return criterion(pred, gt) def get_params(self): param_groups = [[], [], []] enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) param_groups[0].extend(enc) param_groups[1].extend(enc_no_decay) dec, dec_no_decay = group_weight(self.head, self.norm_layer) param_groups[2].extend(dec) param_groups[1].extend(dec_no_decay) classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) param_groups[2].extend(classifier) param_groups[1].extend(classifier_no_decay) return param_groups # @staticmethod def _nostride_dilate(self, m, dilate): if isinstance(m, nn.Conv2d): if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) class DeepLabV3p_r101(nn.Module): def __init__(self, num_classes, config): super(DeepLabV3p_r101, self).__init__() self.norm_layer = BatchNorm2d self.backbone = resnet101(config.pretrained_model_r101, norm_layer=self.norm_layer, bn_eps=config.bn_eps, bn_momentum=config.bn_momentum, deep_stem=True, stem_width=64) self.dilate = 2 for m in self.backbone.layer4.children(): m.apply(partial(self._nostride_dilate, dilate=self.dilate)) self.dilate *= 2 self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) self.business_layer = [] self.business_layer.append(self.head) self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) self.business_layer.append(self.classifier) init_weight(self.business_layer, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') init_weight(self.classifier, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') def forward(self, data, get_sup_loss = False, gt = None, criterion = None): data = data[0] #rgb is the first element in the list blocks = self.backbone(data) v3plus_feature = self.head(blocks) #(b, c, h, w) b, c, h, w = v3plus_feature.shape pred = self.classifier(v3plus_feature) b, c, h, w = data.shape pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) if not self.training: #return pred for evaluation return pred else: if get_sup_loss: return pred, self.get_sup_loss(pred, gt, criterion) else: return pred def get_sup_loss(self, pred, gt, criterion): pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. return criterion(pred, gt) def get_params(self): param_groups = [[], [], []] enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) param_groups[0].extend(enc) param_groups[1].extend(enc_no_decay) dec, dec_no_decay = group_weight(self.head, self.norm_layer) param_groups[2].extend(dec) param_groups[1].extend(dec_no_decay) classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) param_groups[2].extend(classifier) param_groups[1].extend(classifier_no_decay) return param_groups # @staticmethod def _nostride_dilate(self, m, dilate): if isinstance(m, nn.Conv2d): if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) class ASPP(nn.Module): def __init__(self, in_channels, out_channels, dilation_rates=(12, 24, 36), hidden_channels=256, norm_act=nn.BatchNorm2d, pooling_size=None): super(ASPP, self).__init__() self.pooling_size = pooling_size self.map_convs = nn.ModuleList([ nn.Conv2d(in_channels, hidden_channels, 1, bias=False), nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[0], padding=dilation_rates[0]), nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[1], padding=dilation_rates[1]), nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[2], padding=dilation_rates[2]) ]) self.map_bn = norm_act(hidden_channels * 4) self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) self.global_pooling_bn = norm_act(hidden_channels) self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) self.red_bn = norm_act(out_channels) self.leak_relu = nn.LeakyReLU() def forward(self, x): # Map convolutions out = torch.cat([m(x) for m in self.map_convs], dim=1) out = self.map_bn(out) out = self.leak_relu(out) # add activation layer out = self.red_conv(out) # Global pooling pool = self._global_pooling(x) pool = self.global_pooling_conv(pool) pool = self.global_pooling_bn(pool) pool = self.leak_relu(pool) # add activation layer pool = self.pool_red_conv(pool) if self.training or self.pooling_size is None: pool = pool.repeat(1, 1, x.size(2), x.size(3)) out += pool out = self.red_bn(out) out = self.leak_relu(out) # add activation layer return out def _global_pooling(self, x): pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) pool = pool.view(x.size(0), x.size(1), 1, 1) return pool class Head(nn.Module): def __init__(self, base_model, classify_classes, norm_act=nn.BatchNorm2d, bn_momentum=0.0003): super(Head, self).__init__() self.classify_classes = classify_classes if base_model == 'r18': self.aspp = ASPP(512, 256, [6, 12, 18], norm_act=norm_act) self.reduce = nn.Sequential( nn.Conv2d(64, 48, 1, bias=False), norm_act(48, momentum=bn_momentum), nn.ReLU(), ) elif base_model == 'r50': self.aspp = ASPP(2048, 256, [6, 12, 18], norm_act=norm_act) self.reduce = nn.Sequential( nn.Conv2d(256, 48, 1, bias=False), norm_act(48, momentum=bn_momentum), nn.ReLU(), ) else: raise Exception(f"Head not implemented for {base_model}") self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), norm_act(256, momentum=bn_momentum), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), norm_act(256, momentum=bn_momentum), nn.ReLU(), ) def forward(self, f_list): f = f_list[-1] f = self.aspp(f) low_level_features = f_list[0] low_h, low_w = low_level_features.size(2), low_level_features.size(3) low_level_features = self.reduce(low_level_features) f = F.interpolate(f, size=(low_h, low_w), mode='bilinear', align_corners=True) f = torch.cat((f, low_level_features), dim=1) f = self.last_conv(f) return f def group_weight(module, norm_layer): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, Conv2_5D_depth): group_decay.append(m.weight_0) group_decay.append(m.weight_1) group_decay.append(m.weight_2) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, Conv2_5D_disp): group_decay.append(m.weight_0) group_decay.append(m.weight_1) group_decay.append(m.weight_2) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.Parameter): group_decay.append(m) elif isinstance(m, nn.Embedding): group_decay.append(m) assert len(list(module.parameters())) == len(group_decay) + len( group_no_decay) return group_decay, group_no_decay def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): for name, m in feature.named_modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): conv_init(m.weight, **kwargs) elif isinstance(m, Conv2_5D_depth): conv_init(m.weight_0, **kwargs) conv_init(m.weight_1, **kwargs) conv_init(m.weight_2, **kwargs) elif isinstance(m, Conv2_5D_disp): conv_init(m.weight_0, **kwargs) conv_init(m.weight_1, **kwargs) conv_init(m.weight_2, **kwargs) elif isinstance(m, norm_layer): m.eps = bn_eps m.momentum = bn_momentum nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): if isinstance(module_list, list): for feature in module_list: __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs) else: __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs)