buio commited on
Commit
593d914
·
1 Parent(s): b3d2735

conditional sagan 10 attributes

Browse files
Files changed (1) hide show
  1. app.py +266 -3
app.py CHANGED
@@ -1,7 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ """evaluate_gan_gradio.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ckZU76dq3XWcpa5PpQF8a6qJwkTttg8v
8
+
9
+ # ⚙️ Setup
10
+
11
+ ## Fix random seeds
12
+ """
13
+
14
+ SEED = 11
15
+ import os
16
+ os.environ['PYTHONHASHSEED']=str(SEED)
17
+ import random
18
+ import numpy as np
19
+ import tensorflow as tf
20
+
21
+ random.seed(SEED)
22
+ np.random.seed(SEED)
23
+ tf.random.set_seed(SEED)
24
+
25
+ """## Imports"""
26
+
27
+ !pip install gradio -q
28
+
29
  import gradio as gr
30
 
31
+ from scipy import linalg
32
+ import matplotlib.pyplot as plt
33
+ import pandas as pd
34
+
35
+ from tensorflow import keras
36
+ from tensorflow.keras import layers
37
+ from keras.applications.inception_v3 import InceptionV3, preprocess_input
38
+
39
+ from tensorflow.keras.layers import Layer, Input, Dense, Reshape, Flatten
40
+ from tensorflow.keras.layers import Conv2D, Conv2DTranspose, ReLU, LeakyReLU
41
+ from tensorflow.keras.layers import Dropout, Embedding, Concatenate, Add, Activation
42
+ from tensorflow.keras.layers import GlobalAveragePooling2D, UpSampling2D, BatchNormalization
43
+ import tensorflow.keras.backend as K
44
+
45
+ from tensorflow.python.keras.utils import conv_utils
46
+ from tensorflow.keras.initializers import RandomNormal
47
+ from tensorflow.keras.optimizers import Adam
48
+
49
+ !pip install tensorflow_addons
50
+ import tensorflow_addons as tfa
51
+ from tensorflow_addons.layers import SpectralNormalization
52
+
53
+ import gdown
54
+ from zipfile import ZipFile
55
+
56
+ from tqdm.notebook import tqdm
57
+
58
+ """## Download CelebA attributes
59
+
60
+ We'll use face images from the CelebA dataset, resized to 64x64.
61
+ """
62
+
63
+ #Download labels from public github, they have been processed in a 0,1 csv file
64
+ !mkdir "/content/celeba_gan"
65
+ !wget -q -O "/content/celeba_gan/list_attr_celeba01.csv.zip" "https://github.com/buoi/conditional-face-GAN/blob/main/list_attr_celeba01.csv.zip?raw=true"
66
+ !unzip -o "/content/celeba_gan/list_attr_celeba01.csv.zip" -d "/content/celeba_gan"
67
+
68
+ """## Dataset preprocessing functions"""
69
+
70
+ # image utils functions
71
+
72
+ def conv_range(in_range=(-1,1), out_range=(0,255)):
73
+ """ Returns range conversion function"""
74
+
75
+ # compute means and spans once
76
+ in_mean, out_mean = np.mean(in_range), np.mean(out_range)
77
+ in_span, out_span = np.ptp(in_range), np.ptp(out_range)
78
+
79
+ # return function
80
+ def convert_img_range(in_img):
81
+ out_img = (in_img - in_mean) / in_span
82
+ out_img = out_img * out_span + out_mean
83
+ return out_img
84
+
85
+ return convert_img_range
86
+
87
+ def crop128(img):
88
+ #return img[:, 77:141, 57:121]# 64,64 center crop
89
+ return img[:, 45:173, 25:153] # 128,128 center crop
90
+
91
+ def resize64(img):
92
+ return tf.image.resize(img, (64,64), antialias=True, method='bilinear')
93
+
94
+ """# 📉 Evaluate model
95
+
96
+ ## Load trained GAN
97
+ """
98
+
99
+ from collections import namedtuple
100
+ ModelEntry = namedtuple('ModelEntry', '''entity, resume_id, model_name, exp_name, best_epoch, expected_fid, expected_f1, expected_acc, epoch_range, exp_avg_fid, exp_avg_f1, exp_avg_acc''')
101
+
102
+ dcgan = ModelEntry('buio','t1qyadzp', 'dcgan','dcgan',
103
+ 'v24',25.64, 0, 0, (16,28), 24.29, 0,0)
104
+
105
+ acgan2 = ModelEntry('buio','rn8xslip','acgan2_BNstdev','acgan2_BNstdev',
106
+ 'v22', 29.05, 0.9182, 0.9295, (20,31), 24.79, 0.918, 0.926)
107
+
108
+ acgan10 = ModelEntry('buio','3ja6uvac','acgan10_nonseparBNstdev_split','acgan10_nonseparBNstdev_split',
109
+ 'v24', 26.89, 0.859, 0.785, (18,30), 24.59, 0.789,0.858)
110
+
111
+ acgan40 = ModelEntry('buio','2ev65fpt','acgan40_BNstdev','acgan40_BNstdev',
112
+ 'v15', 28.23, 0.430, 0.842, (15,25), 27.72, 04.6, 0.851)
113
+
114
+
115
+ acgan2_hd = ModelEntry('buio','6km6fdgr','acgan2_BNstdev_218x178','acgan2_BNstdev_218x178',
116
+ 'v11', 0,0,0 ,(0,0), 0, 0, 0)
117
+
118
+ acgan10_hd = ModelEntry('buio','3v366skw','acgan40_BNstdev_218x178','acgan40_BNstdev_218x178',
119
+ 'v14', 0,0,0, (0,0),0, 0, 0)
120
+
121
+ acgan40_hd = ModelEntry('buio','booicugb','acgan10_nonseparBNstdev_split_299_218x178','acgan10_nonseparBNstdev_split_299_218x178',
122
+ 'v14', 52.9, 0.410, 0.834, (12,15), 0, 0, 0)
123
+
124
+
125
+
126
+ #1cr1a5w4 SAGAN_3 v31 buianifolli
127
+ #2o3z6bqb SAGAN_5 v17 buianifolli
128
+ #zscel8bz SAGAN_6 v29 buianifolli
129
+
130
+ #sagan40 v18
131
+
132
+ keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3MzA4NTY=/5f09f68e9bb5b09efbc37ad76cdcdbb0"
133
+ saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/2676cd88ef1866d6e572916e413a933e"
134
+ variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/5cab1cb7351f0732ea137fb2d2e0d4ec"
135
+ index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/480b55762c3358f868b8cce53984736b"
136
+
137
+ #sagan10 v16
138
+
139
+ keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMDQwMDE=/392d036bf91d3648eb5a2fa74c1eb716"
140
+ saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a5f8608efcc5dafbe780babcffbc79a9"
141
+ variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a62bf0c4bf7047c0a31df7d2cfdb54f0"
142
+ index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/de6539a7f0909d1dafa89571c7df43d1"
143
+
144
+
145
+
146
+
147
+ #download model
148
+ gan_path = "/content/gan_model/"
149
+ try:
150
+ os.remove(gan_path+"keras_metadata.pb")
151
+ os.remove(gan_path+"saved_model.pb")
152
+ os.remove(gan_path+"variables/variables.data-00000-of-00001")
153
+ os.remove(gan_path+"variables/variables.index")
154
+ except FileNotFoundError:
155
+ pass
156
+ os.makedirs(gan_path,exist_ok =True)
157
+ os.makedirs(gan_path+"/variables",exist_ok =True)
158
+
159
+
160
+ !pip install wget -q
161
+ import wget
162
+ wget.download(keras_metadata_url, gan_path+"keras_metadata.pb",)
163
+ wget.download(saved_model_url, gan_path+"saved_model.pb")
164
+ wget.download(variables_url, gan_path+"variables/variables.data-00000-of-00001")
165
+ wget.download(index_url, gan_path+"variables/variables.index")
166
+
167
+ gan = tf.keras.models.load_model(gan_path)
168
+
169
+ IMAGE_RANGE='11'
170
+ IMAGE_SIZE = gan.discriminator.input_shape[1]
171
+ if IMAGE_SIZE == 64:
172
+ IMAGE_SHAPE = (64,64,3)
173
+ elif IMAGE_SIZE == 218:
174
+ IMAGE_SHAPE = (218,178,3)
175
+
176
+ try:
177
+ LATENT_DIM = gan.generator.input_shape[0][1]
178
+ N_ATTRIBUTES = gan.generator.input_shape[1][1]
179
+ except TypeError:
180
+ LATENT_DIM = gan.generator.input_shape[1]
181
+ N_ATTRIBUTES =0
182
+
183
+ """## 💾 Dataset"""
184
+
185
+ #@title Select Attributes {form-width: "50%", display-mode: "both" }
186
+
187
+ #NUMBER_OF_ATTRIBUTES = "10" #@param [0, 2, 10, 12, 40]
188
+ #N_ATTRIBUTES = int(NUMBER_OF_ATTRIBUTES)
189
+
190
+ IMAGE_RANGE = '11'
191
+
192
+ BATCH_SIZE = 64 #@param {type: "number"}
193
+ if N_ATTRIBUTES == 2:
194
+ LABELS = ["Male", "Smiling"]
195
+
196
+ elif N_ATTRIBUTES == 10:
197
+ LABELS = [
198
+ "Mouth_Slightly_Open", "Wearing_Lipstick", "High_Cheekbones", "Male", "Smiling",
199
+ "Heavy_Makeup", "Wavy_Hair", "Oval_Face", "Pointy_Nose", "Arched_Eyebrows"]
200
+
201
+ elif N_ATTRIBUTES == 12:
202
+ LABELS = ['Wearing_Lipstick','Mouth_Slightly_Open','Male','Smiling',
203
+ 'High_Cheekbones','Heavy_Makeup','Attractive','Young',
204
+ 'No_Beard','Black_Hair','Arched_Eyebrows','Big_Nose']
205
+ elif N_ATTRIBUTES == 40:
206
+ LABELS = [
207
+ '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
208
+ 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
209
+ 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows',
210
+ 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair',
211
+ 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open',
212
+ 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin',
213
+ 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns',
214
+ 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
215
+ 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
216
+ 'Wearing_Necktie', 'Young']
217
+
218
+
219
+ else:
220
+ LABELS = ["Male", "Smiling"]# just for dataset creation
221
+
222
+
223
+ # Take labels and a list of image locations in memory
224
+ df = pd.read_csv(r"/content/celeba_gan/list_attr_celeba01.csv")
225
+ attr_list = df[LABELS].values.tolist()
226
+
227
+ def gen_img(attributes):
228
+
229
+ attr = np.zeros((1,N_ATTRIBUTES))
230
+ for a in attributes:
231
+ attr[0,int(a)] = 1
232
+ num_img = 1
233
+ random_latent_vectors = tf.random.normal(shape=(num_img, LATENT_DIM))
234
+
235
+ generated_images = gan.generator((random_latent_vectors, attr))
236
+ generated_images = (generated_images*0.5+0.5).numpy()
237
+ print(generated_images[0].shape)
238
+ return generated_images[0]
239
+
240
+ iface = gr.Interface(
241
+ gen_img,
242
+ gr.inputs.CheckboxGroup([LABELS[i] for i in range(N_ATTRIBUTES)], type='index'),
243
+ "image",
244
+ layout='unaligned'
245
+ )
246
+ iface.launch(debug=True)
247
+
248
+ def sentence_builder(quantity, animal, place, activity_list, morning):
249
+ return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
250
+
251
+ def generate_image(attributes)
252
 
253
+ iface = gr.Interface(
254
+ sentence_builder,
255
+ [
256
+ gr.inputs.Slider(2, 20),
257
+ gr.inputs.Dropdown(["cat", "dog", "bird"]),
258
+ gr.inputs.Radio(["park", "zoo", "road"]),
259
+ gr.inputs.CheckboxGroup(["ran", "swam", "ate", "slept"]),
260
+ gr.inputs.Checkbox(label="Is it the morning?"),
261
+ ],
262
+ "text",
263
+ examples=[
264
+ [2, "cat", "park", ["ran", "swam"], True],
265
+ [4, "dog", "zoo", ["ate", "swam"], False],
266
+ [10, "bird", "road", ["ran"], False],
267
+ [8, "cat", "zoo", ["ate"], True],
268
+ ],
269
+ )
270
  iface.launch()