|
import os |
|
import argparse |
|
|
|
import torch |
|
from torchvision import utils |
|
|
|
from model.sg2_model import Generator |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
|
|
import subprocess |
|
import shutil |
|
import copy |
|
|
|
VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"] |
|
|
|
SUGGESTED_DISTANCES = { |
|
"pose": 3.0, |
|
"smile": 2.0, |
|
"age": 4.0, |
|
"gender": 3.0, |
|
"hair_length": -4.0, |
|
"beard": 2.0 |
|
} |
|
|
|
def project_code(latent_code, boundary, distance=3.0): |
|
|
|
if len(boundary) == 2: |
|
boundary = boundary.reshape(1, 1, -1) |
|
|
|
return latent_code + distance * boundary |
|
|
|
def project_code_by_edit_name(latent_code, name, strength): |
|
boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries") |
|
|
|
distance = SUGGESTED_DISTANCES[name] * strength |
|
boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy() |
|
|
|
return project_code(latent_code, boundary, distance) |