Spaces:
Runtime error
Runtime error
File size: 2,302 Bytes
cc80adf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.nn as nn
from .vgg import VGG19
from .. import functional as F
from ..config import LOSS_TYPES
class ContextualBilateralLoss(nn.Module):
"""
Creates a criterion that measures the contextual bilateral loss.
Parameters
---
weight_sp : float, optional
a balancing weight between spatial and feature loss.
band_width : int, optional
a band_width parameter described as :math:`h` in the paper.
use_vgg : bool, optional
if you want to use VGG feature, set this `True`.
vgg_layer : str, optional
intermidiate layer name for VGG feature.
Now we support layer names:
`['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
"""
def __init__(self,
weight_sp: float = 0.1,
band_width: float = 0.5,
loss_type: str = 'cosine',
use_vgg: bool = False,
vgg_layer: str = 'relu3_4'):
super(ContextualBilateralLoss, self).__init__()
assert band_width > 0, 'band_width parameter must be positive.'
assert loss_type in LOSS_TYPES,\
f'select a loss type from {LOSS_TYPES}.'
self.band_width = band_width
if use_vgg:
self.vgg_model = VGG19()
self.vgg_layer = vgg_layer
self.register_buffer(
name='vgg_mean',
tensor=torch.tensor(
[[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
)
self.register_buffer(
name='vgg_std',
tensor=torch.tensor(
[[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
)
def forward(self, x, y):
if hasattr(self, 'vgg_model'):
assert x.shape[1] == 3 and y.shape[1] == 3,\
'VGG model takes 3 chennel images.'
# normalization
x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
# picking up vgg feature maps
x = getattr(self.vgg_model(x), self.vgg_layer)
y = getattr(self.vgg_model(y), self.vgg_layer)
return F.contextual_bilateral_loss(x, y, self.band_width)
|