Spaces:
Runtime error
Runtime error
File size: 4,549 Bytes
b9c3db8 e7f7175 ff2b644 0d8c2d3 b9c3db8 3b78729 0c01bda b9c3db8 17dfcbb b9c3db8 cc7e213 2a7d21a b9c3db8 17dfcbb b9c3db8 0d8c2d3 b9c3db8 0d8c2d3 e7f7175 3b78729 cc7e213 3b78729 179cbb8 0c01bda b9c3db8 cc7e213 b9c3db8 6788293 36bec38 0d8c2d3 36bec38 0d8c2d3 36bec38 6788293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import warnings
import argparse
import os
from PIL import Image
import numpy as np
import torch
import stylegan2
from stylegan2 import utils
from huggingface_hub import hf_hub_download
import gradio as gr
import re
from types import SimpleNamespace
#Edited from run_generator.py to return PIL images instead of saving them to the disk.
def generate_images(G, args):
latent_size, label_size = G.latent_size, G.label_size
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
if device.index is not None:
torch.cuda.set_device(device.index)
G.to(device)
if args.truncation_psi != 1:
G.set_truncation(truncation_psi=args.truncation_psi)
if len(args.gpu) > 1:
warnings.warn(
'Noise can not be randomized based on the seed ' +
'when using more than 1 GPU device. Noise will ' +
'now be randomized from default random state.'
)
G.random_noise()
G = torch.nn.DataParallel(G, device_ids=args.gpu)
else:
noise_reference = G.static_noise()
def get_batch(seeds):
latents = []
labels = []
if len(args.gpu) <= 1:
noise_tensors = [[] for _ in noise_reference]
for seed in seeds:
rnd = np.random.default_rng(seed)
latents.append(torch.from_numpy(rnd.standard_normal(latent_size)))
if len(args.gpu) <= 1:
for i, ref in enumerate(noise_reference):
noise_tensors[i].append(
torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]])))
)
if label_size:
labels.append(torch.tensor([rnd.integers(0, label_size)]))
latents = torch.stack(latents, dim=0).to(
device=device, dtype=torch.float32)
if labels:
labels = torch.cat(labels, dim=0).to(
device=device, dtype=torch.int64)
else:
labels = None
if len(args.gpu) <= 1:
noise_tensors = [
torch.stack(noise, dim=0).to(
device=device, dtype=torch.float32)
for noise in noise_tensors
]
else:
noise_tensors = None
return latents, labels, noise_tensors
return_images = []
for i in range(0, len(args.seeds), args.batch_size):
latents, labels, noise_tensors = get_batch(
args.seeds[i: i + args.batch_size])
if noise_tensors is not None:
G.static_noise(noise_tensors=noise_tensors)
with torch.no_grad():
generated = G(latents, labels=labels)
images = utils.tensor_to_PIL(
generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
return_images.extend(images)
return return_images
#----------------------------------------------------------------------------
interface_modelversion_labels = [
"TWDNEv3 iteration 24664 (best and current version on TWDNE)",
"TWDNEv3 iteration 18528 (the most used version on the Internet)",
"TWDNEv3 iteration 17325"
]
def inference(seed, truncation_psi, modelversion_label):
model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1)
G = stylegan2.models.load(
hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])
)
G.eval()
return generate_images(
G,
SimpleNamespace(**{
'truncation_psi': truncation_psi,
'seeds': [seed],
'batch_size': 1,
'pixel_min': -1,
'pixel_max': 1,
'gpu': []
}) #Replace ArgumentParser at run_generator.py
)[0]
title = "TWDNEv3 CPU Generator"
description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)"
article = ""
gr.Interface(
inference,
[
gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
gr.Radio(
interface_modelversion_labels,
value="TWDNEv3 iteration 24664 (best and current version on TWDNE)",
type="value",
label="Model versions"
)
],
gr.outputs.Image(type="pil"),
title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False
).launch() |