File size: 1,824 Bytes
91fc62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import torch
import torch.nn.functional as F
import torchvision


class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        os.environ['TORCH_HOME'] = os.path.abspath(os.getcwd())
        blocks = [torchvision.models.vgg16().features[:4].eval(),
                  torchvision.models.vgg16().features[4:9].eval(),
                  torchvision.models.vgg16().features[9:16].eval(),
                  torchvision.models.vgg16().features[16:23].eval()]
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)

        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x, y):
        x = x * 0.5 + 0.5
        y = y * 0.5 + 0.5
        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std

        x = F.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False)
        y = F.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False)
        perceptual_loss = 0.0
        style_loss = 0.0

        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)

            perceptual_loss += torch.nn.functional.l1_loss(x, y)

            # b, ch, h, w = x.shape
            # act_x = x.reshape(x.shape[0], x.shape[1], -1)
            # act_y = y.reshape(y.shape[0], y.shape[1], -1)
            # gram_x = act_x @ act_x.permute(0, 2, 1) / (ch * h * w)
            # gram_y = act_y @ act_y.permute(0, 2, 1) / (ch * h * w)
            # style_loss += torch.nn.functional.l1_loss(gram_x, gram_y)

        return perceptual_loss#, style_loss