doevent commited on
Commit
69269d9
·
1 Parent(s): 87abef8

Upload models/loss.py

Browse files
Files changed (1) hide show
  1. models/loss.py +222 -0
models/loss.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import os, glob, shutil, math, random, json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision
7
+ import basic
8
+ from utils import util
9
+
10
+ eps = 0.0000001
11
+
12
+ class SPixelLoss:
13
+ def __init__(self, psize=8, mpdist=False, gpu_no=0):
14
+ self.mpdist = mpdist
15
+ self.gpu_no = gpu_no
16
+ self.sp_size = psize
17
+
18
+ def __call__(self, data, epoch_no):
19
+ kernel_size = self.sp_size
20
+ #pos_weight = 0.003
21
+ prob = data['pred_prob']
22
+ labxy_feat = data['target_feat']
23
+ N,C,H,W = labxy_feat.shape
24
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
25
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
26
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
27
+ featLoss_idx = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
28
+ posLoss_idx = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() / kernel_size
29
+ totalLoss_idx = 10*featLoss_idx + 0.003*posLoss_idx
30
+ return {'totalLoss':totalLoss_idx, 'featLoss':featLoss_idx, 'posLoss':posLoss_idx}
31
+
32
+
33
+ class AnchorColorProbLoss:
34
+ def __init__(self, hint2regress=False, enhanced=False, with_grad=False, mpdist=False, gpu_no=0):
35
+ self.mpdist = mpdist
36
+ self.gpu_no = gpu_no
37
+ self.hint2regress = hint2regress
38
+ self.enhanced = enhanced
39
+ self.with_grad = with_grad
40
+ self.rebalance_gradient = basic.RebalanceLoss.apply
41
+ self.entropy_loss = nn.CrossEntropyLoss(ignore_index=-1)
42
+ if self.enhanced:
43
+ self.VGGLoss = VGG19Loss(gpu_no=gpu_no, is_ddp=mpdist)
44
+
45
+ def _perceptual_loss(self, input_grays, input_colors, pred_colors):
46
+ input_RGBs = basic.lab2rgb(torch.cat([input_grays,input_colors], dim=1))
47
+ pred_RGBs = basic.lab2rgb(torch.cat([input_grays,pred_colors], dim=1))
48
+ ## the output of "lab2rgb" just matches the input of "VGGLoss": [0,1]
49
+ return self.VGGLoss(input_RGBs, pred_RGBs)
50
+
51
+ def _laplace_gradient(self, pred_AB, target_AB):
52
+ N,C,H,W = pred_AB.shape
53
+ kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], device=pred_AB.get_device()).float()
54
+ kernel = kernel.view(1, 1, *kernel.size()).repeat(C,1,1,1)
55
+ grad_pred = F.conv2d(pred_AB, kernel, groups=C)
56
+ grad_trg = F.conv2d(target_AB, kernel, groups=C)
57
+ return l1_loss(grad_trg, grad_pred)
58
+
59
+ def __call__(self, data, epoch_no):
60
+ N,C,H,W = data['target_label'].shape
61
+ pal_probs = self.rebalance_gradient(data['pal_prob'], data['class_weight'])
62
+ #ref_probs = data['ref_prob']
63
+ pal_probs = pal_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
64
+ gt_labels = data['target_label'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
65
+ '''
66
+ igored_mask = data['empty_entries'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
67
+ gt_labels[igored_mask] = -1
68
+ gt_labels = gt_probs.squeeze()
69
+ '''
70
+ palLoss_idx = self.entropy_loss(pal_probs, gt_labels.squeeze(dim=1))
71
+ if self.hint2regress:
72
+ ref_probs = data['ref_prob']
73
+ refLoss_idx = 50 * l2_loss(data['spix_color'], ref_probs)
74
+ else:
75
+ ref_probs = self.rebalance_gradient(data['ref_prob'], data['class_weight'])
76
+ ref_probs = ref_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
77
+ refLoss_idx = self.entropy_loss(ref_probs, gt_labels.squeeze(dim=1))
78
+ reconLoss_idx = torch.zeros_like(palLoss_idx)
79
+ if self.enhanced:
80
+ scalar = 1.0 if self.hint2regress else 5.0
81
+ reconLoss_idx = scalar * self._perceptual_loss(data['input_gray'], data['pred_color'], data['input_color'])
82
+ if self.with_grad:
83
+ gradient_loss = self._laplace_gradient(data['pred_color'], data['input_color'])
84
+ reconLoss_idx += gradient_loss
85
+ totalLoss_idx = palLoss_idx + refLoss_idx + reconLoss_idx
86
+ #print("loss terms:", palLoss_idx.item(), refLoss_idx.item(), reconLoss_idx.item())
87
+ return {'totalLoss':totalLoss_idx, 'palLoss':palLoss_idx, 'refLoss':refLoss_idx, 'recLoss':reconLoss_idx}
88
+
89
+
90
+ def compute_affinity_pos_loss(prob_in, labxy_feat, pos_weight=0.003, kernel_size=16):
91
+ S = kernel_size
92
+ m = pos_weight
93
+ prob = prob_in.clone()
94
+ N,C,H,W = labxy_feat.shape
95
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
96
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
97
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
98
+ loss_feat = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
99
+ loss_pos = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() * m / S
100
+ loss_affinity = loss_feat + loss_pos
101
+ return loss_affinity
102
+
103
+
104
+ def l2_loss(y_input, y_target, weight_map=None):
105
+ if weight_map is None:
106
+ return F.mse_loss(y_input, y_target)
107
+ else:
108
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
109
+ batch_dev = torch.sum(diff_map*diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
110
+ return batch_dev.mean()
111
+
112
+
113
+ def l1_loss(y_input, y_target, weight_map=None):
114
+ if weight_map is None:
115
+ return F.l1_loss(y_input, y_target)
116
+ else:
117
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
118
+ batch_dev = torch.sum(diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
119
+ return batch_dev.mean()
120
+
121
+
122
+ def masked_l1_loss(y_input, y_target, outlier_mask):
123
+ one = torch.tensor([1.0]).cuda(y_input.get_device())
124
+ weight_map = torch.where(outlier_mask, one * 0.0, one * 1.0)
125
+ return l1_loss(y_input, y_target, weight_map)
126
+
127
+
128
+ def huber_loss(y_input, y_target, delta=0.01):
129
+ mask = torch.zeros_like(y_input)
130
+ mann = torch.abs(y_input - y_target)
131
+ eucl = 0.5 * (mann**2)
132
+ mask[...] = mann < delta
133
+ loss = eucl * mask / delta + (mann - 0.5 * delta) * (1 - mask)
134
+ return torch.mean(loss)
135
+
136
+
137
+ ## Perceptual loss that uses a pretrained VGG network
138
+ class VGG19Loss(nn.Module):
139
+ def __init__(self, feat_type='liu', gpu_no=0, is_ddp=False, requires_grad=False):
140
+ super(VGG19Loss, self).__init__()
141
+ os.environ['TORCH_HOME'] = '/apdcephfs/share_1290939/richardxia/Saved/Checkpoints/VGG19'
142
+ ## data requirement: (N,C,H,W) in RGB format, [0,1] range, and resolution >= 224x224
143
+ self.mean = [0.485, 0.456, 0.406]
144
+ self.std = [0.229, 0.224, 0.225]
145
+ self.feat_type = feat_type
146
+
147
+ vgg_model = torchvision.models.vgg19(pretrained=True)
148
+ ## AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient
149
+ '''
150
+ if is_ddp:
151
+ vgg_model = vgg_model.cuda(gpu_no)
152
+ vgg_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vgg_model)
153
+ vgg_model = torch.nn.parallel.DistributedDataParallel(vgg_model, device_ids=[gpu_no], find_unused_parameters=True)
154
+ else:
155
+ vgg_model = vgg_model.cuda(gpu_no)
156
+ '''
157
+ vgg_model = vgg_model.cuda(gpu_no)
158
+ if self.feat_type == 'liu':
159
+ ## conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
160
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:2]).eval()
161
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[2:7]).eval()
162
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[7:12]).eval()
163
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[12:21]).eval()
164
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[21:30]).eval()
165
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
166
+ elif self.feat_type == 'lei':
167
+ ## conv1_2, conv2_2, conv3_2, conv4_2, conv5_2
168
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:4]).eval()
169
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[4:9]).eval()
170
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[9:14]).eval()
171
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[14:23]).eval()
172
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[23:32]).eval()
173
+ self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10.0/1.5]
174
+ else:
175
+ ## maxpool after conv4_4
176
+ self.featureExactor = nn.Sequential(*list(vgg_model.features)[:28]).eval()
177
+ '''
178
+ for x in range(2):
179
+ self.slice1.add_module(str(x), pretrained_features[x])
180
+ for x in range(2, 7):
181
+ self.slice2.add_module(str(x), pretrained_features[x])
182
+ for x in range(7, 12):
183
+ self.slice3.add_module(str(x), pretrained_features[x])
184
+ for x in range(12, 21):
185
+ self.slice4.add_module(str(x), pretrained_features[x])
186
+ for x in range(21, 30):
187
+ self.slice5.add_module(str(x), pretrained_features[x])
188
+ '''
189
+ self.criterion = nn.L1Loss()
190
+
191
+ ## fixed parameters
192
+ if not requires_grad:
193
+ for param in self.parameters():
194
+ param.requires_grad = False
195
+ self.eval()
196
+ print('[*] VGG19Loss init!')
197
+
198
+ def normalize(self, tensor):
199
+ tensor = tensor.clone()
200
+ mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
201
+ std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
202
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
203
+ return tensor
204
+
205
+ def forward(self, x, y):
206
+ norm_x, norm_y = self.normalize(x), self.normalize(y)
207
+ ## feature extract
208
+ if self.feat_type == 'liu' or self.feat_type == 'lei':
209
+ x_relu1, y_relu1 = self.slice1(norm_x), self.slice1(norm_y)
210
+ x_relu2, y_relu2 = self.slice2(x_relu1), self.slice2(y_relu1)
211
+ x_relu3, y_relu3 = self.slice3(x_relu2), self.slice3(y_relu2)
212
+ x_relu4, y_relu4 = self.slice4(x_relu3), self.slice4(y_relu3)
213
+ x_relu5, y_relu5 = self.slice5(x_relu4), self.slice5(y_relu4)
214
+ x_vgg = [x_relu1, x_relu2, x_relu3, x_relu4, x_relu5]
215
+ y_vgg = [y_relu1, y_relu2, y_relu3, y_relu4, y_relu5]
216
+ loss = 0
217
+ for i in range(len(x_vgg)):
218
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
219
+ else:
220
+ x_vgg, y_vgg = self.featureExactor(norm_x), self.featureExactor(norm_y)
221
+ loss = self.criterion(x_vgg, y_vgg.detach())
222
+ return loss