File size: 4,487 Bytes
47c46ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
"""
import torch
import torchvision
from models.vggface import VGGFaceFeats


def cos_loss(fi, ft):
    return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()


class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=False):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
        self.resize = resize

    def forward(self, input, target, max_layer=4, cos_dist: bool = False):
        target = (target + 1) * 0.5
        input = (input + 1) * 0.5

        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        x = input
        y = target
        loss = 0.0
        loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
        for bi, block in enumerate(self.blocks[:max_layer]):
            x = block(x)
            y = block(y)
            loss += loss_func(x, y.detach())
        return loss


class VGGFacePerceptualLoss(torch.nn.Module):
    def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
        super().__init__()
        self.vgg = VGGFaceFeats()
        self.vgg.load_state_dict(torch.load(weight_path))

        mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
        self.register_buffer("mean", mean)

        self.transform = torch.nn.functional.interpolate
        self.resize = resize

    def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
        target = (target + 1) * 0.5
        input = (input + 1) * 0.5

        # preprocessing
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = input - self.mean
        target = target - self.mean
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)

        input_feats = self.vgg(input)
        target_feats = self.vgg(target)

        loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
        # calc perceptual loss
        loss = 0.0
        for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
            loss = loss + loss_func(fi, ft.detach())
        return loss


class PerceptualLoss(torch.nn.Module):
    def __init__(
            self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1,  eps: float = 1e-8, cos_dist: bool = False
    ):
        super().__init__()
        self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
        self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
        self.cos_dist = cos_dist

        if lambda_vgg > eps:
            self.vgg = VGGPerceptualLoss()
        if lambda_vggface > eps:
            self.vggface = VGGFacePerceptualLoss()

    def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
        loss = 0.0
        if self.lambda_vgg > eps and use_vgg:
            loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
        if self.lambda_vggface > eps and use_vggface:
            loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
        return loss