feng2022's picture
losses
4bf9ab0
raw
history blame
6 kB
import torch
import torch.nn.functional as F
from .config import LOSS_TYPES
__all__ = ['contextual_loss', 'contextual_bilateral_loss']
def contextual_loss(x: torch.Tensor,
y: torch.Tensor,
band_width: float = 0.5,
loss_type: str = 'cosine',
all_dist: bool = False):
"""
Computes contextual loss between x and y.
The most of this code is copied from
https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da.
Parameters
---
x : torch.Tensor
features of shape (N, C, H, W).
y : torch.Tensor
features of shape (N, C, H, W).
band_width : float, optional
a band-width parameter used to convert distance to similarity.
in the paper, this is described as :math:`h`.
loss_type : str, optional
a loss type to measure the distance between features.
Note: `l1` and `l2` frequently raises OOM.
Returns
---
cx_loss : torch.Tensor
contextual loss between x and y (Eq (1) in the paper)
"""
assert x.size() == y.size(), 'input tensor must have the same size.'
assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
N, C, H, W = x.size()
if loss_type == 'cosine':
dist_raw = compute_cosine_distance(x, y)
elif loss_type == 'l1':
dist_raw = compute_l1_distance(x, y)
elif loss_type == 'l2':
dist_raw = compute_l2_distance(x, y)
dist_tilde = compute_relative_distance(dist_raw)
cx = compute_cx(dist_tilde, band_width)
if all_dist:
return cx
cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1)
cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5)
return cx_loss
# TODO: Operation check
def contextual_bilateral_loss(x: torch.Tensor,
y: torch.Tensor,
weight_sp: float = 0.1,
band_width: float = 1.,
loss_type: str = 'cosine'):
"""
Computes Contextual Bilateral (CoBi) Loss between x and y,
proposed in https://arxiv.org/pdf/1905.05169.pdf.
Parameters
---
x : torch.Tensor
features of shape (N, C, H, W).
y : torch.Tensor
features of shape (N, C, H, W).
band_width : float, optional
a band-width parameter used to convert distance to similarity.
in the paper, this is described as :math:`h`.
loss_type : str, optional
a loss type to measure the distance between features.
Note: `l1` and `l2` frequently raises OOM.
Returns
---
cx_loss : torch.Tensor
contextual loss between x and y (Eq (1) in the paper).
k_arg_max_NC : torch.Tensor
indices to maximize similarity over channels.
"""
assert x.size() == y.size(), 'input tensor must have the same size.'
assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
# spatial loss
grid = compute_meshgrid(x.shape).to(x.device)
dist_raw = compute_l2_distance(grid, grid)
dist_tilde = compute_relative_distance(dist_raw)
cx_sp = compute_cx(dist_tilde, band_width)
# feature loss
if loss_type == 'cosine':
dist_raw = compute_cosine_distance(x, y)
elif loss_type == 'l1':
dist_raw = compute_l1_distance(x, y)
elif loss_type == 'l2':
dist_raw = compute_l2_distance(x, y)
dist_tilde = compute_relative_distance(dist_raw)
cx_feat = compute_cx(dist_tilde, band_width)
# combined loss
cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
cx = k_max_NC.mean(dim=1)
cx_loss = torch.mean(-torch.log(cx + 1e-5))
return cx_loss
def compute_cx(dist_tilde, band_width):
w = torch.exp((1 - dist_tilde) / band_width) # Eq(3)
cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4)
return cx
def compute_relative_distance(dist_raw):
dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True)
dist_tilde = dist_raw / (dist_min + 1e-5)
return dist_tilde
def compute_cosine_distance(x, y):
# mean shifting by channel-wise mean of `y`.
y_mu = y.mean(dim=(0, 2, 3), keepdim=True)
x_centered = x - y_mu
y_centered = y - y_mu
# L2 normalization
x_normalized = F.normalize(x_centered, p=2, dim=1)
y_normalized = F.normalize(y_centered, p=2, dim=1)
# channel-wise vectorization
N, C, *_ = x.size()
x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W)
y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W)
# consine similarity
cosine_sim = torch.bmm(x_normalized.transpose(1, 2),
y_normalized) # (N, H*W, H*W)
# convert to distance
dist = 1 - cosine_sim
return dist
# TODO: Considering avoiding OOM.
def compute_l1_distance(x: torch.Tensor, y: torch.Tensor):
N, C, H, W = x.size()
x_vec = x.view(N, C, -1)
y_vec = y.view(N, C, -1)
dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3)
dist = dist.abs().sum(dim=1)
dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
dist = dist.clamp(min=0.)
return dist
# TODO: Considering avoiding OOM.
def compute_l2_distance(x, y):
N, C, H, W = x.size()
x_vec = x.view(N, C, -1)
y_vec = y.view(N, C, -1)
x_s = torch.sum(x_vec ** 2, dim=1)
y_s = torch.sum(y_vec ** 2, dim=1)
A = y_vec.transpose(1, 2) @ x_vec
dist = y_s - 2 * A + x_s.transpose(0, 1)
dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
dist = dist.clamp(min=0.)
return dist
def compute_meshgrid(shape):
N, C, H, W = shape
rows = torch.arange(0, H, dtype=torch.float32) / (H + 1)
cols = torch.arange(0, W, dtype=torch.float32) / (W + 1)
feature_grid = torch.meshgrid(rows, cols)
feature_grid = torch.stack(feature_grid).unsqueeze(0)
feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0)
return feature_grid