Spaces:
Runtime error
Runtime error
import warnings | |
import argparse | |
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
import stylegan2 | |
from stylegan2 import utils | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import re | |
from types import SimpleNamespace | |
#Edited from run_generator.py to return PIL images instead of saving them to the disk. | |
def generate_images(G, args): | |
latent_size, label_size = G.latent_size, G.label_size | |
device = torch.device(args.gpu[0] if args.gpu else 'cpu') | |
if device.index is not None: | |
torch.cuda.set_device(device.index) | |
G.to(device) | |
if args.truncation_psi != 1: | |
G.set_truncation(truncation_psi=args.truncation_psi) | |
if len(args.gpu) > 1: | |
warnings.warn( | |
'Noise can not be randomized based on the seed ' + | |
'when using more than 1 GPU device. Noise will ' + | |
'now be randomized from default random state.' | |
) | |
G.random_noise() | |
G = torch.nn.DataParallel(G, device_ids=args.gpu) | |
else: | |
noise_reference = G.static_noise() | |
def get_batch(seeds): | |
latents = [] | |
labels = [] | |
if len(args.gpu) <= 1: | |
noise_tensors = [[] for _ in noise_reference] | |
for seed in seeds: | |
rnd = np.random.default_rng(seed) | |
latents.append(torch.from_numpy(rnd.standard_normal(latent_size))) | |
if len(args.gpu) <= 1: | |
for i, ref in enumerate(noise_reference): | |
noise_tensors[i].append( | |
torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]]))) | |
) | |
if label_size: | |
labels.append(torch.tensor([rnd.integers(0, label_size)])) | |
latents = torch.stack(latents, dim=0).to( | |
device=device, dtype=torch.float32) | |
if labels: | |
labels = torch.cat(labels, dim=0).to( | |
device=device, dtype=torch.int64) | |
else: | |
labels = None | |
if len(args.gpu) <= 1: | |
noise_tensors = [ | |
torch.stack(noise, dim=0).to( | |
device=device, dtype=torch.float32) | |
for noise in noise_tensors | |
] | |
else: | |
noise_tensors = None | |
return latents, labels, noise_tensors | |
return_images = [] | |
for i in range(0, len(args.seeds), args.batch_size): | |
latents, labels, noise_tensors = get_batch( | |
args.seeds[i: i + args.batch_size]) | |
if noise_tensors is not None: | |
G.static_noise(noise_tensors=noise_tensors) | |
with torch.no_grad(): | |
generated = G(latents, labels=labels) | |
images = utils.tensor_to_PIL( | |
generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
return_images.extend(images) | |
return return_images | |
#---------------------------------------------------------------------------- | |
interface_modelversion_labels = [ | |
"TWDNEv3 iteration 24664 (best and current version on TWDNE)", | |
"TWDNEv3 iteration 18528 (the most used version on the Internet)", | |
"TWDNEv3 iteration 17325" | |
] | |
def inference(seed, truncation_psi, modelversion_label): | |
model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1) | |
G = stylegan2.models.load( | |
hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN']) | |
) | |
G.eval() | |
return generate_images( | |
G, | |
SimpleNamespace(**{ | |
'truncation_psi': truncation_psi, | |
'seeds': [seed], | |
'batch_size': 1, | |
'pixel_min': -1, | |
'pixel_max': 1, | |
'gpu': [] | |
}) #Replace ArgumentParser at run_generator.py | |
)[0] | |
title = "TWDNEv3 CPU Generator" | |
description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)" | |
article = "" | |
gr.Interface( | |
inference, | |
[ | |
gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"), | |
gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'), | |
gr.Radio( | |
interface_modelversion_labels, | |
value="TWDNEv3 iteration 24664 (best and current version on TWDNE)", | |
type="value", | |
label="Model versions" | |
) | |
], | |
gr.outputs.Image(type="pil"), | |
title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False | |
).launch() |