ChenWu98 commited on
Commit
5b953fb
·
1 Parent(s): af0ed3b

Create ptp_utils.py

Browse files
Files changed (1) hide show
  1. ptp_utils.py +285 -0
ptp_utils.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ import cv2
19
+ from typing import Optional, Union, Tuple, List, Callable, Dict
20
+ from IPython.display import display
21
+ from tqdm.notebook import tqdm
22
+
23
+
24
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
25
+ h, w, c = image.shape
26
+ offset = int(h * .2)
27
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
28
+ font = cv2.FONT_HERSHEY_SIMPLEX
29
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
30
+ img[:h] = image
31
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
32
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
33
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
34
+ return img
35
+
36
+
37
+ def view_images(images, num_rows=1, offset_ratio=0.02):
38
+ if type(images) is list:
39
+ num_empty = len(images) % num_rows
40
+ elif images.ndim == 4:
41
+ num_empty = images.shape[0] % num_rows
42
+ else:
43
+ images = [images]
44
+ num_empty = 0
45
+
46
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
47
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
48
+ num_items = len(images)
49
+
50
+ h, w, c = images[0].shape
51
+ offset = int(h * offset_ratio)
52
+ num_cols = num_items // num_rows
53
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
54
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
55
+ for i in range(num_rows):
56
+ for j in range(num_cols):
57
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
58
+ i * num_cols + j]
59
+
60
+ pil_img = Image.fromarray(image_)
61
+ display(pil_img)
62
+
63
+
64
+
65
+ def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
66
+ if low_resource:
67
+ noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
68
+ noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
69
+ else:
70
+ latents_input = torch.cat([latents] * 2)
71
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
72
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
73
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
74
+ latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
75
+ latents = controller.step_callback(latents)
76
+ return latents
77
+
78
+
79
+ def latent2image(vae, latents):
80
+ latents = 1 / 0.18215 * latents
81
+ image = vae.decode(latents)['sample']
82
+ image = (image / 2 + 0.5).clamp(0, 1)
83
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
84
+ image = (image * 255).astype(np.uint8)
85
+ return image
86
+
87
+
88
+ def init_latent(latent, model, height, width, generator, batch_size):
89
+ if latent is None:
90
+ latent = torch.randn(
91
+ (1, model.unet.in_channels, height // 8, width // 8),
92
+ generator=generator,
93
+ )
94
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
95
+ return latent, latents
96
+
97
+
98
+ @torch.no_grad()
99
+ def text2image_ldm(
100
+ model,
101
+ prompt: List[str],
102
+ controller,
103
+ num_inference_steps: int = 50,
104
+ guidance_scale: Optional[float] = 7.,
105
+ generator: Optional[torch.Generator] = None,
106
+ latent: Optional[torch.FloatTensor] = None,
107
+ ):
108
+ register_attention_control(model, controller)
109
+ height = width = 256
110
+ batch_size = len(prompt)
111
+
112
+ uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
113
+ uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
114
+
115
+ text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
116
+ text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
117
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
118
+ context = torch.cat([uncond_embeddings, text_embeddings])
119
+
120
+ model.scheduler.set_timesteps(num_inference_steps)
121
+ for t in tqdm(model.scheduler.timesteps):
122
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
123
+
124
+ image = latent2image(model.vqvae, latents)
125
+
126
+ return image, latent
127
+
128
+
129
+
130
+ @torch.no_grad()
131
+ def text2image_ldm_stable(
132
+ model,
133
+ prompt: List[str],
134
+ controller,
135
+ num_inference_steps: int = 50,
136
+ guidance_scale: float = 7.5,
137
+ generator: Optional[torch.Generator] = None,
138
+ latent: Optional[torch.FloatTensor] = None,
139
+ low_resource: bool = False,
140
+ ):
141
+ register_attention_control(model, controller)
142
+ height = width = 512
143
+ batch_size = len(prompt)
144
+
145
+ text_input = model.tokenizer(
146
+ prompt,
147
+ padding="max_length",
148
+ max_length=model.tokenizer.model_max_length,
149
+ truncation=True,
150
+ return_tensors="pt",
151
+ )
152
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
153
+ max_length = text_input.input_ids.shape[-1]
154
+ uncond_input = model.tokenizer(
155
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
156
+ )
157
+ uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
158
+
159
+ context = [uncond_embeddings, text_embeddings]
160
+ if not low_resource:
161
+ context = torch.cat(context)
162
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
163
+
164
+ # set timesteps
165
+ extra_set_kwargs = {"offset": 1}
166
+ model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
167
+ for t in tqdm(model.scheduler.timesteps):
168
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
169
+
170
+ image = latent2image(model.vae, latents)
171
+
172
+ return image, latent
173
+
174
+
175
+ def register_attention_control(model, controller):
176
+ def ca_forward(self, place_in_unet):
177
+
178
+ def forward(x, context=None, mask=None):
179
+ batch_size, sequence_length, dim = x.shape
180
+ h = self.heads
181
+ q = self.to_q(x)
182
+ is_cross = context is not None
183
+ context = context if is_cross else x
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+ q = self.reshape_heads_to_batch_dim(q)
187
+ k = self.reshape_heads_to_batch_dim(k)
188
+ v = self.reshape_heads_to_batch_dim(v)
189
+
190
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
191
+
192
+ if mask is not None:
193
+ mask = mask.reshape(batch_size, -1)
194
+ max_neg_value = -torch.finfo(sim.dtype).max
195
+ mask = mask[:, None, :].repeat(h, 1, 1)
196
+ sim.masked_fill_(~mask, max_neg_value)
197
+
198
+ # attention, what we cannot get enough of
199
+ attn = sim.softmax(dim=-1)
200
+ attn = controller(attn, is_cross, place_in_unet)
201
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
202
+ out = self.reshape_batch_dim_to_heads(out)
203
+
204
+ # TODO: Chen (new version of diffusers)
205
+ # return self.to_out(out)
206
+ # linear proj
207
+ out = self.to_out[0](out)
208
+ # dropout
209
+ out = self.to_out[1](out)
210
+ return out
211
+
212
+ return forward
213
+
214
+ def register_recr(net_, count, place_in_unet):
215
+ if net_.__class__.__name__ == 'CrossAttention':
216
+ net_.forward = ca_forward(net_, place_in_unet)
217
+ return count + 1
218
+ elif hasattr(net_, 'children'):
219
+ for net__ in net_.children():
220
+ count = register_recr(net__, count, place_in_unet)
221
+ return count
222
+
223
+ cross_att_count = 0
224
+ sub_nets = model.unet.named_children()
225
+ for net in sub_nets:
226
+ if "down" in net[0]:
227
+ cross_att_count += register_recr(net[1], 0, "down")
228
+ elif "up" in net[0]:
229
+ cross_att_count += register_recr(net[1], 0, "up")
230
+ elif "mid" in net[0]:
231
+ cross_att_count += register_recr(net[1], 0, "mid")
232
+ controller.num_att_layers = cross_att_count
233
+
234
+
235
+ def get_word_inds(text: str, word_place: int, tokenizer):
236
+ split_text = text.split(" ")
237
+ if type(word_place) is str:
238
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
239
+ elif type(word_place) is int:
240
+ word_place = [word_place]
241
+ out = []
242
+ if len(word_place) > 0:
243
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
244
+ cur_len, ptr = 0, 0
245
+
246
+ for i in range(len(words_encode)):
247
+ cur_len += len(words_encode[i])
248
+ if ptr in word_place:
249
+ out.append(i + 1)
250
+ if cur_len >= len(split_text[ptr]):
251
+ ptr += 1
252
+ cur_len = 0
253
+ return np.array(out)
254
+
255
+
256
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
257
+ if type(bounds) is float:
258
+ bounds = 0, bounds
259
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
260
+ if word_inds is None:
261
+ word_inds = torch.arange(alpha.shape[2])
262
+ alpha[: start, prompt_ind, word_inds] = 0
263
+ alpha[start: end, prompt_ind, word_inds] = 1
264
+ alpha[end:, prompt_ind, word_inds] = 0
265
+ return alpha
266
+
267
+
268
+ def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
269
+ tokenizer, max_num_words=77):
270
+ if type(cross_replace_steps) is not dict:
271
+ cross_replace_steps = {"default_": cross_replace_steps}
272
+ if "default_" not in cross_replace_steps:
273
+ cross_replace_steps["default_"] = (0., 1.)
274
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
275
+ for i in range(len(prompts) - 1):
276
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
277
+ i)
278
+ for key, item in cross_replace_steps.items():
279
+ if key != "default_":
280
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
281
+ for i, ind in enumerate(inds):
282
+ if len(ind) > 0:
283
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
284
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
285
+ return alpha_time_words