File size: 4,549 Bytes
b9c3db8
 
 
 
 
 
 
 
 
e7f7175
ff2b644
0d8c2d3
b9c3db8
3b78729
 
0c01bda
b9c3db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17dfcbb
 
b9c3db8
 
 
cc7e213
2a7d21a
b9c3db8
17dfcbb
b9c3db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d8c2d3
 
 
 
 
b9c3db8
0d8c2d3
 
 
 
 
e7f7175
3b78729
 
 
cc7e213
3b78729
 
 
 
179cbb8
 
0c01bda
b9c3db8
 
cc7e213
b9c3db8
6788293
 
 
36bec38
 
 
0d8c2d3
 
36bec38
0d8c2d3
36bec38
6788293
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import warnings
import argparse
import os
from PIL import Image
import numpy as np
import torch

import stylegan2
from stylegan2 import utils
from huggingface_hub import hf_hub_download
import gradio as gr
import re

from types import SimpleNamespace

#Edited from run_generator.py to return PIL images instead of saving them to the disk.
def generate_images(G, args):
    latent_size, label_size = G.latent_size, G.label_size
    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    G.to(device)
    if args.truncation_psi != 1:
        G.set_truncation(truncation_psi=args.truncation_psi)
    if len(args.gpu) > 1:
        warnings.warn(
            'Noise can not be randomized based on the seed ' +
            'when using more than 1 GPU device. Noise will ' +
            'now be randomized from default random state.'
        )
        G.random_noise()
        G = torch.nn.DataParallel(G, device_ids=args.gpu)
    else:
        noise_reference = G.static_noise()

    def get_batch(seeds):
        latents = []
        labels = []
        if len(args.gpu) <= 1:
            noise_tensors = [[] for _ in noise_reference]
        for seed in seeds:
            rnd = np.random.default_rng(seed)
            latents.append(torch.from_numpy(rnd.standard_normal(latent_size)))
            if len(args.gpu) <= 1:
                for i, ref in enumerate(noise_reference):
                    noise_tensors[i].append(
                        torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]])))
                    )
            if label_size:
                labels.append(torch.tensor([rnd.integers(0, label_size)]))
        latents = torch.stack(latents, dim=0).to(
            device=device, dtype=torch.float32)
        if labels:
            labels = torch.cat(labels, dim=0).to(
                device=device, dtype=torch.int64)
        else:
            labels = None
        if len(args.gpu) <= 1:
            noise_tensors = [
                torch.stack(noise, dim=0).to(
                    device=device, dtype=torch.float32)
                for noise in noise_tensors
            ]
        else:
            noise_tensors = None
        return latents, labels, noise_tensors
    return_images = []
    for i in range(0, len(args.seeds), args.batch_size):
        latents, labels, noise_tensors = get_batch(
            args.seeds[i: i + args.batch_size])
        if noise_tensors is not None:
            G.static_noise(noise_tensors=noise_tensors)
        with torch.no_grad():
            generated = G(latents, labels=labels)
        images = utils.tensor_to_PIL(
            generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
        return_images.extend(images)
    return return_images
 
 
 #----------------------------------------------------------------------------
interface_modelversion_labels = [
    "TWDNEv3 iteration 24664 (best and current version on TWDNE)",
    "TWDNEv3 iteration 18528 (the most used version on the Internet)",
    "TWDNEv3 iteration 17325"
]

def inference(seed, truncation_psi, modelversion_label):
    model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1)
    G = stylegan2.models.load(
        hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])
    )
    G.eval()
    return generate_images(
        G, 
        SimpleNamespace(**{
            'truncation_psi': truncation_psi,
            'seeds': [seed],
            'batch_size': 1,
            'pixel_min': -1,
            'pixel_max': 1,
            'gpu': []
        }) #Replace ArgumentParser at run_generator.py
    )[0]
    
title = "TWDNEv3 CPU Generator"
description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)"
article = ""
gr.Interface(
    inference, 
    [
        gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"), 
        gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
        gr.Radio(
            interface_modelversion_labels,
            value="TWDNEv3 iteration 24664 (best and current version on TWDNE)", 
            type="value", 
            label="Model versions"         
        )
    ], 
    gr.outputs.Image(type="pil"),
    title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False
).launch()