interfacegan_pp / app.py
ybelkada's picture
try with html to fix links
4ae241e
raw
history blame
2.84 kB
import os
import torch
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
from utils.constants import VALID_CHOICES, ENABLE_GPU, MODEL_NAME, OUTPUT_LIST, description, title, css, article
from utils.image_manip import tensor_to_pil, concat_images
def get_generator(model_name):
if model_name == 'stylegan_ffhq':
generator = StyleGANGenerator(model_name)
elif model_name == 'stylegan2_ffhq':
generator = StyleGAN2Generator(model_name)
else:
raise ValueError('Model name not recognized')
if ENABLE_GPU:
generator = generator.cuda()
return generator
generator = get_generator(MODEL_NAME)
boundaries = {
boundary:np.squeeze(np.load(open(os.path.join('boundaries', MODEL_NAME, 'boundary_%s.npy' % boundary), 'rb')))
for boundary in VALID_CHOICES
}
@torch.no_grad()
def inference(seed, coef, nb_images, list_choices):
global generator, boundaries
np.random.seed(seed)
latent_codes = generator.easy_sample(nb_images)
if ENABLE_GPU:
latent_codes = latent_codes.cuda()
generator = generator.cuda()
generated_images = generator.easy_synthesize(latent_codes)
generated_images = tensor_to_pil(generated_images)
new_latent_codes = latent_codes.copy()
for i, _ in enumerate(generated_images):
for choice in list_choices:
new_latent_codes[i, :] += boundaries[choice]*coef
modified_generated_images = generator.easy_synthesize(new_latent_codes)
modified_generated_images = tensor_to_pil(modified_generated_images)
concatenated_output = concat_images(generated_images, modified_generated_images)
return concatenated_output
# https://huggingface.co/spaces/osanseviero/6DRepNet/blob/main/app.py
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=644,
label="Random seed to use for the generation"
),
gr.inputs.Slider(
minimum=-3,
maximum=3,
step=0.1,
default=1,
label="Modification coefficient",
),
gr.inputs.Slider(
minimum=1,
maximum=10,
step=1,
default=2,
label="Number of images to generate",
),
gr.inputs.CheckboxGroup(
VALID_CHOICES,
default=[],
type="value",
label="Select attributes to modify",
optional=False
)
],
outputs=OUTPUT_LIST,
layout="horizontal",
theme="peach",
description=description,
title=title,
css=css,
article=article
)
iface.launch()