import torch import torch.nn as nn class Keypoint2DLoss(nn.Module): def __init__(self, loss_type: str = 'l1'): """ 2D keypoint loss module. Args: loss_type (str): Choose between l1 and l2 losses. """ super(Keypoint2DLoss, self).__init__() if loss_type == 'l1': self.loss_fn = nn.L1Loss(reduction='none') elif loss_type == 'l2': self.loss_fn = nn.MSELoss(reduction='none') else: raise NotImplementedError('Unsupported loss function') def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: """ Compute 2D reprojection loss on the keypoints. Args: pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. Returns: torch.Tensor: 2D keypoint loss. """ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() batch_size = conf.shape[0] loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2)) return loss.sum() class Keypoint3DLoss(nn.Module): def __init__(self, loss_type: str = 'l1'): """ 3D keypoint loss module. Args: loss_type (str): Choose between l1 and l2 losses. """ super(Keypoint3DLoss, self).__init__() if loss_type == 'l1': self.loss_fn = nn.L1Loss(reduction='none') elif loss_type == 'l2': self.loss_fn = nn.MSELoss(reduction='none') else: raise NotImplementedError('Unsupported loss function') def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 39): """ Compute 3D keypoint loss. Args: pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. Returns: torch.Tensor: 3D keypoint loss. """ batch_size = pred_keypoints_3d.shape[0] gt_keypoints_3d = gt_keypoints_3d.clone() pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1) conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2)) return loss.sum() class ParameterLoss(nn.Module): def __init__(self): """ SMPL parameter loss module. """ super(ParameterLoss, self).__init__() self.loss_fn = nn.MSELoss(reduction='none') def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): """ Compute SMPL parameter loss. Args: pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth SMPL parameters. Returns: torch.Tensor: L2 parameter loss loss. """ batch_size = pred_param.shape[0] num_dims = len(pred_param.shape) mask_dimension = [batch_size] + [1] * (num_dims-1) has_param = has_param.type(pred_param.type()).view(*mask_dimension) loss_param = (has_param * self.loss_fn(pred_param, gt_param)) return loss_param.sum()