Spaces:
Runtime error
Runtime error
Time-TravelRephotography
/
Time_TravelRephotography
/losses
/contextual_loss
/modules
/contextual.py
import random | |
from typing import ( | |
Iterable, | |
List, | |
Optional, | |
) | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .vgg import VGG19 | |
from .. import functional as F | |
from ..config import LOSS_TYPES | |
class ContextualLoss(nn.Module): | |
""" | |
Creates a criterion that measures the contextual loss. | |
Parameters | |
--- | |
band_width : int, optional | |
a band_width parameter described as :math:`h` in the paper. | |
use_vgg : bool, optional | |
if you want to use VGG feature, set this `True`. | |
vgg_layer : str, optional | |
intermidiate layer name for VGG feature. | |
Now we support layer names: | |
`['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']` | |
""" | |
def __init__( | |
self, | |
band_width: float = 0.5, | |
loss_type: str = 'cosine', | |
use_vgg: bool = False, | |
vgg_model: nn.Module = None, | |
vgg_layers: List[str] = ['relu3_4'], | |
feature_1d_size: int = 64, | |
): | |
super().__init__() | |
assert band_width > 0, 'band_width parameter must be positive.' | |
assert loss_type in LOSS_TYPES,\ | |
f'select a loss type from {LOSS_TYPES}.' | |
self.loss_type = loss_type | |
self.band_width = band_width | |
self.feature_1d_size = feature_1d_size | |
if use_vgg: | |
self.vgg_model = VGG19() if vgg_model is None else vgg_model | |
self.vgg_layers = vgg_layers | |
self.register_buffer( | |
name='vgg_mean', | |
tensor=torch.tensor( | |
[[[0.485]], [[0.456]], [[0.406]]], requires_grad=False) | |
) | |
self.register_buffer( | |
name='vgg_std', | |
tensor=torch.tensor( | |
[[[0.229]], [[0.224]], [[0.225]]], requires_grad=False) | |
) | |
def forward(self, x: torch.Tensor, y: torch.Tensor, all_dist: bool = False): | |
if not hasattr(self, 'vgg_model'): | |
return self.contextual_loss(x, y, self.feature_1d_size, self.band_width, all_dist=all_dist) | |
x = self.forward_vgg(x) | |
y = self.forward_vgg(y) | |
loss = 0 | |
for layer in self.vgg_layers: | |
# picking up vgg feature maps | |
fx = getattr(x, layer) | |
fy = getattr(y, layer) | |
loss = loss + self.contextual_loss( | |
fx, fy, self.feature_1d_size, self.band_width, all_dist=all_dist, loss_type=self.loss_type | |
) | |
return loss | |
def forward_vgg(self, x: torch.Tensor): | |
assert x.shape[1] == 3, 'VGG model takes 3 chennel images.' | |
# [-1, 1] -> [0, 1] | |
x = (x + 1) * 0.5 | |
# normalization | |
x = x.sub(self.vgg_mean.detach()).div(self.vgg_std) | |
return self.vgg_model(x) | |
def contextual_loss( | |
cls, | |
x: torch.Tensor, y: torch.Tensor, | |
feature_1d_size: int, | |
band_width: int, | |
all_dist: bool = False, | |
loss_type: str = 'cosine', | |
) -> torch.Tensor: | |
feature_size = feature_1d_size ** 2 | |
if np.prod(x.shape[2:]) > feature_size or np.prod(y.shape[2:]) > feature_size: | |
x, indices = cls.random_sampling(x, feature_1d_size=feature_1d_size) | |
y, _ = cls.random_sampling(y, feature_1d_size=feature_1d_size, indices=indices) | |
return F.contextual_loss(x, y, band_width, all_dist=all_dist, loss_type=loss_type) | |
def random_sampling( | |
tensor_NCHW: torch.Tensor, feature_1d_size: int, indices: Optional[List] = None | |
): | |
N, C, H, W = tensor_NCHW.shape | |
S = H * W | |
tensor_NCS = tensor_NCHW.reshape([N, C, S]) | |
if indices is None: | |
all_indices = list(range(S)) | |
random.shuffle(all_indices) | |
indices = all_indices[:feature_1d_size**2] | |
res = tensor_NCS[:, :, indices].reshape(N, -1, feature_1d_size, feature_1d_size) | |
return res, indices | |