Spaces:
Runtime error
Runtime error
from argparse import ArgumentParser, Namespace | |
from typing import ( | |
List, | |
Tuple, | |
) | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torchvision.transforms import ( | |
Compose, | |
Grayscale, | |
Resize, | |
ToTensor, | |
) | |
from models.encoder import Encoder | |
from models.encoder4editing import ( | |
get_latents as get_e4e_latents, | |
setup_model as setup_e4e_model, | |
) | |
from utils.misc import ( | |
optional_string, | |
iterable_to_str, | |
stem, | |
) | |
class ColorEncoderArguments: | |
def __init__(self): | |
parser = ArgumentParser("Encode an image via a feed-forward encoder") | |
self.add_arguments(parser) | |
self.parser = parser | |
def add_arguments(parser: ArgumentParser): | |
parser.add_argument("--encoder_ckpt", default=None, | |
help="encoder checkpoint path. initialize w with encoder output if specified") | |
parser.add_argument("--encoder_size", type=int, default=256, | |
help="Resize to this size to pass as input to the encoder") | |
class InitializerArguments: | |
def add_arguments(cls, parser: ArgumentParser): | |
ColorEncoderArguments.add_arguments(parser) | |
cls.add_e4e_arguments(parser) | |
parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2, | |
help="replace layers <start> to <end> in the e4e code by the color code") | |
parser.add_argument("--init_latent", default=None, help="path to init wp") | |
def to_string(args: Namespace): | |
return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent | |
else f"init({iterable_to_str(args.mix_layer_range)})") | |
#+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}") | |
def add_e4e_arguments(parser: ArgumentParser): | |
parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt', | |
help="e4e checkpoint path.") | |
parser.add_argument("--e4e_size", type=int, default=256, | |
help="Resize to this size to pass as input to the e4e") | |
def create_color_encoder(args: Namespace): | |
encoder = Encoder(1, args.encoder_size, 512) | |
ckpt = torch.load(args.encoder_ckpt) | |
encoder.load_state_dict(ckpt["model"]) | |
return encoder | |
def transform_input(img: Image): | |
tsfm = Compose([ | |
Grayscale(), | |
Resize(args.encoder_size), | |
ToTensor(), | |
]) | |
return tsfm(img) | |
def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor: | |
assert args.encoder_size is not None | |
imgs = Resize(args.encoder_size)(imgs) | |
color_encoder = create_color_encoder(args).to(imgs.device) | |
color_encoder.eval() | |
with torch.no_grad(): | |
latent = color_encoder(imgs) | |
return latent.detach() | |
def resize(imgs: torch.Tensor, size: int) -> torch.Tensor: | |
return F.interpolate(imgs, size=size, mode='bilinear') | |
class Initializer(nn.Module): | |
def __init__(self, args: Namespace): | |
super().__init__() | |
self.path = None | |
if args.init_latent is not None: | |
self.path = args.init_latent | |
return | |
assert args.encoder_size is not None | |
self.color_encoder = create_color_encoder(args) | |
self.color_encoder.eval() | |
self.color_encoder_size = args.encoder_size | |
self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt) | |
assert 'cars_' not in e4e_opts.dataset_type | |
self.e4e.decoder.eval() | |
self.e4e.eval() | |
self.e4e_size = args.e4e_size | |
self.mix_layer_range = args.mix_layer_range | |
def encode_color(self, imgs: torch.Tensor) -> torch.Tensor: | |
""" | |
Get the color W code | |
""" | |
imgs = resize(imgs, self.color_encoder_size) | |
latent = self.color_encoder(imgs) | |
return latent | |
def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor: | |
imgs = resize(imgs, self.e4e_size) | |
imgs = (imgs - 0.5) / 0.5 | |
if imgs.shape[1] == 1: # 1 channel | |
imgs = imgs.repeat(1, 3, 1, 1) | |
return get_e4e_latents(self.e4e, imgs) | |
def load(self, device: torch.device): | |
latent_np = np.load(self.path) | |
return torch.tensor(latent_np, device=device)[None, ...] | |
def forward(self, imgs: torch.Tensor) -> torch.Tensor: | |
if self.path is not None: | |
return self.load(imgs.device) | |
shape_code = self.encode_shape(imgs) | |
color_code = self.encode_color(imgs) | |
# style mix | |
latent = shape_code | |
start, end = self.mix_layer_range | |
latent[:, start:end] = color_code | |
return latent | |