File size: 3,877 Bytes
47c46ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from argparse import (
    ArgumentParser,
    Namespace,
)
from typing import Optional

import numpy as np
import torch
from torch import nn

from losses.perceptual_loss import PerceptualLoss
from models.degrade import Downsample
from utils.misc import optional_string


class ReconstructionArguments:
    @staticmethod
    def add_arguments(parser: ArgumentParser):
        parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
        parser.add_argument("--vgg", type=float, default=1, help="vgg")
        parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")

    @staticmethod
    def to_string(args: Namespace) -> str:
        return (
            f"s{args.recon_size}"
            + optional_string(args.vgg > 0, f"-vgg{args.vgg}")
            + optional_string(args.vggface > 0, f"-vggface{args.vggface}")
        )


def create_perceptual_loss(args: Namespace):
    return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)


class EyeLoss(nn.Module):
    def __init__(
            self,
            target: torch.Tensor,
            input_size: int = 1024,
            input_channels: int = 3,
            percept: Optional[nn.Module] = None,
            args: Optional[Namespace] = None
    ):
        """
        target: target image
        """
        assert not (percept is None and args is None)

        super().__init__()

        self.target = target

        target_size = target.shape[-1]
        self.downsample = Downsample(input_size, target_size, input_channels) \
                if target_size != input_size else (lambda x: x)

        self.percept = percept if percept is not None else create_perceptual_loss(args)

        eye_size = np.array((224, 224))
        btlrs = []
        for sgn in [1, -1]:
            center = np.array((480, 384 * sgn))   # (y, x)
            b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
            l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
            btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
        self.btlrs = np.stack(btlrs, axis=0)

    def forward(self, img: torch.Tensor, degrade: nn.Module = None):
        """
        img: it should be the degraded version of the generated image
        """
        if degrade is not None:
            img = degrade(img, downsample=self.downsample)

        loss = 0
        for (b, t, l, r) in self.btlrs:
            loss = loss + self.percept(
                img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
                use_vggface=False, max_vgg_layer=4,
                # use_vgg=False,
            )
        return loss


class FaceLoss(nn.Module):
    def __init__(
            self,
            target: torch.Tensor,
            input_size: int = 1024,
            input_channels: int = 3,
            size: int = 256,
            percept: Optional[nn.Module] = None,
            args: Optional[Namespace] = None
    ):
        """
        target: target image
        """
        assert not (percept is None and args is None)

        super().__init__()

        target_size = target.shape[-1]
        self.target = target if target_size == size \
                else Downsample(target_size, size, target.shape[1]).to(target.device)(target)

        self.downsample = Downsample(input_size, size, input_channels) \
                if size != input_size else (lambda x: x)

        self.percept = percept if percept is not None else create_perceptual_loss(args)

    def forward(self, img: torch.Tensor, degrade: nn.Module = None):
        """
        img: it should be the degraded version of the generated image
        """
        if degrade is not None:
            img = degrade(img, downsample=self.downsample)
        loss = self.percept(img, self.target)
        return loss