Spaces:
Running
Running
Upload models/loss.py
Browse files- models/loss.py +222 -0
models/loss.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import os, glob, shutil, math, random, json
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision
|
7 |
+
import basic
|
8 |
+
from utils import util
|
9 |
+
|
10 |
+
eps = 0.0000001
|
11 |
+
|
12 |
+
class SPixelLoss:
|
13 |
+
def __init__(self, psize=8, mpdist=False, gpu_no=0):
|
14 |
+
self.mpdist = mpdist
|
15 |
+
self.gpu_no = gpu_no
|
16 |
+
self.sp_size = psize
|
17 |
+
|
18 |
+
def __call__(self, data, epoch_no):
|
19 |
+
kernel_size = self.sp_size
|
20 |
+
#pos_weight = 0.003
|
21 |
+
prob = data['pred_prob']
|
22 |
+
labxy_feat = data['target_feat']
|
23 |
+
N,C,H,W = labxy_feat.shape
|
24 |
+
pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
|
25 |
+
reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
|
26 |
+
loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
|
27 |
+
featLoss_idx = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
|
28 |
+
posLoss_idx = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() / kernel_size
|
29 |
+
totalLoss_idx = 10*featLoss_idx + 0.003*posLoss_idx
|
30 |
+
return {'totalLoss':totalLoss_idx, 'featLoss':featLoss_idx, 'posLoss':posLoss_idx}
|
31 |
+
|
32 |
+
|
33 |
+
class AnchorColorProbLoss:
|
34 |
+
def __init__(self, hint2regress=False, enhanced=False, with_grad=False, mpdist=False, gpu_no=0):
|
35 |
+
self.mpdist = mpdist
|
36 |
+
self.gpu_no = gpu_no
|
37 |
+
self.hint2regress = hint2regress
|
38 |
+
self.enhanced = enhanced
|
39 |
+
self.with_grad = with_grad
|
40 |
+
self.rebalance_gradient = basic.RebalanceLoss.apply
|
41 |
+
self.entropy_loss = nn.CrossEntropyLoss(ignore_index=-1)
|
42 |
+
if self.enhanced:
|
43 |
+
self.VGGLoss = VGG19Loss(gpu_no=gpu_no, is_ddp=mpdist)
|
44 |
+
|
45 |
+
def _perceptual_loss(self, input_grays, input_colors, pred_colors):
|
46 |
+
input_RGBs = basic.lab2rgb(torch.cat([input_grays,input_colors], dim=1))
|
47 |
+
pred_RGBs = basic.lab2rgb(torch.cat([input_grays,pred_colors], dim=1))
|
48 |
+
## the output of "lab2rgb" just matches the input of "VGGLoss": [0,1]
|
49 |
+
return self.VGGLoss(input_RGBs, pred_RGBs)
|
50 |
+
|
51 |
+
def _laplace_gradient(self, pred_AB, target_AB):
|
52 |
+
N,C,H,W = pred_AB.shape
|
53 |
+
kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], device=pred_AB.get_device()).float()
|
54 |
+
kernel = kernel.view(1, 1, *kernel.size()).repeat(C,1,1,1)
|
55 |
+
grad_pred = F.conv2d(pred_AB, kernel, groups=C)
|
56 |
+
grad_trg = F.conv2d(target_AB, kernel, groups=C)
|
57 |
+
return l1_loss(grad_trg, grad_pred)
|
58 |
+
|
59 |
+
def __call__(self, data, epoch_no):
|
60 |
+
N,C,H,W = data['target_label'].shape
|
61 |
+
pal_probs = self.rebalance_gradient(data['pal_prob'], data['class_weight'])
|
62 |
+
#ref_probs = data['ref_prob']
|
63 |
+
pal_probs = pal_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
64 |
+
gt_labels = data['target_label'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
65 |
+
'''
|
66 |
+
igored_mask = data['empty_entries'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
67 |
+
gt_labels[igored_mask] = -1
|
68 |
+
gt_labels = gt_probs.squeeze()
|
69 |
+
'''
|
70 |
+
palLoss_idx = self.entropy_loss(pal_probs, gt_labels.squeeze(dim=1))
|
71 |
+
if self.hint2regress:
|
72 |
+
ref_probs = data['ref_prob']
|
73 |
+
refLoss_idx = 50 * l2_loss(data['spix_color'], ref_probs)
|
74 |
+
else:
|
75 |
+
ref_probs = self.rebalance_gradient(data['ref_prob'], data['class_weight'])
|
76 |
+
ref_probs = ref_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
77 |
+
refLoss_idx = self.entropy_loss(ref_probs, gt_labels.squeeze(dim=1))
|
78 |
+
reconLoss_idx = torch.zeros_like(palLoss_idx)
|
79 |
+
if self.enhanced:
|
80 |
+
scalar = 1.0 if self.hint2regress else 5.0
|
81 |
+
reconLoss_idx = scalar * self._perceptual_loss(data['input_gray'], data['pred_color'], data['input_color'])
|
82 |
+
if self.with_grad:
|
83 |
+
gradient_loss = self._laplace_gradient(data['pred_color'], data['input_color'])
|
84 |
+
reconLoss_idx += gradient_loss
|
85 |
+
totalLoss_idx = palLoss_idx + refLoss_idx + reconLoss_idx
|
86 |
+
#print("loss terms:", palLoss_idx.item(), refLoss_idx.item(), reconLoss_idx.item())
|
87 |
+
return {'totalLoss':totalLoss_idx, 'palLoss':palLoss_idx, 'refLoss':refLoss_idx, 'recLoss':reconLoss_idx}
|
88 |
+
|
89 |
+
|
90 |
+
def compute_affinity_pos_loss(prob_in, labxy_feat, pos_weight=0.003, kernel_size=16):
|
91 |
+
S = kernel_size
|
92 |
+
m = pos_weight
|
93 |
+
prob = prob_in.clone()
|
94 |
+
N,C,H,W = labxy_feat.shape
|
95 |
+
pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
|
96 |
+
reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
|
97 |
+
loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
|
98 |
+
loss_feat = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
|
99 |
+
loss_pos = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() * m / S
|
100 |
+
loss_affinity = loss_feat + loss_pos
|
101 |
+
return loss_affinity
|
102 |
+
|
103 |
+
|
104 |
+
def l2_loss(y_input, y_target, weight_map=None):
|
105 |
+
if weight_map is None:
|
106 |
+
return F.mse_loss(y_input, y_target)
|
107 |
+
else:
|
108 |
+
diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
|
109 |
+
batch_dev = torch.sum(diff_map*diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
|
110 |
+
return batch_dev.mean()
|
111 |
+
|
112 |
+
|
113 |
+
def l1_loss(y_input, y_target, weight_map=None):
|
114 |
+
if weight_map is None:
|
115 |
+
return F.l1_loss(y_input, y_target)
|
116 |
+
else:
|
117 |
+
diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
|
118 |
+
batch_dev = torch.sum(diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
|
119 |
+
return batch_dev.mean()
|
120 |
+
|
121 |
+
|
122 |
+
def masked_l1_loss(y_input, y_target, outlier_mask):
|
123 |
+
one = torch.tensor([1.0]).cuda(y_input.get_device())
|
124 |
+
weight_map = torch.where(outlier_mask, one * 0.0, one * 1.0)
|
125 |
+
return l1_loss(y_input, y_target, weight_map)
|
126 |
+
|
127 |
+
|
128 |
+
def huber_loss(y_input, y_target, delta=0.01):
|
129 |
+
mask = torch.zeros_like(y_input)
|
130 |
+
mann = torch.abs(y_input - y_target)
|
131 |
+
eucl = 0.5 * (mann**2)
|
132 |
+
mask[...] = mann < delta
|
133 |
+
loss = eucl * mask / delta + (mann - 0.5 * delta) * (1 - mask)
|
134 |
+
return torch.mean(loss)
|
135 |
+
|
136 |
+
|
137 |
+
## Perceptual loss that uses a pretrained VGG network
|
138 |
+
class VGG19Loss(nn.Module):
|
139 |
+
def __init__(self, feat_type='liu', gpu_no=0, is_ddp=False, requires_grad=False):
|
140 |
+
super(VGG19Loss, self).__init__()
|
141 |
+
os.environ['TORCH_HOME'] = '/apdcephfs/share_1290939/richardxia/Saved/Checkpoints/VGG19'
|
142 |
+
## data requirement: (N,C,H,W) in RGB format, [0,1] range, and resolution >= 224x224
|
143 |
+
self.mean = [0.485, 0.456, 0.406]
|
144 |
+
self.std = [0.229, 0.224, 0.225]
|
145 |
+
self.feat_type = feat_type
|
146 |
+
|
147 |
+
vgg_model = torchvision.models.vgg19(pretrained=True)
|
148 |
+
## AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient
|
149 |
+
'''
|
150 |
+
if is_ddp:
|
151 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
152 |
+
vgg_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vgg_model)
|
153 |
+
vgg_model = torch.nn.parallel.DistributedDataParallel(vgg_model, device_ids=[gpu_no], find_unused_parameters=True)
|
154 |
+
else:
|
155 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
156 |
+
'''
|
157 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
158 |
+
if self.feat_type == 'liu':
|
159 |
+
## conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
|
160 |
+
self.slice1 = nn.Sequential(*list(vgg_model.features)[:2]).eval()
|
161 |
+
self.slice2 = nn.Sequential(*list(vgg_model.features)[2:7]).eval()
|
162 |
+
self.slice3 = nn.Sequential(*list(vgg_model.features)[7:12]).eval()
|
163 |
+
self.slice4 = nn.Sequential(*list(vgg_model.features)[12:21]).eval()
|
164 |
+
self.slice5 = nn.Sequential(*list(vgg_model.features)[21:30]).eval()
|
165 |
+
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
166 |
+
elif self.feat_type == 'lei':
|
167 |
+
## conv1_2, conv2_2, conv3_2, conv4_2, conv5_2
|
168 |
+
self.slice1 = nn.Sequential(*list(vgg_model.features)[:4]).eval()
|
169 |
+
self.slice2 = nn.Sequential(*list(vgg_model.features)[4:9]).eval()
|
170 |
+
self.slice3 = nn.Sequential(*list(vgg_model.features)[9:14]).eval()
|
171 |
+
self.slice4 = nn.Sequential(*list(vgg_model.features)[14:23]).eval()
|
172 |
+
self.slice5 = nn.Sequential(*list(vgg_model.features)[23:32]).eval()
|
173 |
+
self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10.0/1.5]
|
174 |
+
else:
|
175 |
+
## maxpool after conv4_4
|
176 |
+
self.featureExactor = nn.Sequential(*list(vgg_model.features)[:28]).eval()
|
177 |
+
'''
|
178 |
+
for x in range(2):
|
179 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
180 |
+
for x in range(2, 7):
|
181 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
182 |
+
for x in range(7, 12):
|
183 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
184 |
+
for x in range(12, 21):
|
185 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
186 |
+
for x in range(21, 30):
|
187 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
188 |
+
'''
|
189 |
+
self.criterion = nn.L1Loss()
|
190 |
+
|
191 |
+
## fixed parameters
|
192 |
+
if not requires_grad:
|
193 |
+
for param in self.parameters():
|
194 |
+
param.requires_grad = False
|
195 |
+
self.eval()
|
196 |
+
print('[*] VGG19Loss init!')
|
197 |
+
|
198 |
+
def normalize(self, tensor):
|
199 |
+
tensor = tensor.clone()
|
200 |
+
mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
|
201 |
+
std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
|
202 |
+
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
203 |
+
return tensor
|
204 |
+
|
205 |
+
def forward(self, x, y):
|
206 |
+
norm_x, norm_y = self.normalize(x), self.normalize(y)
|
207 |
+
## feature extract
|
208 |
+
if self.feat_type == 'liu' or self.feat_type == 'lei':
|
209 |
+
x_relu1, y_relu1 = self.slice1(norm_x), self.slice1(norm_y)
|
210 |
+
x_relu2, y_relu2 = self.slice2(x_relu1), self.slice2(y_relu1)
|
211 |
+
x_relu3, y_relu3 = self.slice3(x_relu2), self.slice3(y_relu2)
|
212 |
+
x_relu4, y_relu4 = self.slice4(x_relu3), self.slice4(y_relu3)
|
213 |
+
x_relu5, y_relu5 = self.slice5(x_relu4), self.slice5(y_relu4)
|
214 |
+
x_vgg = [x_relu1, x_relu2, x_relu3, x_relu4, x_relu5]
|
215 |
+
y_vgg = [y_relu1, y_relu2, y_relu3, y_relu4, y_relu5]
|
216 |
+
loss = 0
|
217 |
+
for i in range(len(x_vgg)):
|
218 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
219 |
+
else:
|
220 |
+
x_vgg, y_vgg = self.featureExactor(norm_x), self.featureExactor(norm_y)
|
221 |
+
loss = self.criterion(x_vgg, y_vgg.detach())
|
222 |
+
return loss
|