# Copyright (C) 2021 * Ltd. All rights reserved. # author : Sanghyeon Jo import math import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models import torch.utils.model_zoo as model_zoo from .arch_resnet import resnet from .arch_resnest import resnest from .abc_modules import ABC_Model from .deeplab_utils import ASPP, Decoder from .aff_utils import PathIndex from .puzzle_utils import tile_features, merge_features from tools.ai.torch_utils import resize_for_tensors ####################################################################### # Normalization ####################################################################### from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d class FixedBatchNorm(nn.BatchNorm2d): def forward(self, x): return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps) def group_norm(features): return nn.GroupNorm(4, features) ####################################################################### class Backbone(nn.Module, ABC_Model): def __init__(self, model_name, state_path, num_classes=20, mode='fix', segmentation=False): super().__init__() self.mode = mode if self.mode == 'fix': self.norm_fn = FixedBatchNorm else: self.norm_fn = nn.BatchNorm2d if 'resnet' in model_name: self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1), batch_norm_fn=self.norm_fn) state_dict = torch.load(state_path) self.model.load_state_dict(state_dict, strict=False) else: if segmentation: dilation, dilated = 4, True else: dilation, dilated = 2, False self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation, norm_layer=self.norm_fn) del self.model.avgpool del self.model.fc self.stage1 = nn.Sequential(self.model.conv1, self.model.bn1, self.model.relu, self.model.maxpool) self.stage2 = nn.Sequential(self.model.layer1) self.stage3 = nn.Sequential(self.model.layer2) self.stage4 = nn.Sequential(self.model.layer3) self.stage5 = nn.Sequential(self.model.layer4) class Classifier(Backbone): def __init__(self, model_name, state_path, num_classes=20, mode='fix'): super().__init__(model_name, state_path, num_classes, mode) self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) self.num_classes = num_classes self.initialize([self.classifier]) def forward(self, x, with_cam=False): x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.stage5(x) if with_cam: features = self.classifier(x) logits = self.global_average_pooling_2d(features) return logits, features else: x = self.global_average_pooling_2d(x, keepdims=True) logits = self.classifier(x).view(-1, self.num_classes) return logits class Classifier_For_Positive_Pooling(Backbone): def __init__(self, model_name, num_classes=20, mode='fix'): super().__init__(model_name, num_classes, mode) self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) self.num_classes = num_classes self.initialize([self.classifier]) def forward(self, x, with_cam=False): x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.stage5(x) if with_cam: features = self.classifier(x) logits = self.global_average_pooling_2d(features) return logits, features else: x = self.global_average_pooling_2d(x, keepdims=True) logits = self.classifier(x).view(-1, self.num_classes) return logits class Classifier_For_Puzzle(Classifier): def __init__(self, model_name, num_classes=20, mode='fix'): super().__init__(model_name, num_classes, mode) def forward(self, x, num_pieces=1, level=-1): batch_size = x.size()[0] output_dic = {} layers = [self.stage1, self.stage2, self.stage3, self.stage4, self.stage5, self.classifier] for l, layer in enumerate(layers): l += 1 if level == l: x = tile_features(x, num_pieces) x = layer(x) output_dic['stage%d'%l] = x output_dic['logits'] = self.global_average_pooling_2d(output_dic['stage6']) for l in range(len(layers)): l += 1 if l >= level: output_dic['stage%d'%l] = merge_features(output_dic['stage%d'%l], num_pieces, batch_size) if level is not None: output_dic['merged_logits'] = self.global_average_pooling_2d(output_dic['stage6']) return output_dic class AffinityNet(Backbone): def __init__(self, model_name, path_index=None): super().__init__(model_name, None, 'fix') if '50' in model_name: fc_edge1_features = 64 else: fc_edge1_features = 128 self.fc_edge1 = nn.Sequential( nn.Conv2d(fc_edge1_features, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.ReLU(inplace=True), ) self.fc_edge2 = nn.Sequential( nn.Conv2d(256, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.ReLU(inplace=True), ) self.fc_edge3 = nn.Sequential( nn.Conv2d(512, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge4 = nn.Sequential( nn.Conv2d(1024, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge5 = nn.Sequential( nn.Conv2d(2048, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True) self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5]) self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6]) if path_index is not None: self.path_index = path_index self.n_path_lengths = len(self.path_index.path_indices) for i, pi in enumerate(self.path_index.path_indices): self.register_buffer("path_indices_" + str(i), torch.from_numpy(pi)) def train(self, mode=True): super().train(mode) self.backbone.eval() def forward(self, x, with_affinity=False): x1 = self.stage1(x).detach() x2 = self.stage2(x1).detach() x3 = self.stage3(x2).detach() x4 = self.stage4(x3).detach() x5 = self.stage5(x4).detach() edge1 = self.fc_edge1(x1) edge2 = self.fc_edge2(x2) edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)] edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)] edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)] edge = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1)) if with_affinity: return edge, self.to_affinity(torch.sigmoid(edge)) else: return edge def get_edge(self, x, image_size=512, stride=4): feat_size = (x.size(2)-1)//stride+1, (x.size(3)-1)//stride+1 x = F.pad(x, [0, image_size-x.size(3), 0, image_size-x.size(2)]) edge_out = self.forward(x) edge_out = edge_out[..., :feat_size[0], :feat_size[1]] edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2) return edge_out """ aff = self.to_affinity(torch.sigmoid(edge_out)) pos_aff_loss = (-1) * torch.log(aff + 1e-5) neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff) """ def to_affinity(self, edge): aff_list = [] edge = edge.view(edge.size(0), -1) for i in range(self.n_path_lengths): ind = self._buffers["path_indices_" + str(i)] ind_flat = ind.view(-1) dist = torch.index_select(edge, dim=-1, index=ind_flat) dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) aff_list.append(aff) aff_cat = torch.cat(aff_list, dim=1) return aff_cat class DeepLabv3_Plus(Backbone): def __init__(self, model_name, num_classes=21, mode='fix', use_group_norm=False): super().__init__(model_name, num_classes, mode, segmentation=False) if use_group_norm: norm_fn_for_extra_modules = group_norm else: norm_fn_for_extra_modules = self.norm_fn self.aspp = ASPP(output_stride=16, norm_fn=norm_fn_for_extra_modules) self.decoder = Decoder(num_classes, 256, norm_fn_for_extra_modules) def forward(self, x, with_cam=False): inputs = x x = self.stage1(x) x = self.stage2(x) x_low_level = x x = self.stage3(x) x = self.stage4(x) x = self.stage5(x) x = self.aspp(x) x = self.decoder(x, x_low_level) x = resize_for_tensors(x, inputs.size()[2:], align_corners=True) return x class Seg_Model(Backbone): def __init__(self, model_name, num_classes=21): super().__init__(model_name, num_classes, mode='fix', segmentation=False) self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) def forward(self, inputs): x = self.stage1(inputs) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.stage5(x) logits = self.classifier(x) # logits = resize_for_tensors(logits, inputs.size()[2:], align_corners=False) return logits class CSeg_Model(Backbone): def __init__(self, model_name, num_classes=21): super().__init__(model_name, num_classes, 'fix') if '50' in model_name: fc_edge1_features = 64 else: fc_edge1_features = 128 self.fc_edge1 = nn.Sequential( nn.Conv2d(fc_edge1_features, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.ReLU(inplace=True), ) self.fc_edge2 = nn.Sequential( nn.Conv2d(256, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.ReLU(inplace=True), ) self.fc_edge3 = nn.Sequential( nn.Conv2d(512, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge4 = nn.Sequential( nn.Conv2d(1024, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge5 = nn.Sequential( nn.Conv2d(2048, 32, 1, bias=False), nn.GroupNorm(4, 32), nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), nn.ReLU(inplace=True), ) self.fc_edge6 = nn.Conv2d(160, num_classes, 1, bias=True) def forward(self, x): x1 = self.stage1(x) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) x5 = self.stage5(x4) edge1 = self.fc_edge1(x1) edge2 = self.fc_edge2(x2) edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)] edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)] edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)] logits = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1)) # logits = resize_for_tensors(logits, x.size()[2:], align_corners=True) return logits