import torch import torch.nn.functional as F from .config import LOSS_TYPES __all__ = ['contextual_loss', 'contextual_bilateral_loss'] def contextual_loss(x: torch.Tensor, y: torch.Tensor, band_width: float = 0.5, loss_type: str = 'cosine', all_dist: bool = False): """ Computes contextual loss between x and y. The most of this code is copied from https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da. Parameters --- x : torch.Tensor features of shape (N, C, H, W). y : torch.Tensor features of shape (N, C, H, W). band_width : float, optional a band-width parameter used to convert distance to similarity. in the paper, this is described as :math:`h`. loss_type : str, optional a loss type to measure the distance between features. Note: `l1` and `l2` frequently raises OOM. Returns --- cx_loss : torch.Tensor contextual loss between x and y (Eq (1) in the paper) """ assert x.size() == y.size(), 'input tensor must have the same size.' assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' N, C, H, W = x.size() if loss_type == 'cosine': dist_raw = compute_cosine_distance(x, y) elif loss_type == 'l1': dist_raw = compute_l1_distance(x, y) elif loss_type == 'l2': dist_raw = compute_l2_distance(x, y) dist_tilde = compute_relative_distance(dist_raw) cx = compute_cx(dist_tilde, band_width) if all_dist: return cx cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1) cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5) return cx_loss # TODO: Operation check def contextual_bilateral_loss(x: torch.Tensor, y: torch.Tensor, weight_sp: float = 0.1, band_width: float = 1., loss_type: str = 'cosine'): """ Computes Contextual Bilateral (CoBi) Loss between x and y, proposed in https://arxiv.org/pdf/1905.05169.pdf. Parameters --- x : torch.Tensor features of shape (N, C, H, W). y : torch.Tensor features of shape (N, C, H, W). band_width : float, optional a band-width parameter used to convert distance to similarity. in the paper, this is described as :math:`h`. loss_type : str, optional a loss type to measure the distance between features. Note: `l1` and `l2` frequently raises OOM. Returns --- cx_loss : torch.Tensor contextual loss between x and y (Eq (1) in the paper). k_arg_max_NC : torch.Tensor indices to maximize similarity over channels. """ assert x.size() == y.size(), 'input tensor must have the same size.' assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.' # spatial loss grid = compute_meshgrid(x.shape).to(x.device) dist_raw = compute_l2_distance(grid, grid) dist_tilde = compute_relative_distance(dist_raw) cx_sp = compute_cx(dist_tilde, band_width) # feature loss if loss_type == 'cosine': dist_raw = compute_cosine_distance(x, y) elif loss_type == 'l1': dist_raw = compute_l1_distance(x, y) elif loss_type == 'l2': dist_raw = compute_l2_distance(x, y) dist_tilde = compute_relative_distance(dist_raw) cx_feat = compute_cx(dist_tilde, band_width) # combined loss cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True) cx = k_max_NC.mean(dim=1) cx_loss = torch.mean(-torch.log(cx + 1e-5)) return cx_loss def compute_cx(dist_tilde, band_width): w = torch.exp((1 - dist_tilde) / band_width) # Eq(3) cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4) return cx def compute_relative_distance(dist_raw): dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True) dist_tilde = dist_raw / (dist_min + 1e-5) return dist_tilde def compute_cosine_distance(x, y): # mean shifting by channel-wise mean of `y`. y_mu = y.mean(dim=(0, 2, 3), keepdim=True) x_centered = x - y_mu y_centered = y - y_mu # L2 normalization x_normalized = F.normalize(x_centered, p=2, dim=1) y_normalized = F.normalize(y_centered, p=2, dim=1) # channel-wise vectorization N, C, *_ = x.size() x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W) y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W) # consine similarity cosine_sim = torch.bmm(x_normalized.transpose(1, 2), y_normalized) # (N, H*W, H*W) # convert to distance dist = 1 - cosine_sim return dist # TODO: Considering avoiding OOM. def compute_l1_distance(x: torch.Tensor, y: torch.Tensor): N, C, H, W = x.size() x_vec = x.view(N, C, -1) y_vec = y.view(N, C, -1) dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3) dist = dist.abs().sum(dim=1) dist = dist.transpose(1, 2).reshape(N, H*W, H*W) dist = dist.clamp(min=0.) return dist # TODO: Considering avoiding OOM. def compute_l2_distance(x, y): N, C, H, W = x.size() x_vec = x.view(N, C, -1) y_vec = y.view(N, C, -1) x_s = torch.sum(x_vec ** 2, dim=1) y_s = torch.sum(y_vec ** 2, dim=1) A = y_vec.transpose(1, 2) @ x_vec dist = y_s - 2 * A + x_s.transpose(0, 1) dist = dist.transpose(1, 2).reshape(N, H*W, H*W) dist = dist.clamp(min=0.) return dist def compute_meshgrid(shape): N, C, H, W = shape rows = torch.arange(0, H, dtype=torch.float32) / (H + 1) cols = torch.arange(0, W, dtype=torch.float32) / (W + 1) feature_grid = torch.meshgrid(rows, cols) feature_grid = torch.stack(feature_grid).unsqueeze(0) feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0) return feature_grid