BlendGAN / generate_image_pairs.py
AK391
add files
0145b71
import argparse
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
from model import Generator
from utils import ten2cv, cv2ten
import random
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index):
if args.sample_zs is not None:
sample_zs = torch.load(args.sample_zs)
else:
sample_zs = None
with torch.no_grad():
g_ema.eval()
for i in tqdm(range(args.pics)):
if sample_zs is not None:
sample_z = sample_zs[i]
else:
sample_z = torch.randn(1, args.latent, device=device)
sample1, _ = g_ema([sample_z],
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index,
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample1 = ten2cv(sample1)
sample2 = ten2cv(sample2)
out = np.concatenate([sample1, sample2], axis=1)
cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out)
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=1024)
parser.add_argument('--pics', type=int, default=20, help='N_PICS')
parser.add_argument('--truncation', type=float, default=0.75)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint')
parser.add_argument('--style_img', type=str, default=None, help='path to style image')
parser.add_argument('--sample_zs', type=str, default=None)
parser.add_argument('--add_weight_index', type=int, default=6)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--outdir', type=str, default="")
args = parser.parse_args()
outdir = args.outdir
if not os.path.exists(outdir):
os.makedirs(outdir, exist_ok=True)
args.latent = 512
args.n_mlp = 8
checkpoint = torch.load(args.ckpt)
model_dict = checkpoint['g_ema']
if "latent_avg" in checkpoint.keys():
latent_avg = checkpoint["latent_avg"]
else:
latent_avg = None
if "truncation" in checkpoint.keys():
args.truncation = checkpoint["truncation"]
print('ckpt: ', args.ckpt)
print('truncation: ', args.truncation)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
g_ema.load_state_dict(model_dict)
if args.truncation < 1:
if latent_avg is not None:
mean_latent = latent_avg
print('### use mean_latent in ckpt["latent_avg"]')
else:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
print('### generate mean_latent with \'g_ema.mean_latent\'')
else:
mean_latent = None
print('### args.truncation = 1, mean_latent is None')
if args.style_img is not None:
img = cv2.imread(args.style_img, 1)
img = cv2ten(img, device)
sample_style = g_ema.get_z_embed(img)
else:
sample_style = torch.randn(1, args.latent, device=device)
generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index)
print('Done!')