File size: 7,836 Bytes
d4ebf73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# 2022.06.08-Changed for implementation of TokenFusion
# Huawei Technologies Co., Ltd. <[email protected]>
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.seg_opr.loss_func import JSD
from . import mix_transformer
from mmcv.cnn import ConvModule
from .modules import num_parallel
class MLP(nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.proj(x)
return x
class SegFormerHead(nn.Module):
"""
SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
"""
def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs):
super(SegFormerHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
#decoder_params = kwargs['decoder_params']
#embedding_dim = decoder_params['embed_dim']
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
self.dropout = nn.Dropout2d(0.1)
self.linear_fuse = ConvModule(
in_channels=embedding_dim*4,
out_channels=embedding_dim,
kernel_size=1,
norm_cfg=dict(type='BN', requires_grad=True)
)
self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
def forward(self, x):
c1, c2, c3, c4 = x
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous()
_c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous()
_c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous()
_c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous()
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.dropout(_c)
x = self.linear_pred(x)
return x
class LinearFusionConsistency(nn.Module):
def __init__(self, backbone, config, cons_lambda, num_classes=20, embedding_dim=256, pretrained=True):
super().__init__()
self.num_classes = num_classes
self.embedding_dim = embedding_dim
self.feature_strides = [4, 8, 16, 32]
self.num_parallel = num_parallel
self.cons_lambda = cons_lambda
self.cons_loss = JSD()
#self.in_channels = [32, 64, 160, 256]
#self.in_channels = [64, 128, 320, 512]
self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio)
self.in_channels = self.encoder.embed_dims
## initilize encoder
if pretrained:
state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth')
state_dict.pop('head.weight')
state_dict.pop('head.bias')
state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel)
self.encoder.load_state_dict(state_dict, strict=True)
self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels,
embedding_dim=self.embedding_dim, num_classes=self.num_classes)
self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True))
self.register_parameter('alpha', self.alpha)
self.ratio = config.ratio
def get_params(self):
param_groups = [[], [], []]
for name, param in list(self.encoder.named_parameters()):
if "norm" in name:
param_groups[1].append(param)
else:
param_groups[0].append(param)
for param in list(self.decoder.parameters()):
param_groups[2].append(param)
return param_groups
# def get_params(self):
# param_groups = [[], []]
# for param in list(self.encoder.parameters()):
# param_groups[0].append(param)
# for param in list(self.decoder.parameters()):
# param_groups[1].append(param)
# return param_groups
def forward(self, data, get_sup_loss = False, gt = None, criterion = None):
b, c, h, w = data[0].shape #rgb is the 0th element
x = self.encoder(data)
pred = [self.decoder(x[0]), self.decoder(x[1])]
ens = 0
alpha_soft = F.softmax(self.alpha)
for l in range(self.num_parallel):
ens += alpha_soft[l] * pred[l].detach()
pred.append(ens)
largepred = []
for i in range(len(pred)):
largepred.append(F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True))
cons_loss = self.cons_lambda * self.get_cons_loss(pred[0], pred[1]) #Not taking consistency with ensemble
if not self.training:
return pred
else: # training
if get_sup_loss:
# l1 = self.get_l1_loss(masks, data[0].get_device()) / b
# l1_loss = self.l1_lambda * l1
cons_loss = self.cons_lambda * self.get_cons_loss(pred[0], pred[1]) #Not taking consistency with ensemble
sup_loss = self.get_sup_loss(largepred, gt, criterion)
# print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses")
return largepred, [sup_loss, cons_loss]
else:
return largepred
def get_cons_loss(self, b1, b2):
#b1 and b2 are [batchsize x num_classes x p x p] where p depends on encoder
assert b1.shape[1] == self.num_classes
b1 = b1.reshape(-1, self.num_classes)
b2 = b2.reshape(-1, self.num_classes) #JSD loss expects batch_size x SoftMaxDimension
return self.cons_loss(b1, b2)
def get_sup_loss(self, pred, gt, criterion):
sup_loss = 0
for p in pred:
p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data.
# soft_output = nn.LogSoftmax()(p)
sup_loss += criterion(p, gt)
return sup_loss / len(pred)
def expand_state_dict(model_dict, state_dict, num_parallel):
model_dict_keys = model_dict.keys()
state_dict_keys = state_dict.keys()
for model_dict_key in model_dict_keys:
model_dict_key_re = model_dict_key.replace('module.', '')
if model_dict_key_re in state_dict_keys:
model_dict[model_dict_key] = state_dict[model_dict_key_re]
for i in range(num_parallel):
ln = '.ln_%d' % i
replace = True if ln in model_dict_key_re else False
model_dict_key_re = model_dict_key_re.replace(ln, '')
if replace and model_dict_key_re in state_dict_keys:
model_dict[model_dict_key] = state_dict[model_dict_key_re]
return model_dict
|