File size: 5,963 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
@LOSSES.register_module()
class JointsMSELoss(nn.Module):
"""MSE loss for heatmaps.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss()
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0.
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
loss += self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx])
else:
loss += self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints * self.loss_weight
@LOSSES.register_module()
class CombinedTargetMSELoss(nn.Module):
"""MSE loss for combined target.
CombinedTarget: The combination of classification target
(response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_channels = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_channels, -1)).split(1, 1)
heatmaps_gt = target.reshape(
(batch_size, num_channels, -1)).split(1, 1)
loss = 0.
num_joints = num_channels // 3
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx * 3].squeeze()
heatmap_gt = heatmaps_gt[idx * 3].squeeze()
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
if self.use_target_weight:
heatmap_pred = heatmap_pred * target_weight[:, idx]
heatmap_gt = heatmap_gt * target_weight[:, idx]
# classification loss
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
# regression loss
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
heatmap_gt * offset_x_gt)
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
heatmap_gt * offset_y_gt)
return loss / num_joints * self.loss_weight
@LOSSES.register_module()
class JointsOHKMMSELoss(nn.Module):
"""MSE loss with online hard keypoint mining.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
topk (int): Only top k joint losses are kept.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, topk=8, loss_weight=1.):
super().__init__()
assert topk > 0
self.criterion = nn.MSELoss(reduction='none')
self.use_target_weight = use_target_weight
self.topk = topk
self.loss_weight = loss_weight
def _ohkm(self, loss):
"""Online hard keypoint mining."""
ohkm_loss = 0.
N = len(loss)
for i in range(N):
sub_loss = loss[i]
_, topk_idx = torch.topk(
sub_loss, k=self.topk, dim=0, sorted=False)
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
ohkm_loss += torch.sum(tmp_loss) / self.topk
ohkm_loss /= N
return ohkm_loss
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
if num_joints < self.topk:
raise ValueError(f'topk ({self.topk}) should not '
f'larger than num_joints ({num_joints}).')
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
losses = []
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
losses.append(
self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx]))
else:
losses.append(self.criterion(heatmap_pred, heatmap_gt))
losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses]
losses = torch.cat(losses, dim=1)
return self._ohkm(losses) * self.loss_weight
|