hr16's picture
Update app.py
0d8c2d3
import warnings
import argparse
import os
from PIL import Image
import numpy as np
import torch
import stylegan2
from stylegan2 import utils
from huggingface_hub import hf_hub_download
import gradio as gr
import re
from types import SimpleNamespace
#Edited from run_generator.py to return PIL images instead of saving them to the disk.
def generate_images(G, args):
latent_size, label_size = G.latent_size, G.label_size
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
if device.index is not None:
torch.cuda.set_device(device.index)
G.to(device)
if args.truncation_psi != 1:
G.set_truncation(truncation_psi=args.truncation_psi)
if len(args.gpu) > 1:
warnings.warn(
'Noise can not be randomized based on the seed ' +
'when using more than 1 GPU device. Noise will ' +
'now be randomized from default random state.'
)
G.random_noise()
G = torch.nn.DataParallel(G, device_ids=args.gpu)
else:
noise_reference = G.static_noise()
def get_batch(seeds):
latents = []
labels = []
if len(args.gpu) <= 1:
noise_tensors = [[] for _ in noise_reference]
for seed in seeds:
rnd = np.random.default_rng(seed)
latents.append(torch.from_numpy(rnd.standard_normal(latent_size)))
if len(args.gpu) <= 1:
for i, ref in enumerate(noise_reference):
noise_tensors[i].append(
torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]])))
)
if label_size:
labels.append(torch.tensor([rnd.integers(0, label_size)]))
latents = torch.stack(latents, dim=0).to(
device=device, dtype=torch.float32)
if labels:
labels = torch.cat(labels, dim=0).to(
device=device, dtype=torch.int64)
else:
labels = None
if len(args.gpu) <= 1:
noise_tensors = [
torch.stack(noise, dim=0).to(
device=device, dtype=torch.float32)
for noise in noise_tensors
]
else:
noise_tensors = None
return latents, labels, noise_tensors
return_images = []
for i in range(0, len(args.seeds), args.batch_size):
latents, labels, noise_tensors = get_batch(
args.seeds[i: i + args.batch_size])
if noise_tensors is not None:
G.static_noise(noise_tensors=noise_tensors)
with torch.no_grad():
generated = G(latents, labels=labels)
images = utils.tensor_to_PIL(
generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
return_images.extend(images)
return return_images
#----------------------------------------------------------------------------
interface_modelversion_labels = [
"TWDNEv3 iteration 24664 (best and current version on TWDNE)",
"TWDNEv3 iteration 18528 (the most used version on the Internet)",
"TWDNEv3 iteration 17325"
]
def inference(seed, truncation_psi, modelversion_label):
model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1)
G = stylegan2.models.load(
hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])
)
G.eval()
return generate_images(
G,
SimpleNamespace(**{
'truncation_psi': truncation_psi,
'seeds': [seed],
'batch_size': 1,
'pixel_min': -1,
'pixel_max': 1,
'gpu': []
}) #Replace ArgumentParser at run_generator.py
)[0]
title = "TWDNEv3 CPU Generator"
description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)"
article = ""
gr.Interface(
inference,
[
gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
gr.Radio(
interface_modelversion_labels,
value="TWDNEv3 iteration 24664 (best and current version on TWDNE)",
type="value",
label="Model versions"
)
],
gr.outputs.Image(type="pil"),
title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False
).launch()