|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
class ReConsLoss(nn.Module): |
|
def __init__(self, recons_loss, nb_joints): |
|
super(ReConsLoss, self).__init__() |
|
|
|
if recons_loss == 'l1': |
|
self.Loss = torch.nn.L1Loss() |
|
elif recons_loss == 'l2' : |
|
self.Loss = torch.nn.MSELoss() |
|
elif recons_loss == 'l1_smooth' : |
|
self.Loss = torch.nn.SmoothL1Loss() |
|
|
|
|
|
|
|
|
|
|
|
self.nb_joints = nb_joints |
|
self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4 |
|
|
|
def forward(self, motion_pred, motion_gt) : |
|
loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim]) |
|
return loss |
|
|
|
def forward_vel(self, motion_pred, motion_gt) : |
|
loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4]) |
|
return loss |
|
|
|
def loss_robust(origin_prediction,perturbation_prediction,loss='l2'): |
|
if loss=='l1': |
|
return torch.nn.L1Loss(origin_prediction,perturbation_prediction) |
|
if loss=='l2': |
|
return torch.nn.MSELoss(origin_prediction,perturbation_prediction) |
|
if loss=='l1_smooth': |
|
return torch.nn.SmoothL1Loss(origin_prediction,perturbation_prediction) |
|
if loss=='jsd': |
|
return calculate_kl_divergence(origin_prediction,perturbation_prediction) |
|
|
|
|
|
def calculate_kl_divergence(p, q): |
|
return torch.sum(p * (torch.log2(p) - torch.log2(q))) |
|
|
|
def calculate_jsd_loss(p, q): |
|
|
|
p_normalized = F.normalize(p, p=1, dim=-1) |
|
q_normalized = F.normalize(q, p=1, dim=-1) |
|
|
|
|
|
m = 0.5 * (p_normalized + q_normalized) |
|
|
|
|
|
kl_p = calculate_kl_divergence(p_normalized, m) |
|
kl_q = calculate_kl_divergence(q_normalized, m) |
|
|
|
|
|
jsd_loss = 0.5 * (kl_p + kl_q) |
|
|
|
return jsd_loss |