import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class Loss(nn.Module): def __init__(self, loss_weight, keys, mapping=None) -> None: ''' mapping: map the kwargs keys into desired ones. ''' super().__init__() self.loss_weight = loss_weight self.keys = keys self.mapping = mapping if isinstance(mapping, dict): self.mapping = {k: v for k, v in mapping if v in keys} def forward(self, **kwargs): params = {k: v for k, v in kwargs.items() if k in self.keys} if self.mapping is not None: for k, v in kwargs.items(): if self.mapping.get(k) is not None: params[self.mapping[k]] = v return self._forward(**params) * self.loss_weight def _forward(self, **kwargs): pass class CharbonnierLoss(Loss): def __init__(self, loss_weight, keys) -> None: super().__init__(loss_weight, keys) def _forward(self, imgt_pred, imgt): diff = imgt_pred - imgt loss = ((diff ** 2 + 1e-6) ** 0.5).mean() return loss class AdaCharbonnierLoss(Loss): def __init__(self, loss_weight, keys) -> None: super().__init__(loss_weight, keys) def _forward(self, imgt_pred, imgt, weight): alpha = weight / 2 epsilon = 10 ** (-(10 * weight - 1) / 3) diff = imgt_pred - imgt loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean() return loss class TernaryLoss(Loss): def __init__(self, loss_weight, keys, patch_size=7): super().__init__(loss_weight, keys) self.patch_size = patch_size out_channels = patch_size * patch_size self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) self.w = np.transpose(self.w, (3, 2, 0, 1)) self.w = torch.tensor(self.w, dtype=torch.float32) def transform(self, tensor): self.w = self.w.to(tensor.device) tensor_ = tensor.mean(dim=1, keepdim=True) patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None) loc_diff = patches - tensor_ loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2) return loc_diff_norm def valid_mask(self, tensor): padding = self.patch_size//2 b, c, h, w = tensor.size() inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) mask = F.pad(inner, [padding] * 4) return mask def _forward(self, imgt_pred, imgt): loc_diff_x = self.transform(imgt_pred) loc_diff_y = self.transform(imgt) diff = loc_diff_x - loc_diff_y.detach() dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True) mask = self.valid_mask(imgt_pred) loss = (dist * mask).mean() return loss class GeometryLoss(Loss): def __init__(self, loss_weight, keys, patch_size=3): super().__init__(loss_weight, keys) self.patch_size = patch_size out_channels = patch_size * patch_size self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) self.w = np.transpose(self.w, (3, 2, 0, 1)) self.w = torch.tensor(self.w).float() def transform(self, tensor): b, c, h, w = tensor.size() self.w = self.w.to(tensor.device) tensor_ = tensor.reshape(b*c, 1, h, w) patches = F.conv2d(tensor_, self.w, padding=self.patch_size // 2, bias=None) loc_diff = patches - tensor_ loc_diff_ = loc_diff.reshape(b, c*(self.patch_size ** 2), h, w) loc_diff_norm = loc_diff_ / torch.sqrt(0.81 + loc_diff_ ** 2) return loc_diff_norm def valid_mask(self, tensor): padding = self.patch_size // 2 b, c, h, w = tensor.size() inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) mask = F.pad(inner, [padding] * 4) return mask def _forward(self, ft_pred, ft_gt): loss = 0. for pred, gt in zip(ft_pred, ft_gt): loc_diff_x = self.transform(pred) loc_diff_y = self.transform(gt) diff = loc_diff_x - loc_diff_y dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True) mask = self.valid_mask(pred) loss = loss + (dist * mask).mean() return loss class IFRFlowLoss(Loss): def __init__(self, loss_weight, keys, beta=0.3) -> None: super().__init__(loss_weight, keys) self.beta = beta self.ada_cb_loss = AdaCharbonnierLoss(1.0, ['imgt_pred', 'imgt', 'weight']) def _forward(self, flow0_pred, flow1_pred, flow): robust_weight0 = self.get_robust_weight(flow0_pred[0], flow[:, 0:2]) robust_weight1 = self.get_robust_weight(flow1_pred[0], flow[:, 2:4]) loss = 0 for lvl in range(1, len(flow0_pred)): scale_factor = 2**lvl loss = loss + self.ada_cb_loss(**{ 'imgt_pred': self.resize(flow0_pred[lvl], scale_factor), 'imgt': flow[:, 0:2], 'weight': robust_weight0 }) loss = loss + self.ada_cb_loss(**{ 'imgt_pred': self.resize(flow1_pred[lvl], scale_factor), 'imgt': flow[:, 2:4], 'weight': robust_weight1 }) return loss def resize(self, x, scale_factor): return scale_factor * F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) def get_robust_weight(self, flow_pred, flow_gt): epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5 robust_weight = torch.exp(-self.beta * epe) return robust_weight class MultipleFlowLoss(Loss): def __init__(self, loss_weight, keys, beta=0.3) -> None: super().__init__(loss_weight, keys) self.beta = beta self.ada_cb_loss = AdaCharbonnierLoss(1.0, ['imgt_pred', 'imgt', 'weight']) def _forward(self, flow0_pred, flow1_pred, flow): robust_weight0 = self.get_mutli_flow_robust_weight(flow0_pred[0], flow[:, 0:2]) robust_weight1 = self.get_mutli_flow_robust_weight(flow1_pred[0], flow[:, 2:4]) loss = 0 for lvl in range(1, len(flow0_pred)): scale_factor = 2**lvl loss = loss + self.ada_cb_loss(**{ 'imgt_pred': self.resize(flow0_pred[lvl], scale_factor), 'imgt': flow[:, 0:2], 'weight': robust_weight0 }) loss = loss + self.ada_cb_loss(**{ 'imgt_pred': self.resize(flow1_pred[lvl], scale_factor), 'imgt': flow[:, 2:4], 'weight': robust_weight1 }) return loss def resize(self, x, scale_factor): return scale_factor * F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) def get_mutli_flow_robust_weight(self, flow_pred, flow_gt): b, num_flows, c, h, w = flow_pred.shape flow_pred = flow_pred.view(b, num_flows, c, h, w) flow_gt = flow_gt.repeat(1, num_flows, 1, 1).view(b, num_flows, c, h, w) epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=2, keepdim=True).max(1)[0] ** 0.5 robust_weight = torch.exp(-self.beta * epe) return robust_weight