File size: 4,738 Bytes
47c46ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from argparse import ArgumentParser, Namespace
from typing import (
    List,
    Tuple,
)

import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import (
    Compose,
    Grayscale,
    Resize,
    ToTensor,
)

from models.encoder import Encoder
from models.encoder4editing import (
    get_latents as get_e4e_latents,
    setup_model as setup_e4e_model,
)
from utils.misc import (
    optional_string,
    iterable_to_str,
    stem,
)



class ColorEncoderArguments:
    def __init__(self):
        parser = ArgumentParser("Encode an image via a feed-forward encoder")

        self.add_arguments(parser)

        self.parser = parser

    @staticmethod
    def add_arguments(parser: ArgumentParser):
        parser.add_argument("--encoder_ckpt", default=None,
                            help="encoder checkpoint path. initialize w with encoder output if specified")
        parser.add_argument("--encoder_size", type=int, default=256,
                            help="Resize to this size to pass as input to the encoder")


class InitializerArguments:
    @classmethod
    def add_arguments(cls, parser: ArgumentParser):
        ColorEncoderArguments.add_arguments(parser)
        cls.add_e4e_arguments(parser)
        parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
                help="replace layers <start> to <end> in the e4e code by the color code")

        parser.add_argument("--init_latent", default=None, help="path to init wp")

    @staticmethod
    def to_string(args: Namespace):
        return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
               else f"init({iterable_to_str(args.mix_layer_range)})")
            #+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")

    @staticmethod
    def add_e4e_arguments(parser: ArgumentParser):
        parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
                            help="e4e checkpoint path.")
        parser.add_argument("--e4e_size", type=int, default=256,
                            help="Resize to this size to pass as input to the e4e")



def create_color_encoder(args: Namespace):
    encoder = Encoder(1, args.encoder_size, 512)
    ckpt = torch.load(args.encoder_ckpt)
    encoder.load_state_dict(ckpt["model"])
    return encoder


def transform_input(img: Image):
    tsfm = Compose([
        Grayscale(),
        Resize(args.encoder_size),
        ToTensor(),
    ])
    return tsfm(img)


def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
    assert args.encoder_size is not None

    imgs = Resize(args.encoder_size)(imgs)

    color_encoder = create_color_encoder(args).to(imgs.device)
    color_encoder.eval()
    with torch.no_grad():
        latent = color_encoder(imgs)
    return latent.detach()


def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
    return F.interpolate(imgs, size=size, mode='bilinear')


class Initializer(nn.Module):
    def __init__(self, args: Namespace):
        super().__init__()

        self.path = None
        if args.init_latent is not None:
            self.path = args.init_latent
            return


        assert args.encoder_size is not None
        self.color_encoder = create_color_encoder(args)
        self.color_encoder.eval()
        self.color_encoder_size = args.encoder_size

        self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
        assert 'cars_' not in e4e_opts.dataset_type
        self.e4e.decoder.eval()
        self.e4e.eval()
        self.e4e_size = args.e4e_size

        self.mix_layer_range = args.mix_layer_range

    def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
        """
        Get the color W code
        """
        imgs = resize(imgs, self.color_encoder_size)

        latent = self.color_encoder(imgs)

        return latent

    def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
        imgs = resize(imgs, self.e4e_size)
        imgs = (imgs - 0.5) / 0.5
        if imgs.shape[1] == 1: # 1 channel
            imgs = imgs.repeat(1, 3, 1, 1)
        return get_e4e_latents(self.e4e, imgs)

    def load(self, device: torch.device):
        latent_np = np.load(self.path)
        return torch.tensor(latent_np, device=device)[None, ...]

    def forward(self, imgs: torch.Tensor) -> torch.Tensor:
        if self.path is not None:
            return self.load(imgs.device)

        shape_code = self.encode_shape(imgs)
        color_code = self.encode_color(imgs)

        # style mix
        latent = shape_code
        start, end = self.mix_layer_range
        latent[:, start:end] = color_code
        return latent