hr16 commited on
Commit
b9c3db8
·
1 Parent(s): 480bfbc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import argparse
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+
8
+ import stylegan2
9
+ from stylegan2 import utils
10
+
11
+ def generate_images(G, args):
12
+ latent_size, label_size = G.latent_size, G.label_size
13
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
14
+ if device.index is not None:
15
+ torch.cuda.set_device(device.index)
16
+ G.to(device)
17
+ if args.truncation_psi != 1:
18
+ G.set_truncation(truncation_psi=args.truncation_psi)
19
+ if len(args.gpu) > 1:
20
+ warnings.warn(
21
+ 'Noise can not be randomized based on the seed ' +
22
+ 'when using more than 1 GPU device. Noise will ' +
23
+ 'now be randomized from default random state.'
24
+ )
25
+ G.random_noise()
26
+ G = torch.nn.DataParallel(G, device_ids=args.gpu)
27
+ else:
28
+ noise_reference = G.static_noise()
29
+
30
+ def get_batch(seeds):
31
+ latents = []
32
+ labels = []
33
+ if len(args.gpu) <= 1:
34
+ noise_tensors = [[] for _ in noise_reference]
35
+ for seed in seeds:
36
+ rnd = np.random.RandomState(seed)
37
+ latents.append(torch.from_numpy(rnd.randn(latent_size)))
38
+ if len(args.gpu) <= 1:
39
+ for i, ref in enumerate(noise_reference):
40
+ noise_tensors[i].append(
41
+ torch.from_numpy(rnd.randn(*ref.size()[1:])))
42
+ if label_size:
43
+ labels.append(torch.tensor([rnd.randint(0, label_size)]))
44
+ latents = torch.stack(latents, dim=0).to(
45
+ device=device, dtype=torch.float32)
46
+ if labels:
47
+ labels = torch.cat(labels, dim=0).to(
48
+ device=device, dtype=torch.int64)
49
+ else:
50
+ labels = None
51
+ if len(args.gpu) <= 1:
52
+ noise_tensors = [
53
+ torch.stack(noise, dim=0).to(
54
+ device=device, dtype=torch.float32)
55
+ for noise in noise_tensors
56
+ ]
57
+ else:
58
+ noise_tensors = None
59
+ return latents, labels, noise_tensors
60
+ return_images = []
61
+ for i in range(0, len(args.seeds), args.batch_size):
62
+ latents, labels, noise_tensors = get_batch(
63
+ args.seeds[i: i + args.batch_size])
64
+ if noise_tensors is not None:
65
+ G.static_noise(noise_tensors=noise_tensors)
66
+ with torch.no_grad():
67
+ generated = G(latents, labels=labels)
68
+ images = utils.tensor_to_PIL(
69
+ generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
70
+ return_images.extend(images)
71
+ return return_images
72
+
73
+
74
+ #----------------------------------------------------------------------------
75
+
76
+ def inference(seed):
77
+
78
+
79
+ title = "TWDNEv3 CPU Generator"
80
+ description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port). To use it, simply put your random seed."
81
+ article = ""
82
+ gr.Interface(inference, ["number"], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False).launch()