GunaKoppula commited on
Commit
475842f
·
verified ·
1 Parent(s): 0cb85e1

Upload 10 files

Browse files
README.md CHANGED
@@ -1,13 +1,89 @@
1
  ---
2
- title: Session20
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.47.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: "ERA SESSION20 - Stable Diffusion: Generative Art with Guidance"
3
+ emoji: 🌍
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.48.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ **Styles Used:**
14
+ 1. [Oil style](https://huggingface.co/sd-concepts-library/oil-style)
15
+ 2. [Xyz](https://huggingface.co/sd-concepts-library/xyz)
16
+ 3. [Allante](https://huggingface.co/sd-concepts-library/style-of-marc-allante)
17
+ 4. [Moebius](https://huggingface.co/sd-concepts-library/moebius)
18
+ 5. [Polygons](https://huggingface.co/sd-concepts-library/low-poly-hd-logos-icons)
19
+
20
+ ### Result of Experiments with different styles:
21
+ **Prompt:** `"a cat and dog in the style of cs"` \
22
+ _"cs" in the prompt refers to "custom style" whose embedding is replaced by each of the concept embeddings shown below_
23
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/1effe375-6ef4-4adc-be7b-d6311fdaa50d)
24
+
25
+ ---
26
+ **Prompt:** `"dolphin swimming on Mars in the style of cs"`
27
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/2cd32248-4233-42c0-97c0-00e1ae8fdc85)
28
+
29
+ ### Result of Experiments with Guidance loss functions:
30
+ **Prompt:** `"a mouse in the style of cs"`
31
+ **Loss Function:**
32
+ ```python
33
+ def loss_fn(images):
34
+ return images.mean()
35
+ ```
36
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/c9d46e14-44bb-4ea7-88a4-26ef46344fce)
37
+ ---
38
+ ```python
39
+ def loss_fn(images):
40
+ return -images.median()/3
41
+ ```
42
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/2649e4f6-3de5-4e54-8f22-3d65874b7b07)
43
+ ---
44
+ ```python
45
+ def loss_fn(images):
46
+ error = (images - images.min()) / 255*(images.max() - images.min())
47
+ return error.mean()
48
+ ```
49
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/6399c780-e9b7-42f8-8d90-44c8b40d5265)
50
+ ---
51
+ **Prompt:** `"angry german shephard in the style of cs"`
52
+ ```python
53
+ def loss_fn(images):
54
+ error1 = torch.abs(images[:, 0] - 0.9)
55
+ error2 = torch.abs(images[:, 1] - 0.9)
56
+ error3 = torch.abs(images[:, 2] - 0.9)
57
+ return (
58
+ torch.sin(error1.mean()) + torch.sin(error2.mean()) + torch.sin(error3.mean())
59
+ ) / 3
60
+ ```
61
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/fa7d30ed-4efd-4504-b89c-94e093f51f9c)
62
+
63
+ ---
64
+ **Prompt:** `"A campfire (oil on canvas)"`
65
+ ```python
66
+ def loss_fn(images):
67
+ error1 = torch.abs(images[:, 0] - 0.9)
68
+ error2 = torch.abs(images[:, 1] - 0.9)
69
+ error3 = torch.abs(images[:, 2] - 0.9)
70
+ return (
71
+ torch.sin((error1 * error2 * error3)).mean()
72
+ + torch.cos((error1 * error2 * error3)).mean()
73
+ )
74
+ ```
75
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/88382dae-6701-4103-a664-ed17727b690f)
76
+
77
+ ---
78
+ ```python
79
+ def loss_fn(images):
80
+ error1 = torch.abs(images[:, 0] - 0.9)
81
+ error2 = torch.abs(images[:, 1] - 0.9)
82
+ error3 = torch.abs(images[:, 2] - 0.9)
83
+ return (
84
+ torch.sin(error1.mean()) + torch.sin(error2.mean()) + torch.sin(error3.mean())
85
+ ) / 3
86
+ ```
87
+ ![image](https://github.com/RaviNaik/ERA-SESSION20/assets/23289802/0ab3edad-579d-4821-b992-6c18b61bd444)
88
+
89
+
app.py CHANGED
@@ -1,307 +1,91 @@
1
  import gradio as gr
2
- from base64 import b64encode
3
- import numpy
4
  import torch
5
- from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
- from PIL import Image
7
- from torch import autocast
8
- from torchvision import transforms as tfms
9
- from tqdm.auto import tqdm
10
- from transformers import CLIPTextModel, CLIPTokenizer, logging
11
- import torchvision.transforms as T
12
-
13
- torch.manual_seed(1)
14
- logging.set_verbosity_error()
15
- torch_device = "cpu"
16
-
17
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
18
-
19
- # Load the tokenizer and text encoder to tokenize and encode the text.
20
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
21
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
22
-
23
- # The UNet model for generating the latents.
24
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
25
-
26
- # The noise scheduler
27
- scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
28
-
29
- vae = vae.to(torch_device)
30
- text_encoder = text_encoder.to(torch_device)
31
- unet = unet.to(torch_device);
32
-
33
- token_emb_layer = text_encoder.text_model.embeddings.token_embedding
34
- pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
35
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
36
- position_embeddings = pos_emb_layer(position_ids)
37
-
38
- def pil_to_latent(input_im):
39
- # Single image -> single latent in a batch (so size 1, 4, 64, 64)
40
- with torch.no_grad():
41
- latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
42
- return 0.18215 * latent.latent_dist.sample()
43
-
44
- def latents_to_pil(latents):
45
- # bath of latents -> list of images
46
- latents = (1 / 0.18215) * latents
47
- with torch.no_grad():
48
- image = vae.decode(latents).sample
49
- image = (image / 2 + 0.5).clamp(0, 1)
50
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
51
- images = (image * 255).round().astype("uint8")
52
- pil_images = [Image.fromarray(image) for image in images]
53
- return pil_images
54
-
55
- def get_output_embeds(input_embeddings):
56
- # CLIP's text model uses causal mask, so we prepare it here:
57
- bsz, seq_len = input_embeddings.shape[:2]
58
- causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
59
-
60
- # Getting the output embeddings involves calling the model with passing output_hidden_states=True
61
- # so that it doesn't just return the pooled final predictions:
62
- encoder_outputs = text_encoder.text_model.encoder(
63
- inputs_embeds=input_embeddings,
64
- attention_mask=None, # We aren't using an attention mask so that can be None
65
- causal_attention_mask=causal_attention_mask.to(torch_device),
66
- output_attentions=None,
67
- output_hidden_states=True, # We want the output embs not the final output
68
- return_dict=None,
69
- )
70
-
71
- # We're interested in the output hidden state only
72
- output = encoder_outputs[0]
73
-
74
- # There is a final layer norm we need to pass these through
75
- output = text_encoder.text_model.final_layer_norm(output)
76
-
77
- # And now they're ready!
78
- return output
79
-
80
- def generate_with_embs(text_embeddings, seed, max_length):
81
- height = 512 # default height of Stable Diffusion
82
- width = 512 # default width of Stable Diffusion
83
- num_inference_steps = 10 # Number of denoising steps
84
- guidance_scale = 7.5 # Scale for classifier-free guidance
85
- generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
86
- batch_size = 1
87
-
88
- # max_length = text_input.input_ids.shape[-1]
89
- uncond_input = tokenizer(
90
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
91
- )
92
- with torch.no_grad():
93
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
94
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
95
-
96
- # Prep Scheduler
97
- set_timesteps(scheduler, num_inference_steps)
98
-
99
- # Prep latents
100
- latents = torch.randn(
101
- (batch_size, unet.in_channels, height // 8, width // 8),
102
- generator=generator,
103
- )
104
- latents = latents.to(torch_device)
105
- latents = latents * scheduler.init_noise_sigma
106
-
107
- # Loop
108
- for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
109
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
110
- latent_model_input = torch.cat([latents] * 2)
111
- sigma = scheduler.sigmas[i]
112
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
113
-
114
- # predict the noise residual
115
- with torch.no_grad():
116
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
117
-
118
- # perform guidance
119
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
120
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
121
-
122
- # compute the previous noisy sample x_t -> x_t-1
123
- latents = scheduler.step(noise_pred, t, latents).prev_sample
124
-
125
- return latents_to_pil(latents)[0]
126
-
127
- # Prep Scheduler
128
- def set_timesteps(scheduler, num_inference_steps):
129
- scheduler.set_timesteps(num_inference_steps)
130
- scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
131
-
132
- def embed_style(prompt, style_embed, style_seed):
133
- # Tokenize
134
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
135
- input_ids = text_input.input_ids.to(torch_device)
136
-
137
- # Get token embeddings
138
- token_embeddings = token_emb_layer(input_ids)
139
-
140
- replacement_token_embedding = style_embed.to(torch_device)
141
-
142
- # Insert this into the token embeddings
143
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
144
-
145
- # Combine with pos embs
146
- input_embeddings = token_embeddings + position_embeddings
147
-
148
- # Feed through to get final output embs
149
- modified_output_embeddings = get_output_embeds(input_embeddings)
150
-
151
- # And generate an image with this:
152
- max_length = text_input.input_ids.shape[-1]
153
- return generate_with_embs(modified_output_embeddings, style_seed, max_length)
154
-
155
- def loss_style(prompt, style_embed, style_seed):
156
- # Tokenize
157
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
158
- input_ids = text_input.input_ids.to(torch_device)
159
-
160
- # Get token embeddings
161
- token_embeddings = token_emb_layer(input_ids)
162
-
163
- # The new embedding - our special birb word
164
- replacement_token_embedding = style_embed.to(torch_device)
165
-
166
- # Insert this into the token embeddings
167
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
168
-
169
- # Combine with pos embs
170
- input_embeddings = token_embeddings + position_embeddings
171
-
172
- # Feed through to get final output embs
173
- modified_output_embeddings = get_output_embeds(input_embeddings)
174
-
175
- # And generate an image with this:
176
- max_length = text_input.input_ids.shape[-1]
177
- return generate_loss_based_image(modified_output_embeddings, style_seed,max_length)
178
-
179
-
180
- def color_loss(image):
181
- color_channel = image[:, 1]
182
- target_value = 0.7
183
- error = torch.abs(color_channel - target_value).mean()
184
- return error
185
-
186
- def generate_loss_based_image(text_embeddings, seed, max_length):
187
-
188
- height = 64
189
- width = 64
190
- num_inference_steps = 10
191
- guidance_scale = 8
192
- generator = torch.manual_seed(64)
193
- batch_size = 1
194
- loss_scale = 200
195
-
196
- uncond_input = tokenizer(
197
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
198
- )
199
- with torch.no_grad():
200
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
201
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
202
-
203
- # Prep Scheduler
204
- set_timesteps(scheduler, num_inference_steps+1)
205
-
206
- # Prep latents
207
- latents = torch.randn(
208
- (batch_size, unet.in_channels, height // 8, width // 8),
209
- generator=generator,
210
- )
211
- latents = latents.to(torch_device)
212
- latents = latents * scheduler.init_noise_sigma
213
-
214
- sched_out = None
215
-
216
- # Loop
217
- for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
218
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
219
- latent_model_input = torch.cat([latents] * 2)
220
- sigma = scheduler.sigmas[i]
221
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
222
-
223
- # predict the noise residual
224
- with torch.no_grad():
225
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
226
-
227
- # perform CFG
228
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
229
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
230
-
231
- ### ADDITIONAL GUIDANCE ###
232
- if i%5 == 0 and i>0:
233
- # Requires grad on the latents
234
- latents = latents.detach().requires_grad_()
235
-
236
- # Get the predicted x0:
237
- scheduler._step_index -= 1
238
- latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
239
-
240
- # Decode to image space
241
- denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
242
-
243
-
244
- # Calculate loss
245
- loss = color_loss(denoised_images) * loss_scale
246
-
247
- # Occasionally print it out
248
- # if i%10==0:
249
- print(i, 'loss:', loss)
250
-
251
- # Get gradient
252
- cond_grad = torch.autograd.grad(loss, latents)[0]
253
-
254
- # Modify the latents based on this gradient
255
- latents = latents.detach() - cond_grad * sigma**2
256
- # To PIL Images
257
- im_t0 = latents_to_pil(latents_x0)[0]
258
- im_next = latents_to_pil(latents)[0]
259
-
260
- # Now step with scheduler
261
- latents = scheduler.step(noise_pred, t, latents).prev_sample
262
-
263
- return latents_to_pil(latents)[0]
264
-
265
-
266
- def generate_image_from_prompt(text_in, style_in):
267
- STYLE_LIST = ['coffeemachine.bin', 'collage_style.bin', 'cube.bin', 'jerrymouse2.bin', 'zero.bin']
268
- STYLE_SEEDS = [32, 64, 128, 16, 8]
269
-
270
- print(text_in)
271
- print(style_in)
272
- style_file = style_in + '.bin'
273
- idx = STYLE_LIST.index(style_file)
274
- print(style_file)
275
- print(idx)
276
-
277
- prompt = text_in + ' a puppy'
278
-
279
- style_seed = STYLE_SEEDS[idx]
280
- style_dict = torch.load(style_file)
281
- style_embed = [v for v in style_dict.values()]
282
-
283
- generated_image = embed_style(prompt, style_embed[0], style_seed)
284
-
285
- loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
286
-
287
- return [generated_image, loss_generated_img]
288
-
289
-
290
- # Define Interface
291
-
292
- title = 'ERA-SESSION20 Generative Art and Stable Diffusion'
293
-
294
- demo = gr.Interface(generate_image_from_prompt,
295
- inputs = [gr.Textbox(1, label='prompt'),
296
- gr.Dropdown(
297
- ['coffeemachine', 'collage_style', 'cube', 'jerrymouse2', 'zero'],value="cube", label="Pretrained Styles"
298
- )
299
- ],
300
- outputs = [
301
-
302
- gr.Gallery(label="Generated images", show_label=True, elem_id="gallery", columns=[2], rows=[2], object_fit="contain", height="auto")
303
- ],
304
-
305
- title = title
306
- )
307
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ import random
 
3
  import torch
4
+ import pathlib
5
+
6
+ from src.utils import concept_styles, loss_fn
7
+ from src.stable_diffusion import StableDiffusion
8
+
9
+ PROJECT_PATH = "."
10
+ CONCEPT_LIBS_PATH = f"{PROJECT_PATH}/concept_libs"
11
+
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ def generate(prompt, styles, gen_steps, loss_scale):
16
+ lossless_images, lossy_images = [], []
17
+ for style in styles:
18
+ concept_lib_path = f"{CONCEPT_LIBS_PATH}/{concept_styles[style]}"
19
+ concept_lib = pathlib.Path(concept_lib_path)
20
+ concept_embed = torch.load(concept_lib)
21
+
22
+ manual_seed = random.randint(0, 100)
23
+ diffusion = StableDiffusion(
24
+ device=DEVICE,
25
+ num_inference_steps=gen_steps,
26
+ manual_seed=manual_seed,
27
+ )
28
+ generated_image_lossless = diffusion.generate_image(
29
+ prompt=prompt,
30
+ loss_fn=loss_fn,
31
+ loss_scale=0,
32
+ concept_embed=concept_embed,
33
+ )
34
+ generated_image_lossy = diffusion.generate_image(
35
+ prompt=prompt,
36
+ loss_fn=loss_fn,
37
+ loss_scale=loss_scale,
38
+ concept_embed=concept_embed,
39
+ )
40
+ lossless_images.append((generated_image_lossless, style))
41
+ lossy_images.append((generated_image_lossy, style))
42
+ return {lossless_gallery: lossless_images, lossy_gallery: lossy_images}
43
+
44
+
45
+ with gr.Blocks() as app:
46
+ gr.Markdown("## ERA Session20 - Stable Diffusion: Generative Art with Guidance")
47
+ with gr.Row():
48
+ with gr.Column():
49
+ prompt_box = gr.Textbox(label="Prompt", interactive=True)
50
+ style_selector = gr.Dropdown(
51
+ choices=list(concept_styles.keys()),
52
+ value=list(concept_styles.keys())[0],
53
+ multiselect=True,
54
+ label="Select a Concept Style",
55
+ interactive=True,
56
+ )
57
+ gen_steps = gr.Slider(
58
+ minimum=10,
59
+ maximum=50,
60
+ value=30,
61
+ step=10,
62
+ label="Select Number of Steps",
63
+ interactive=True,
64
+ )
65
+
66
+ loss_scale = gr.Slider(
67
+ minimum=0,
68
+ maximum=32,
69
+ value=8,
70
+ step=8,
71
+ label="Select Guidance Scale",
72
+ interactive=True,
73
+ )
74
+
75
+ submit_btn = gr.Button(value="Generate")
76
+
77
+ with gr.Column():
78
+ lossless_gallery = gr.Gallery(
79
+ label="Generated Images without Guidance", show_label=True
80
+ )
81
+ lossy_gallery = gr.Gallery(
82
+ label="Generated Images with Guidance", show_label=True
83
+ )
84
+
85
+ submit_btn.click(
86
+ generate,
87
+ inputs=[prompt_box, style_selector, gen_steps, loss_scale],
88
+ outputs=[lossless_gallery, lossy_gallery],
89
+ )
90
+
91
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_libs/coffeemachine.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc3a85dc9cbdf6ab5fca4056c473da1b632c0565030be918682ce3e62095b4b1
3
+ size 3840
concept_libs/collage_style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b143c4841c5f2d39d0eb2015d62c17d1b18da9bb0a42c76320df7acfe1e144bf
3
+ size 3840
concept_libs/cube.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a6d6394f0cd38847259c42746a6b0e50ca1e76e6ddc8e217ff14f2feb7dbca4
3
+ size 3819
concept_libs/jerrymouse2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9713d9367f1faa6ebd753db5c8a209c565be0b25e32051c723c4533dd9df605
3
+ size 3840
concept_libs/zero.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78286aa910deafe4e46c6e38a86f464a246aef95ad5611a756dd99405f418a85
3
+ size 3819
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/stable_diffusion.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class StableDiffusion:
9
+ def __init__(
10
+ self,
11
+ vae_arch="CompVis/stable-diffusion-v1-4",
12
+ tokenizer_arch="openai/clip-vit-large-patch14",
13
+ encoder_arch="openai/clip-vit-large-patch14",
14
+ unet_arch="CompVis/stable-diffusion-v1-4",
15
+ device="cpu",
16
+ height=512,
17
+ width=512,
18
+ num_inference_steps=30,
19
+ guidance_scale=7.5,
20
+ manual_seed=1,
21
+ ) -> None:
22
+ self.height = height # default height of Stable Diffusion
23
+ self.width = width # default width of Stable Diffusion
24
+ self.num_inference_steps = num_inference_steps # Number of denoising steps
25
+ self.guidance_scale = guidance_scale # Scale for classifier-free guidance
26
+ self.device = device
27
+ self.manual_seed = manual_seed
28
+
29
+ vae = AutoencoderKL.from_pretrained(vae_arch, subfolder="vae")
30
+ # Load the tokenizer and text encoder to tokenize and encode the text.
31
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_arch)
32
+ text_encoder = CLIPTextModel.from_pretrained(encoder_arch)
33
+
34
+ # The UNet model for generating the latents.
35
+ unet = UNet2DConditionModel.from_pretrained(unet_arch, subfolder="unet")
36
+
37
+ # The noise scheduler
38
+ self.scheduler = LMSDiscreteScheduler(
39
+ beta_start=0.00085,
40
+ beta_end=0.012,
41
+ beta_schedule="scaled_linear",
42
+ num_train_timesteps=1000,
43
+ )
44
+
45
+ # To the GPU we go!
46
+ self.vae = vae.to(self.device)
47
+ self.text_encoder = text_encoder.to(self.device)
48
+ self.unet = unet.to(self.device)
49
+
50
+ self.token_emb_layer = text_encoder.text_model.embeddings.token_embedding
51
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
52
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
53
+ self.position_embeddings = pos_emb_layer(position_ids)
54
+
55
+ def get_output_embeds(self, input_embeddings):
56
+ # CLIP's text model uses causal mask, so we prepare it here:
57
+ bsz, seq_len = input_embeddings.shape[:2]
58
+ causal_attention_mask = (
59
+ self.text_encoder.text_model._build_causal_attention_mask(
60
+ bsz, seq_len, dtype=input_embeddings.dtype
61
+ )
62
+ )
63
+
64
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
65
+ # so that it doesn't just return the pooled final predictions:
66
+ encoder_outputs = self.text_encoder.text_model.encoder(
67
+ inputs_embeds=input_embeddings,
68
+ attention_mask=None, # We aren't using an attention mask so that can be None
69
+ causal_attention_mask=causal_attention_mask.to(self.device),
70
+ output_attentions=None,
71
+ output_hidden_states=True, # We want the output embs not the final output
72
+ return_dict=None,
73
+ )
74
+
75
+ # We're interested in the output hidden state only
76
+ output = encoder_outputs[0]
77
+
78
+ # There is a final layer norm we need to pass these through
79
+ output = self.text_encoder.text_model.final_layer_norm(output)
80
+
81
+ # And now they're ready!
82
+ return output
83
+
84
+ def set_timesteps(self, scheduler, num_inference_steps):
85
+ scheduler.set_timesteps(num_inference_steps)
86
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
87
+
88
+ def latents_to_pil(self, latents):
89
+ # bath of latents -> list of images
90
+ latents = (1 / 0.18215) * latents
91
+ with torch.no_grad():
92
+ image = self.vae.decode(latents).sample
93
+ image = (image / 2 + 0.5).clamp(0, 1)
94
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
95
+ images = (image * 255).round().astype("uint8")
96
+ pil_images = [Image.fromarray(image) for image in images]
97
+ return pil_images
98
+
99
+ def generate_with_embs(self, text_embeddings, text_input, loss_fn, loss_scale):
100
+ generator = torch.manual_seed(
101
+ self.manual_seed
102
+ ) # Seed generator to create the inital latent noise
103
+ batch_size = 1
104
+
105
+ max_length = text_input.input_ids.shape[-1]
106
+ uncond_input = self.tokenizer(
107
+ [""] * batch_size,
108
+ padding="max_length",
109
+ max_length=max_length,
110
+ return_tensors="pt",
111
+ )
112
+ with torch.no_grad():
113
+ uncond_embeddings = self.text_encoder(
114
+ uncond_input.input_ids.to(self.device)
115
+ )[0]
116
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
117
+
118
+ # Prep Scheduler
119
+ self.set_timesteps(self.scheduler, self.num_inference_steps)
120
+
121
+ # Prep latents
122
+ latents = torch.randn(
123
+ (batch_size, self.unet.in_channels, self.height // 8, self.width // 8),
124
+ generator=generator,
125
+ )
126
+ latents = latents.to(self.device)
127
+ latents = latents * self.scheduler.init_noise_sigma
128
+
129
+ # Loop
130
+ for i, t in tqdm(
131
+ enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)
132
+ ):
133
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
134
+ latent_model_input = torch.cat([latents] * 2)
135
+ sigma = self.scheduler.sigmas[i]
136
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
137
+
138
+ # predict the noise residual
139
+ with torch.no_grad():
140
+ noise_pred = self.unet(
141
+ latent_model_input, t, encoder_hidden_states=text_embeddings
142
+ )["sample"]
143
+
144
+ # perform guidance
145
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
146
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
147
+ noise_pred_text - noise_pred_uncond
148
+ )
149
+ if i % 5 == 0:
150
+ # Requires grad on the latents
151
+ latents = latents.detach().requires_grad_()
152
+
153
+ # Get the predicted x0:
154
+ # latents_x0 = latents - sigma * noise_pred
155
+ latents_x0 = self.scheduler.step(
156
+ noise_pred, t, latents
157
+ ).pred_original_sample
158
+
159
+ # Decode to image space
160
+ denoised_images = (
161
+ self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
162
+ ) # range (0, 1)
163
+
164
+ # Calculate loss
165
+ loss = loss_fn(denoised_images) * loss_scale
166
+
167
+ # Occasionally print it out
168
+ # if i % 10 == 0:
169
+ # print(i, "loss:", loss.item())
170
+
171
+ # Get gradient
172
+ cond_grad = torch.autograd.grad(loss, latents)[0]
173
+
174
+ # Modify the latents based on this gradient
175
+ latents = latents.detach() - cond_grad * sigma**2
176
+ self.scheduler._step_index = self.scheduler._step_index - 1
177
+
178
+ # compute the previous noisy sample x_t -> x_t-1
179
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
180
+
181
+ return self.latents_to_pil(latents)[0]
182
+
183
+ def generate_image(
184
+ self,
185
+ prompt="A campfire (oil on canvas)",
186
+ loss_fn=None,
187
+ loss_scale=200,
188
+ concept_embed=None, # birb_embed["<birb-style>"]
189
+ ):
190
+ prompt += " in the style of cs"
191
+ text_input = self.tokenizer(
192
+ prompt,
193
+ padding="max_length",
194
+ max_length=self.tokenizer.model_max_length,
195
+ truncation=True,
196
+ return_tensors="pt",
197
+ )
198
+ input_ids = text_input.input_ids.to(self.device)
199
+ custom_style_token = self.tokenizer.encode("cs", add_special_tokens=False)[0]
200
+ # Get token embeddings
201
+ token_embeddings = self.token_emb_layer(input_ids)
202
+
203
+ # The new embedding - our special birb word
204
+ embed_key = list(concept_embed.keys())[0]
205
+ replacement_token_embedding = concept_embed[embed_key]
206
+
207
+ # Insert this into the token embeddings
208
+ token_embeddings[
209
+ 0, torch.where(input_ids[0] == custom_style_token)
210
+ ] = replacement_token_embedding.to(self.device)
211
+ # token_embeddings = token_embeddings + (replacement_token_embedding * 0.9)
212
+ # Combine with pos embs
213
+ input_embeddings = token_embeddings + self.position_embeddings
214
+
215
+ # Feed through to get final output embs
216
+ modified_output_embeddings = self.get_output_embeds(input_embeddings)
217
+
218
+ # And generate an image with this:
219
+ generated_image = self.generate_with_embs(
220
+ modified_output_embeddings, text_input, loss_fn, loss_scale
221
+ )
222
+ return generated_image
src/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def loss_fn(images):
2
+ return -images.median() / 3
3
+
4
+
5
+ concept_styles = {
6
+ "Coffee Machine": "coffeemachine.bin",
7
+ "College Style": "college_style.bin",
8
+ "Cube": "cube.bin",
9
+ "Jerry Mouse": "jerrymouse",
10
+ "Zero": "zero.bin",
11
+ }