import warnings import argparse import os from PIL import Image import numpy as np import torch import stylegan2 from stylegan2 import utils from huggingface_hub import hf_hub_download from types import SimpleNamespace def generate_images(G, args): latent_size, label_size = G.latent_size, G.label_size device = torch.device(args.gpu[0] if args.gpu else 'cpu') if device.index is not None: torch.cuda.set_device(device.index) G.to(device) if args.truncation_psi != 1: G.set_truncation(truncation_psi=args.truncation_psi) if len(args.gpu) > 1: warnings.warn( 'Noise can not be randomized based on the seed ' + 'when using more than 1 GPU device. Noise will ' + 'now be randomized from default random state.' ) G.random_noise() G = torch.nn.DataParallel(G, device_ids=args.gpu) else: noise_reference = G.static_noise() def get_batch(seeds): latents = [] labels = [] if len(args.gpu) <= 1: noise_tensors = [[] for _ in noise_reference] for seed in seeds: rnd = np.random.RandomState(seed) latents.append(torch.from_numpy(rnd.randn(latent_size))) if len(args.gpu) <= 1: for i, ref in enumerate(noise_reference): noise_tensors[i].append( torch.from_numpy(rnd.randn(*ref.size()[1:]))) if label_size: labels.append(torch.tensor([rnd.randint(0, label_size)])) latents = torch.stack(latents, dim=0).to( device=device, dtype=torch.float32) if labels: labels = torch.cat(labels, dim=0).to( device=device, dtype=torch.int64) else: labels = None if len(args.gpu) <= 1: noise_tensors = [ torch.stack(noise, dim=0).to( device=device, dtype=torch.float32) for noise in noise_tensors ] else: noise_tensors = None return latents, labels, noise_tensors return_images = [] for i in range(0, len(args.seeds), args.batch_size): latents, labels, noise_tensors = get_batch( args.seeds[i: i + args.batch_size]) if noise_tensors is not None: G.static_noise(noise_tensors=noise_tensors) with torch.no_grad(): generated = G(latents, labels=labels) images = utils.tensor_to_PIL( generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max) return_images.extend(images) return return_images #---------------------------------------------------------------------------- def inference(seed): G = stylegan2.models.load(hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", "Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])) G.eval() return generate_images( G, SimpleNamespace(**{ 'truncation_psi': 0.7, #It seems like 0.7 will give the best result. 'seeds': [seed], 'batch_size': 1, 'pixel_min': -1, 'pixel_max': 1, }) #https://github.com/adriansahlman/stylegan2_pytorch/blob/master/run_generator.py ) title = "TWDNEv3 CPU Generator" description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port). To use it, simply put your random seed." article = "" gr.Interface(inference, ["number"], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False).launch()