import torch import torch.nn as nn from .vgg import VGG19 from .. import functional as F from ..config import LOSS_TYPES class ContextualBilateralLoss(nn.Module): """ Creates a criterion that measures the contextual bilateral loss. Parameters --- weight_sp : float, optional a balancing weight between spatial and feature loss. band_width : int, optional a band_width parameter described as :math:`h` in the paper. use_vgg : bool, optional if you want to use VGG feature, set this `True`. vgg_layer : str, optional intermidiate layer name for VGG feature. Now we support layer names: `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` """ def __init__(self, weight_sp: float = 0.1, band_width: float = 0.5, loss_type: str = 'cosine', use_vgg: bool = False, vgg_layer: str = 'relu3_4'): super(ContextualBilateralLoss, self).__init__() assert band_width > 0, 'band_width parameter must be positive.' assert loss_type in LOSS_TYPES,\ f'select a loss type from {LOSS_TYPES}.' self.band_width = band_width if use_vgg: self.vgg_model = VGG19() self.vgg_layer = vgg_layer self.register_buffer( name='vgg_mean', tensor=torch.tensor( [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) ) self.register_buffer( name='vgg_std', tensor=torch.tensor( [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) ) def forward(self, x, y): if hasattr(self, 'vgg_model'): assert x.shape[1] == 3 and y.shape[1] == 3,\ 'VGG model takes 3 chennel images.' # normalization x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) # picking up vgg feature maps x = getattr(self.vgg_model(x), self.vgg_layer) y = getattr(self.vgg_model(y), self.vgg_layer) return F.contextual_bilateral_loss(x, y, self.band_width)