Spaces:
Runtime error
Runtime error
File size: 4,738 Bytes
47c46ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from argparse import ArgumentParser, Namespace
from typing import (
List,
Tuple,
)
import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import (
Compose,
Grayscale,
Resize,
ToTensor,
)
from models.encoder import Encoder
from models.encoder4editing import (
get_latents as get_e4e_latents,
setup_model as setup_e4e_model,
)
from utils.misc import (
optional_string,
iterable_to_str,
stem,
)
class ColorEncoderArguments:
def __init__(self):
parser = ArgumentParser("Encode an image via a feed-forward encoder")
self.add_arguments(parser)
self.parser = parser
@staticmethod
def add_arguments(parser: ArgumentParser):
parser.add_argument("--encoder_ckpt", default=None,
help="encoder checkpoint path. initialize w with encoder output if specified")
parser.add_argument("--encoder_size", type=int, default=256,
help="Resize to this size to pass as input to the encoder")
class InitializerArguments:
@classmethod
def add_arguments(cls, parser: ArgumentParser):
ColorEncoderArguments.add_arguments(parser)
cls.add_e4e_arguments(parser)
parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
help="replace layers <start> to <end> in the e4e code by the color code")
parser.add_argument("--init_latent", default=None, help="path to init wp")
@staticmethod
def to_string(args: Namespace):
return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
else f"init({iterable_to_str(args.mix_layer_range)})")
#+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
@staticmethod
def add_e4e_arguments(parser: ArgumentParser):
parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
help="e4e checkpoint path.")
parser.add_argument("--e4e_size", type=int, default=256,
help="Resize to this size to pass as input to the e4e")
def create_color_encoder(args: Namespace):
encoder = Encoder(1, args.encoder_size, 512)
ckpt = torch.load(args.encoder_ckpt)
encoder.load_state_dict(ckpt["model"])
return encoder
def transform_input(img: Image):
tsfm = Compose([
Grayscale(),
Resize(args.encoder_size),
ToTensor(),
])
return tsfm(img)
def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
assert args.encoder_size is not None
imgs = Resize(args.encoder_size)(imgs)
color_encoder = create_color_encoder(args).to(imgs.device)
color_encoder.eval()
with torch.no_grad():
latent = color_encoder(imgs)
return latent.detach()
def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
return F.interpolate(imgs, size=size, mode='bilinear')
class Initializer(nn.Module):
def __init__(self, args: Namespace):
super().__init__()
self.path = None
if args.init_latent is not None:
self.path = args.init_latent
return
assert args.encoder_size is not None
self.color_encoder = create_color_encoder(args)
self.color_encoder.eval()
self.color_encoder_size = args.encoder_size
self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
assert 'cars_' not in e4e_opts.dataset_type
self.e4e.decoder.eval()
self.e4e.eval()
self.e4e_size = args.e4e_size
self.mix_layer_range = args.mix_layer_range
def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Get the color W code
"""
imgs = resize(imgs, self.color_encoder_size)
latent = self.color_encoder(imgs)
return latent
def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
imgs = resize(imgs, self.e4e_size)
imgs = (imgs - 0.5) / 0.5
if imgs.shape[1] == 1: # 1 channel
imgs = imgs.repeat(1, 3, 1, 1)
return get_e4e_latents(self.e4e, imgs)
def load(self, device: torch.device):
latent_np = np.load(self.path)
return torch.tensor(latent_np, device=device)[None, ...]
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
if self.path is not None:
return self.load(imgs.device)
shape_code = self.encode_shape(imgs)
color_code = self.encode_color(imgs)
# style mix
latent = shape_code
start, end = self.mix_layer_range
latent[:, start:end] = color_code
return latent
|