interfacegan_pp / app.py
ybelkada's picture
commit files
4d6b877
raw
history blame
2.75 kB
import os
import sys
import torch
import cv2
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
VALID_CHOICES = [
"Bald",
"Young",
"Mustache",
"Eyeglasses",
"Hat",
"Smiling"
]
ENABLE_GPU = False
MODEL_NAMES = [
'stylegan_ffhq',
'stylegan2_ffhq'
]
NB_IMG = 4
OUTPUT_LIST = [gr.outputs.Image(type="pil", label="Generated Image") for _ in range(NB_IMG)] + [gr.outputs.Image(type="pil", label="Modified Image") for _ in range(NB_IMG)]
def tensor_to_pil(input_object):
"""Shows images in one figure."""
if isinstance(input_object, dict):
im_array = []
images = input_object['image']
else:
images = input_object
for _, image in enumerate(images):
im_array.append(PIL.Image.fromarray(image))
return im_array
def get_generator(model_name):
if model_name == 'stylegan_ffhq':
generator = StyleGANGenerator(model_name)
elif model_name == 'stylegan2_ffhq':
generator = StyleGAN2Generator(model_name)
else:
raise ValueError('Model name not recognized')
if ENABLE_GPU:
generator = generator.cuda()
return generator
def inference(seed, choice, model_name, coef, nb_images=NB_IMG):
np.random.seed(seed)
boundary = np.squeeze(np.load(open(os.path.join('boundaries', model_name, 'boundary_%s.npy' % choice), 'rb')))
generator = get_generator(model_name)
latent_codes = generator.easy_sample(nb_images)
if ENABLE_GPU:
latent_codes = latent_codes.cuda()
generator = generator.cuda()
generated_images = generator.easy_synthesize(latent_codes)
generated_images = tensor_to_pil(generated_images)
new_latent_codes = latent_codes.copy()
for i, _ in enumerate(generated_images):
new_latent_codes[i, :] += boundary*coef
modified_generated_images = generator.easy_synthesize(new_latent_codes)
modified_generated_images = tensor_to_pil(modified_generated_images)
return generated_images + modified_generated_images
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=264,
),
gr.inputs.Dropdown(
choices=VALID_CHOICES,
type="value",
),
gr.inputs.Dropdown(
choices=MODEL_NAMES,
type="value",
),
gr.inputs.Slider(
minimum=-3,
maximum=3,
step=0.1,
default=0,
),
],
outputs=OUTPUT_LIST,
layout="horizontal",
theme="peach"
)
iface.launch()