Spaces:
Runtime error
Runtime error
conditional sagan 10 attributes
Browse files
app.py
CHANGED
@@ -1,7 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
iface = gr.Interface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
iface.launch()
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""evaluate_gan_gradio.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1ckZU76dq3XWcpa5PpQF8a6qJwkTttg8v
|
8 |
+
|
9 |
+
# ⚙️ Setup
|
10 |
+
|
11 |
+
## Fix random seeds
|
12 |
+
"""
|
13 |
+
|
14 |
+
SEED = 11
|
15 |
+
import os
|
16 |
+
os.environ['PYTHONHASHSEED']=str(SEED)
|
17 |
+
import random
|
18 |
+
import numpy as np
|
19 |
+
import tensorflow as tf
|
20 |
+
|
21 |
+
random.seed(SEED)
|
22 |
+
np.random.seed(SEED)
|
23 |
+
tf.random.set_seed(SEED)
|
24 |
+
|
25 |
+
"""## Imports"""
|
26 |
+
|
27 |
+
!pip install gradio -q
|
28 |
+
|
29 |
import gradio as gr
|
30 |
|
31 |
+
from scipy import linalg
|
32 |
+
import matplotlib.pyplot as plt
|
33 |
+
import pandas as pd
|
34 |
+
|
35 |
+
from tensorflow import keras
|
36 |
+
from tensorflow.keras import layers
|
37 |
+
from keras.applications.inception_v3 import InceptionV3, preprocess_input
|
38 |
+
|
39 |
+
from tensorflow.keras.layers import Layer, Input, Dense, Reshape, Flatten
|
40 |
+
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, ReLU, LeakyReLU
|
41 |
+
from tensorflow.keras.layers import Dropout, Embedding, Concatenate, Add, Activation
|
42 |
+
from tensorflow.keras.layers import GlobalAveragePooling2D, UpSampling2D, BatchNormalization
|
43 |
+
import tensorflow.keras.backend as K
|
44 |
+
|
45 |
+
from tensorflow.python.keras.utils import conv_utils
|
46 |
+
from tensorflow.keras.initializers import RandomNormal
|
47 |
+
from tensorflow.keras.optimizers import Adam
|
48 |
+
|
49 |
+
!pip install tensorflow_addons
|
50 |
+
import tensorflow_addons as tfa
|
51 |
+
from tensorflow_addons.layers import SpectralNormalization
|
52 |
+
|
53 |
+
import gdown
|
54 |
+
from zipfile import ZipFile
|
55 |
+
|
56 |
+
from tqdm.notebook import tqdm
|
57 |
+
|
58 |
+
"""## Download CelebA attributes
|
59 |
+
|
60 |
+
We'll use face images from the CelebA dataset, resized to 64x64.
|
61 |
+
"""
|
62 |
+
|
63 |
+
#Download labels from public github, they have been processed in a 0,1 csv file
|
64 |
+
!mkdir "/content/celeba_gan"
|
65 |
+
!wget -q -O "/content/celeba_gan/list_attr_celeba01.csv.zip" "https://github.com/buoi/conditional-face-GAN/blob/main/list_attr_celeba01.csv.zip?raw=true"
|
66 |
+
!unzip -o "/content/celeba_gan/list_attr_celeba01.csv.zip" -d "/content/celeba_gan"
|
67 |
+
|
68 |
+
"""## Dataset preprocessing functions"""
|
69 |
+
|
70 |
+
# image utils functions
|
71 |
+
|
72 |
+
def conv_range(in_range=(-1,1), out_range=(0,255)):
|
73 |
+
""" Returns range conversion function"""
|
74 |
+
|
75 |
+
# compute means and spans once
|
76 |
+
in_mean, out_mean = np.mean(in_range), np.mean(out_range)
|
77 |
+
in_span, out_span = np.ptp(in_range), np.ptp(out_range)
|
78 |
+
|
79 |
+
# return function
|
80 |
+
def convert_img_range(in_img):
|
81 |
+
out_img = (in_img - in_mean) / in_span
|
82 |
+
out_img = out_img * out_span + out_mean
|
83 |
+
return out_img
|
84 |
+
|
85 |
+
return convert_img_range
|
86 |
+
|
87 |
+
def crop128(img):
|
88 |
+
#return img[:, 77:141, 57:121]# 64,64 center crop
|
89 |
+
return img[:, 45:173, 25:153] # 128,128 center crop
|
90 |
+
|
91 |
+
def resize64(img):
|
92 |
+
return tf.image.resize(img, (64,64), antialias=True, method='bilinear')
|
93 |
+
|
94 |
+
"""# 📉 Evaluate model
|
95 |
+
|
96 |
+
## Load trained GAN
|
97 |
+
"""
|
98 |
+
|
99 |
+
from collections import namedtuple
|
100 |
+
ModelEntry = namedtuple('ModelEntry', '''entity, resume_id, model_name, exp_name, best_epoch, expected_fid, expected_f1, expected_acc, epoch_range, exp_avg_fid, exp_avg_f1, exp_avg_acc''')
|
101 |
+
|
102 |
+
dcgan = ModelEntry('buio','t1qyadzp', 'dcgan','dcgan',
|
103 |
+
'v24',25.64, 0, 0, (16,28), 24.29, 0,0)
|
104 |
+
|
105 |
+
acgan2 = ModelEntry('buio','rn8xslip','acgan2_BNstdev','acgan2_BNstdev',
|
106 |
+
'v22', 29.05, 0.9182, 0.9295, (20,31), 24.79, 0.918, 0.926)
|
107 |
+
|
108 |
+
acgan10 = ModelEntry('buio','3ja6uvac','acgan10_nonseparBNstdev_split','acgan10_nonseparBNstdev_split',
|
109 |
+
'v24', 26.89, 0.859, 0.785, (18,30), 24.59, 0.789,0.858)
|
110 |
+
|
111 |
+
acgan40 = ModelEntry('buio','2ev65fpt','acgan40_BNstdev','acgan40_BNstdev',
|
112 |
+
'v15', 28.23, 0.430, 0.842, (15,25), 27.72, 04.6, 0.851)
|
113 |
+
|
114 |
+
|
115 |
+
acgan2_hd = ModelEntry('buio','6km6fdgr','acgan2_BNstdev_218x178','acgan2_BNstdev_218x178',
|
116 |
+
'v11', 0,0,0 ,(0,0), 0, 0, 0)
|
117 |
+
|
118 |
+
acgan10_hd = ModelEntry('buio','3v366skw','acgan40_BNstdev_218x178','acgan40_BNstdev_218x178',
|
119 |
+
'v14', 0,0,0, (0,0),0, 0, 0)
|
120 |
+
|
121 |
+
acgan40_hd = ModelEntry('buio','booicugb','acgan10_nonseparBNstdev_split_299_218x178','acgan10_nonseparBNstdev_split_299_218x178',
|
122 |
+
'v14', 52.9, 0.410, 0.834, (12,15), 0, 0, 0)
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
#1cr1a5w4 SAGAN_3 v31 buianifolli
|
127 |
+
#2o3z6bqb SAGAN_5 v17 buianifolli
|
128 |
+
#zscel8bz SAGAN_6 v29 buianifolli
|
129 |
+
|
130 |
+
#sagan40 v18
|
131 |
+
|
132 |
+
keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3MzA4NTY=/5f09f68e9bb5b09efbc37ad76cdcdbb0"
|
133 |
+
saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/2676cd88ef1866d6e572916e413a933e"
|
134 |
+
variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/5cab1cb7351f0732ea137fb2d2e0d4ec"
|
135 |
+
index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/480b55762c3358f868b8cce53984736b"
|
136 |
+
|
137 |
+
#sagan10 v16
|
138 |
+
|
139 |
+
keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMDQwMDE=/392d036bf91d3648eb5a2fa74c1eb716"
|
140 |
+
saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a5f8608efcc5dafbe780babcffbc79a9"
|
141 |
+
variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a62bf0c4bf7047c0a31df7d2cfdb54f0"
|
142 |
+
index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/de6539a7f0909d1dafa89571c7df43d1"
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
#download model
|
148 |
+
gan_path = "/content/gan_model/"
|
149 |
+
try:
|
150 |
+
os.remove(gan_path+"keras_metadata.pb")
|
151 |
+
os.remove(gan_path+"saved_model.pb")
|
152 |
+
os.remove(gan_path+"variables/variables.data-00000-of-00001")
|
153 |
+
os.remove(gan_path+"variables/variables.index")
|
154 |
+
except FileNotFoundError:
|
155 |
+
pass
|
156 |
+
os.makedirs(gan_path,exist_ok =True)
|
157 |
+
os.makedirs(gan_path+"/variables",exist_ok =True)
|
158 |
+
|
159 |
+
|
160 |
+
!pip install wget -q
|
161 |
+
import wget
|
162 |
+
wget.download(keras_metadata_url, gan_path+"keras_metadata.pb",)
|
163 |
+
wget.download(saved_model_url, gan_path+"saved_model.pb")
|
164 |
+
wget.download(variables_url, gan_path+"variables/variables.data-00000-of-00001")
|
165 |
+
wget.download(index_url, gan_path+"variables/variables.index")
|
166 |
+
|
167 |
+
gan = tf.keras.models.load_model(gan_path)
|
168 |
+
|
169 |
+
IMAGE_RANGE='11'
|
170 |
+
IMAGE_SIZE = gan.discriminator.input_shape[1]
|
171 |
+
if IMAGE_SIZE == 64:
|
172 |
+
IMAGE_SHAPE = (64,64,3)
|
173 |
+
elif IMAGE_SIZE == 218:
|
174 |
+
IMAGE_SHAPE = (218,178,3)
|
175 |
+
|
176 |
+
try:
|
177 |
+
LATENT_DIM = gan.generator.input_shape[0][1]
|
178 |
+
N_ATTRIBUTES = gan.generator.input_shape[1][1]
|
179 |
+
except TypeError:
|
180 |
+
LATENT_DIM = gan.generator.input_shape[1]
|
181 |
+
N_ATTRIBUTES =0
|
182 |
+
|
183 |
+
"""## 💾 Dataset"""
|
184 |
+
|
185 |
+
#@title Select Attributes {form-width: "50%", display-mode: "both" }
|
186 |
+
|
187 |
+
#NUMBER_OF_ATTRIBUTES = "10" #@param [0, 2, 10, 12, 40]
|
188 |
+
#N_ATTRIBUTES = int(NUMBER_OF_ATTRIBUTES)
|
189 |
+
|
190 |
+
IMAGE_RANGE = '11'
|
191 |
+
|
192 |
+
BATCH_SIZE = 64 #@param {type: "number"}
|
193 |
+
if N_ATTRIBUTES == 2:
|
194 |
+
LABELS = ["Male", "Smiling"]
|
195 |
+
|
196 |
+
elif N_ATTRIBUTES == 10:
|
197 |
+
LABELS = [
|
198 |
+
"Mouth_Slightly_Open", "Wearing_Lipstick", "High_Cheekbones", "Male", "Smiling",
|
199 |
+
"Heavy_Makeup", "Wavy_Hair", "Oval_Face", "Pointy_Nose", "Arched_Eyebrows"]
|
200 |
+
|
201 |
+
elif N_ATTRIBUTES == 12:
|
202 |
+
LABELS = ['Wearing_Lipstick','Mouth_Slightly_Open','Male','Smiling',
|
203 |
+
'High_Cheekbones','Heavy_Makeup','Attractive','Young',
|
204 |
+
'No_Beard','Black_Hair','Arched_Eyebrows','Big_Nose']
|
205 |
+
elif N_ATTRIBUTES == 40:
|
206 |
+
LABELS = [
|
207 |
+
'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
|
208 |
+
'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
|
209 |
+
'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows',
|
210 |
+
'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair',
|
211 |
+
'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open',
|
212 |
+
'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin',
|
213 |
+
'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns',
|
214 |
+
'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
|
215 |
+
'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
|
216 |
+
'Wearing_Necktie', 'Young']
|
217 |
+
|
218 |
+
|
219 |
+
else:
|
220 |
+
LABELS = ["Male", "Smiling"]# just for dataset creation
|
221 |
+
|
222 |
+
|
223 |
+
# Take labels and a list of image locations in memory
|
224 |
+
df = pd.read_csv(r"/content/celeba_gan/list_attr_celeba01.csv")
|
225 |
+
attr_list = df[LABELS].values.tolist()
|
226 |
+
|
227 |
+
def gen_img(attributes):
|
228 |
+
|
229 |
+
attr = np.zeros((1,N_ATTRIBUTES))
|
230 |
+
for a in attributes:
|
231 |
+
attr[0,int(a)] = 1
|
232 |
+
num_img = 1
|
233 |
+
random_latent_vectors = tf.random.normal(shape=(num_img, LATENT_DIM))
|
234 |
+
|
235 |
+
generated_images = gan.generator((random_latent_vectors, attr))
|
236 |
+
generated_images = (generated_images*0.5+0.5).numpy()
|
237 |
+
print(generated_images[0].shape)
|
238 |
+
return generated_images[0]
|
239 |
+
|
240 |
+
iface = gr.Interface(
|
241 |
+
gen_img,
|
242 |
+
gr.inputs.CheckboxGroup([LABELS[i] for i in range(N_ATTRIBUTES)], type='index'),
|
243 |
+
"image",
|
244 |
+
layout='unaligned'
|
245 |
+
)
|
246 |
+
iface.launch(debug=True)
|
247 |
+
|
248 |
+
def sentence_builder(quantity, animal, place, activity_list, morning):
|
249 |
+
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
|
250 |
+
|
251 |
+
def generate_image(attributes)
|
252 |
|
253 |
+
iface = gr.Interface(
|
254 |
+
sentence_builder,
|
255 |
+
[
|
256 |
+
gr.inputs.Slider(2, 20),
|
257 |
+
gr.inputs.Dropdown(["cat", "dog", "bird"]),
|
258 |
+
gr.inputs.Radio(["park", "zoo", "road"]),
|
259 |
+
gr.inputs.CheckboxGroup(["ran", "swam", "ate", "slept"]),
|
260 |
+
gr.inputs.Checkbox(label="Is it the morning?"),
|
261 |
+
],
|
262 |
+
"text",
|
263 |
+
examples=[
|
264 |
+
[2, "cat", "park", ["ran", "swam"], True],
|
265 |
+
[4, "dog", "zoo", ["ate", "swam"], False],
|
266 |
+
[10, "bird", "road", ["ran"], False],
|
267 |
+
[8, "cat", "zoo", ["ate"], True],
|
268 |
+
],
|
269 |
+
)
|
270 |
iface.launch()
|