padmanabhbosamia commited on
Commit
f6a2113
1 Parent(s): af5f9f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -229
app.py CHANGED
@@ -1,16 +1,9 @@
1
- #!pip install -q --upgrade transformers diffusers ftfy
2
- #!pip install -q --upgrade transformers==4.25.1 diffusers ftfy
3
- #!pip install accelerate -q
4
-
5
  from base64 import b64encode
6
-
7
- import numpy
8
  import torch
9
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
10
- from huggingface_hub import notebook_login
11
 
12
- # For video display:
13
- from IPython.display import HTML
14
  from matplotlib import pyplot as plt
15
  from pathlib import Path
16
  from PIL import Image
@@ -18,57 +11,51 @@ from torch import autocast
18
  from torchvision import transforms as tfms
19
  from tqdm.auto import tqdm
20
  from transformers import CLIPTextModel, CLIPTokenizer, logging
21
- import gradio as gr
22
- torch.manual_seed(1)
23
- #if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
24
 
25
- # Supress some unnecessary warnings when loading the CLIPTextModel
26
  logging.set_verbosity_error()
27
 
28
- # Set device
29
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
- #import os
32
- #MY_TOKEN=os.environ.get('Learning')
33
-
34
 
35
- # Load the autoencoder model which will be used to decode the latents into image space.
36
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") #,use_auth_token=MY_TOKEN)
37
-
38
- # Load the tokenizer and text encoder to tokenize and encode the text.
39
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
40
  text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
41
 
42
- # The UNet model for generating the latents.
43
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
44
 
45
- # The noise scheduler
46
  scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
47
 
48
- # To the GPU we go!
49
  vae = vae.to(torch_device)
50
  text_encoder = text_encoder.to(torch_device)
51
  unet = unet.to(torch_device)
52
 
53
- """Functions"""
 
 
54
 
55
- def pil_to_latent(input_im):
56
- # Single image -> single latent in a batch (so size 1, 4, 64, 64)
57
- with torch.no_grad():
58
- latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
59
- return 0.18215 * latent.latent_dist.sample()
60
 
61
- def latents_to_pil(latents):
62
- # bath of latents -> list of images
63
- latents = (1 / 0.18215) * latents
64
- with torch.no_grad():
65
- image = vae.decode(latents).sample
66
- image = (image / 2 + 0.5).clamp(0, 1)
67
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
68
- images = (image * 255).round().astype("uint8")
69
- pil_images = [Image.fromarray(image) for image in images]
70
- return pil_images
71
 
 
 
 
 
72
 
73
  def get_output_embeds(input_embeddings):
74
  # CLIP's text model uses causal mask, so we prepare it here:
@@ -95,177 +82,92 @@ def get_output_embeds(input_embeddings):
95
  # And now they're ready!
96
  return output
97
 
