feng2022's picture
losses
4bf9ab0
raw
history blame
3.98 kB
import random
from typing import (
Iterable,
List,
Optional,
)
import numpy as np
import torch
import torch.nn as nn
from .vgg import VGG19
from .. import functional as F
from ..config import LOSS_TYPES
class ContextualLoss(nn.Module):
"""
Creates a criterion that measures the contextual loss.
Parameters
---
band_width : int, optional
a band_width parameter described as :math:`h` in the paper.
use_vgg : bool, optional
if you want to use VGG feature, set this `True`.
vgg_layer : str, optional
intermidiate layer name for VGG feature.
Now we support layer names:
`['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
"""
def __init__(
self,
band_width: float = 0.5,
loss_type: str = 'cosine',
use_vgg: bool = False,
vgg_model: nn.Module = None,
vgg_layers: List[str] = ['relu3_4'],
feature_1d_size: int = 64,
):
super().__init__()
assert band_width > 0, 'band_width parameter must be positive.'
assert loss_type in LOSS_TYPES,\
f'select a loss type from {LOSS_TYPES}.'
self.loss_type = loss_type
self.band_width = band_width
self.feature_1d_size = feature_1d_size
if use_vgg:
self.vgg_model = VGG19() if vgg_model is None else vgg_model
self.vgg_layers = vgg_layers
self.register_buffer(
name='vgg_mean',
tensor=torch.tensor(
[[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
)
self.register_buffer(
name='vgg_std',
tensor=torch.tensor(
[[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
)
def forward(self, x: torch.Tensor, y: torch.Tensor, all_dist: bool = False):
if not hasattr(self, 'vgg_model'):
return self.contextual_loss(x, y, self.feature_1d_size, self.band_width, all_dist=all_dist)
x = self.forward_vgg(x)
y = self.forward_vgg(y)
loss = 0
for layer in self.vgg_layers:
# picking up vgg feature maps
fx = getattr(x, layer)
fy = getattr(y, layer)
loss = loss + self.contextual_loss(
fx, fy, self.feature_1d_size, self.band_width, all_dist=all_dist, loss_type=self.loss_type
)
return loss
def forward_vgg(self, x: torch.Tensor):
assert x.shape[1] == 3, 'VGG model takes 3 chennel images.'
# [-1, 1] -> [0, 1]
x = (x + 1) * 0.5
# normalization
x = x.sub(self.vgg_mean.detach()).div(self.vgg_std)
return self.vgg_model(x)
@classmethod
def contextual_loss(
cls,
x: torch.Tensor, y: torch.Tensor,
feature_1d_size: int,
band_width: int,
all_dist: bool = False,
loss_type: str = 'cosine',
) -> torch.Tensor:
feature_size = feature_1d_size ** 2
if np.prod(x.shape[2:]) > feature_size or np.prod(y.shape[2:]) > feature_size:
x, indices = cls.random_sampling(x, feature_1d_size=feature_1d_size)
y, _ = cls.random_sampling(y, feature_1d_size=feature_1d_size, indices=indices)
return F.contextual_loss(x, y, band_width, all_dist=all_dist, loss_type=loss_type)
@staticmethod
def random_sampling(
tensor_NCHW: torch.Tensor, feature_1d_size: int, indices: Optional[List] = None
):
N, C, H, W = tensor_NCHW.shape
S = H * W
tensor_NCS = tensor_NCHW.reshape([N, C, S])
if indices is None:
all_indices = list(range(S))
random.shuffle(all_indices)
indices = all_indices[:feature_1d_size**2]
res = tensor_NCS[:, :, indices].reshape(N, -1, feature_1d_size, feature_1d_size)
return res, indices