98
- #Generating an image with these modified embeddings
99
-
100
- def generate_with_embs(text_embeddings, text_input):
101
- height = 512 # default height of Stable Diffusion
102
- width = 512 # default width of Stable Diffusion
103
- num_inference_steps = 7 # Number of denoising steps
104
- guidance_scale = 7.5 # Scale for classifier-free guidance
105
- generator = torch.manual_seed(64) # Seed generator to create the inital latent noise
106
- batch_size = 1
107
-
108
- max_length = text_input.input_ids.shape[-1]
109
- uncond_input = tokenizer(
110
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
111
- )
112
- with torch.no_grad():
113
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
114
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
115
-
116
- # Prep Scheduler
117
- scheduler.set_timesteps(num_inference_steps)
118
-
119
- # Prep latents
120
- latents = torch.randn(
121
- (batch_size, unet.config.in_channels, height // 8, width // 8),
122
- generator=generator,
123
- )
124
- latents = latents.to(torch_device)
125
- latents = latents * scheduler.init_noise_sigma
126
-
127
- # Loop
128
- for i, t in tqdm(enumerate(scheduler.timesteps)):
129
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
130
- latent_model_input = torch.cat([latents] * 2)
131
- sigma = scheduler.sigmas[i]
132
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
133
-
134
- # predict the noise residual
135
- with torch.no_grad():
136
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
137
-
138
- # perform guidance
139
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
140
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
141
-
142
- # compute the previous noisy sample x_t -> x_t-1
143
- latents = scheduler.step(noise_pred, t, latents).prev_sample
144
-
145
- return latents_to_pil(latents)[0]
146
-
147
- def ref_loss(images,ref_image):
148
- # Reference image
149
- error = torch.abs(images - ref_image).mean()
150
- return error
151
-
152
- def inference(prompt, style_index):
153
-
154
- styles = ['<snoopy>', '<boot-mjstyle>','<birb-style>','<pop_art>','<ronaldo>','<Thumps_up>']
155
- embed = ['snoopy.bin','boot-mjstyle.bin', 'bird_style.bin', 'pop_art.bin','ronaldo.bin','Thumps_up.bin']
156
 
 
157
 
158
- # Tokenize
159
- text_input = tokenizer(prompt+" .", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
160
- # Access the embedding layer
161
- token_emb_layer = text_encoder.text_model.embeddings.token_embedding
162
- token_embeddings = token_emb_layer(text_input.input_ids.to(torch_device))
163
- pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
164
-
165
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
166
- position_embeddings = pos_emb_layer(position_ids)
167
-
168
- ## Without any Textual Inversion
169
- input_ids = text_input.input_ids.to(torch_device)
170
-
171
- # Get token embeddings
172
- token_embeddings = token_emb_layer(input_ids)
173
-
174
- # Combine with pos embs
175
- input_embeddings = token_embeddings + position_embeddings
176
-
177
- # Feed through to get final output embs
178
- modified_output_embeddings = get_output_embeds(input_embeddings)
179
-
180
- # And generate an image with this:
181
- image1 = generate_with_embs(modified_output_embeddings,text_input)
182
-
183
- replace_id=269 #replaced dot with Textual Inversion
184
-
185
- ## midjourney-style
186
- style = styles[style_index]
187
- emb = embed[style_index]
188
-
189
- x_embed = torch.load(emb)
190
-
191
- # The new embedding - our special birb word
192
- replacement_token_embedding = x_embed[style].to(torch_device)
193
-
194
- # Insert this into the token embeddings
195
- token_embeddings[0, torch.where(input_ids[0]==replace_id)] = replacement_token_embedding.to(torch_device)
196
-
197
- # Combine with pos embs
198
- input_embeddings = token_embeddings + position_embeddings
199
-
200
- # Feed through to get final output embs
201
- modified_output_embeddings = get_output_embeds(input_embeddings)
202
 
203
- # And generate an image with this:
204
- image2 = generate_with_embs(modified_output_embeddings,text_input)
205
 
206
- prompt1 = 'rainbow'
207
 
208
- # Tokenize
209
- text_input1 = tokenizer(prompt1, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
 
 
 
210
 
211
- # Access the embedding layer
212
- token_emb_layer = text_encoder.text_model.embeddings.token_embedding
 
 
 
 
 
 
 
 
213
 
214
- pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
215
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
216
- position_embeddings1 = pos_emb_layer(position_ids)
 
217
 
218
- input_ids1 = text_input1.input_ids.to(torch_device)
 
 
 
219
 
220
- # Get token embeddings
221
- token_embeddings1 = token_emb_layer(input_ids1)
222
 
223
- # Combine with pos embs
224
- input_embeddings1 = token_embeddings1 + position_embeddings1
225
 
226
- # Feed through to get final output embs
227
- modified_output_embeddings1 = get_output_embeds(input_embeddings1)
228
 
229
- # And generate an image with this:
230
- ref_image = generate_with_embs(modified_output_embeddings1, text_input1)
 
231
 
232
- ref_latent = pil_to_latent(ref_image)
233
 
234
- height = 512 # default height of Stable Diffusion
235
- width = 512 # default width of Stable Diffusion
236
- num_inference_steps = 7 # # Number of denoising steps
237
- guidance_scale = 8 # # Scale for classifier-free guidance
238
- generator = torch.manual_seed(64) # Seed generator to create the inital latent noise
239
  batch_size = 1
240
- blue_loss_scale = 200 #
241
-
242
- # Prep text
243
- text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
244
- with torch.no_grad():
245
- text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
246
 
247
- # And the uncond. input as before:
248
- max_length = text_input.input_ids.shape[-1]
249
  uncond_input = tokenizer(
250
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
251
  )
252
  with torch.no_grad():
253
  uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
254
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
255
 
256
  # Prep Scheduler
257
- scheduler.set_timesteps(num_inference_steps)
258
 
259
  # Prep latents
260
  latents = torch.randn(
261
- (batch_size, unet.config.in_channels, height // 8, width // 8),
262
- generator=generator,
263
  )
264
  latents = latents.to(torch_device)
265
  latents = latents * scheduler.init_noise_sigma
266
 
267
  # Loop
268
- for i, t in tqdm(enumerate(scheduler.timesteps)):
269
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
270
  latent_model_input = torch.cat([latents] * 2)
271
  sigma = scheduler.sigmas[i]
@@ -275,63 +177,110 @@ def inference(prompt, style_index):
275
  with torch.no_grad():
276
  noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
277
 
278
- # perform CFG
279
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
280
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
281
 
282
- #### ADDITIONAL GUIDANCE ###
283
- if i%5 == 0:
284
- # Requires grad on the latents
285
- latents = latents.detach().requires_grad_()
286
-
287
- # Get the predicted x0:
288
- # latents_x0 = latents - sigma * noise_pred
289
- latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
290
-
291
- # Decode to image space
292
- denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
293
-
294
- #ref image
295
- with torch.no_grad():
296
- ref_images = vae.decode((1 / 0.18215) * ref_latent).sample / 2 + 0.5 # range (0, 1)
297
-
298
- # Calculate loss
299
- loss = ref_loss(denoised_images,ref_images) * blue_loss_scale
300
-
301
- # Occasionally print it out
302
- # if i%10==0:
303
- # print(i, 'loss:', loss.item())
304
-
305
- # Get gradient
306
- cond_grad = torch.autograd.grad(loss, latents)[0]
307
 
308
- # Modify the latents based on this gradient
309
- latents = latents.detach() - cond_grad * sigma**2
310
- scheduler._step_index = scheduler._step_index - 1
311
 
 
 
312
 
313
- # Now step with scheduler
314
- latents = scheduler.step(noise_pred, t, latents).prev_sample
315
- #latents = scheduler.step(noise_pred, t, latents).pred_original_sample
316
 
 
 
 
 
317
 
318
- image3 = latents_to_pil(latents)[0]
319
-
320
- return (image1, 'Original Image'), (image2, 'Styled Image'), (image3, 'After Textual Inversion')
321
 
322
- # Gradio App with num_inference_steps=10
 
 
323
 
324
- title="Textual Inversion in Stable Diffusion"
325
- description="<p style='text-align: center;'>Textual Inversion in Stable Diffusion.</b></p>"
326
- gallery = gr.Gallery(label="Generated images", show_label=True, elem_id="gallery", columns=3).style(grid=[2], height="auto")
 
 
327
 
328
- gr.Interface(fn=inference, inputs=["text",
 
329
 
330
- gr.Radio([('<snoopy>',0), ('<boot-mjstyle>',1),('<birb-style>',2),
331
- ('<pop_art>',3),(' <ronaldo>',4),('<Thumps_up>',5)], value = 0, label = 'Style')],
332
- outputs = gallery, title = title,
333
- examples = [['a girl playing in snow',0],
334
- #['an oil painting of a goddess',6],
335
- #['a rabbit on the moon', 5 ]
336
- ],
337
- ).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from base64 import b64encode
2
+ import gradio as gr
3
+ import numpy as np
4
  import torch
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
 
6
 
 
 
7
  from matplotlib import pyplot as plt
8
  from pathlib import Path
9
  from PIL import Image
 
11
  from torchvision import transforms as tfms
12
  from tqdm.auto import tqdm
13
  from transformers import CLIPTextModel, CLIPTokenizer, logging
14
+ import os
15
+ import cv2
16
+ import torchvision.transforms as T
17
 
18
+ torch.manual_seed(1)
19
  logging.set_verbosity_error()
20
 
 
21
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # Load the autoencoder
24
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
 
25
 
26
+ # Load tokenizer and text encoder to tokenize and encode the text
 
 
 
27
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
28
  text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
29
 
30
+ # Unet model for generating latents
31
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
32
 
33
+ # Noise scheduler
34
  scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
35
 
36
+ # Move everything to GPU
37
  vae = vae.to(torch_device)
38
  text_encoder = text_encoder.to(torch_device)
39
  unet = unet.to(torch_device)
40
 
41
+ style_files = ['Thumps_up.bin', 'birb_style.bin',
42
+ 'snoopy.bin', 'pop_art.bin',
43
+ 'boot-mjstyle.bin']
44
 
45
+ images_without_loss = []
46
+ images_with_loss = []
 
 
 
47
 
48
+ seed_values = [8,16,50,80,128]
49
+ height = 512 # default height of Stable Diffusion
50
+ width = 512 # default width of Stable Diffusion
51
+ num_inference_steps = 5 # Number of denoising steps
52
+ guidance_scale = 7.5 # Scale for classifier-free guidance
53
+ num_styles = len(style_files)
 
 
 
 
54
 
55
+ # Prep Scheduler
56
+ def set_timesteps(scheduler, num_inference_steps):
57
+ scheduler.set_timesteps(num_inference_steps)
58
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
59
 
60
  def get_output_embeds(input_embeddings):
61
  # CLIP's text model uses causal mask, so we prepare it here:
 
82
  # And now they're ready!
83
  return output
84
 
85
+ def get_style_embeddings(style_file):
86
+ style_embed = torch.load(style_file)
87
+ style_name = list(style_embed.keys())[0]
88
+ return style_embed[style_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ import torch
91
 
92
+ def vibrance_loss(image):
93
+ # Calculate the standard deviation of color channels
94
+ std_dev = torch.std(image, dim=(2, 3)) # Compute standard deviation over height and width
95
+ # Calculate the mean standard deviation across the batch
96
+ mean_std_dev = torch.mean(std_dev)
97
+ # You can adjust a scale factor to control the strength of vibrance regularization
98
+ scale_factor = 100.0
99
+ # Calculate the vibrance loss
100
+ loss = -scale_factor * mean_std_dev
101
+ return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
103
 
104
+ from torchvision.transforms import ToTensor
105
 
106
+ def pil_to_latent(input_im):
107
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
108
+ with torch.no_grad():
109
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
110
+ return 0.18215 * latent.latent_dist.sample()
111
 
112
+ def latents_to_pil(latents):
113
+ # bath of latents -> list of images
114
+ latents = (1 / 0.18215) * latents
115
+ with torch.no_grad():
116
+ image = vae.decode(latents).sample
117
+ image = (image / 2 + 0.5).clamp(0, 1)
118
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
119
+ images = (image * 255).round().astype("uint8")
120
+ pil_images = [Image.fromarray(image) for image in images]
121
+ return pil_images
122
 
123
+ def additional_guidance(latents, scheduler, noise_pred, t, sigma, custom_loss_fn):
124
+ #### ADDITIONAL GUIDANCE ###
125
+ # Requires grad on the latents
126
+ latents = latents.detach().requires_grad_()
127
 
128
+ # Get the predicted x0:
129
+ latents_x0 = latents - sigma * noise_pred
130
+ #print(f"latents: {latents.shape}, noise_pred:{noise_pred.shape}")
131
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
132
 
133
+ # Decode to image space
134
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
135
 
136
+ # Calculate loss
137
+ loss = custom_loss_fn(denoised_images)
138
 
139
+ # Get gradient
140
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
141
 
142
+ # Modify the latents based on this gradient
143
+ latents = latents.detach() - cond_grad * sigma**2
144
+ return latents, loss
145
 
 
146
 
147
+ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None):
148
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
 
 
 
149
  batch_size = 1
 
 
 
 
 
 
150
 
 
 
151
  uncond_input = tokenizer(
152
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
153
  )
154
  with torch.no_grad():
155
  uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
156
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
157
 
158
  # Prep Scheduler
159
+ set_timesteps(scheduler, num_inference_steps)
160
 
161
  # Prep latents
162
  latents = torch.randn(
163
+ (batch_size, unet.in_channels, height // 8, width // 8),
164
+ generator=generator,
165
  )
166
  latents = latents.to(torch_device)
167
  latents = latents * scheduler.init_noise_sigma
168
 
169
  # Loop
170
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
171
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
172
  latent_model_input = torch.cat([latents] * 2)
173
  sigma = scheduler.sigmas[i]
 
177
  with torch.no_grad():
178
  noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
179
 
180
+ # perform guidance
181
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
182
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
183
+ if loss_fn is not None:
184
+ if i%2 == 0:
185
+ latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn)
186
 
187
+ # compute the previous noisy sample x_t -> x_t-1
188
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ return latents_to_pil(latents)[0]
 
 
191
 
192
+ def generate_images(prompt, style_num=None, random_seed=41, custom_loss_fn = None):
193
+ eos_pos = len(prompt.split())+1
194
 
195
+ style_token_embedding = None
196
+ if style_num:
197
+ style_token_embedding = get_style_embeddings(style_files[style_num])
198
 
199
+ # tokenize
200
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
201
+ max_length = text_input.input_ids.shape[-1]
202
+ input_ids = text_input.input_ids.to(torch_device)
203
 
204
+ # get token embeddings
205
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
206
+ token_embeddings = token_emb_layer(input_ids)
207
 
208
+ # Append style token towards the end of the sentence embeddings
209
+ if style_token_embedding is not None:
210
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
211
 
212
+ # combine with pos embs
213
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
214
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
215
+ position_embeddings = pos_emb_layer(position_ids)
216
+ input_embeddings = token_embeddings + position_embeddings
217
 
218
+ # Feed through to get final output embs
219
+ modified_output_embeddings = get_output_embeds(input_embeddings)
220
 
221
+ # And generate an image with this:
222
+ generated_image = generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn)
223
+ return generated_image
224
+
225
+ import matplotlib.pyplot as plt
226
+
227
+ def display_images_in_rows(images_with_titles, titles):
228
+ num_images = len(images_with_titles)
229
+ rows = 5 # Display 5 rows always
230
+ columns = 1 if num_images == 5 else 2 # Use 1 column if there are 5 images, otherwise 2 columns
231
+ fig, axes = plt.subplots(rows, columns + 1, figsize=(15, 5 * rows)) # Add an extra column for titles
232
+
233
+ for r in range(rows):
234
+ # Add the title on the extreme left in the middle of each picture
235
+ axes[r, 0].text(0.5, 0.5, titles[r], ha='center', va='center')
236
+ axes[r, 0].axis('off')
237
+
238
+ # Add "Without Loss" label above the first column and "With Loss" label above the second column (if applicable)
239
+ if columns == 2:
240
+ axes[r, 1].set_title("Without Loss", pad=10)
241
+ axes[r, 2].set_title("With Loss", pad=10)
242
+
243
+ for c in range(1, columns + 1):
244
+ index = r * columns + c - 1
245
+ if index < num_images:
246
+ image, _ = images_with_titles[index]
247
+ axes[r, c].imshow(image)
248
+ axes[r, c].axis('off')
249
+
250
+ return fig
251
+ # plt.show()
252
+
253
+
254
+ def image_generator(prompt = "dog", loss_function=None):
255
+ images_without_loss = []
256
+ images_with_loss = []
257
+ if loss_function == "Yes":
258
+ loss_function = vibrance_loss
259
+ else:
260
+ loss_function = None
261
+
262
+ for i in range(num_styles):
263
+ generated_img = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
264
+ images_without_loss.append(generated_img)
265
+ if loss_function:
266
+ generated_img = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
267
+ images_with_loss.append(generated_img)
268
+
269
+ generated_sd_images = []
270
+ titles = ["Bird_style", "Boot-mjstyle", "Snoopy Style", "Pop Art Style", "Thumpsup Style"]
271
+
272
+ for i in range(len(titles)):
273
+ generated_sd_images.append((images_without_loss[i], titles[i]))
274
+ if images_with_loss != []:
275
+ generated_sd_images.append((images_with_loss[i], titles[i]))
276
+
277
+
278
+ return display_images_in_rows(generated_sd_images, titles)
279
+
280
+ description = "Generate an image with a prompt and apply vibrance loss if you wish to. Note that the app is hosted on a cpu and it takes atleast 15 minutes for generating images without loss. Please feel free to clone the space and use it with a GPU after increase the inference steps to more than 10 for better results"
281
+
282
+ demo = gr.Interface(image_generator,
283
+ inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="dog sitting on a bench"),
284
+ gr.Radio(["Yes", "No"], value="No" , label="Apply vibrance loss")],
285
+ outputs=gr.Plot(label="Generated Images"), title = "Stable Diffusion using Textual Inversion", description=description)
286
+ demo.launch()