Spaces:
Sleeping
Sleeping
Update syngen_diffusion_pipeline.py
Browse filesmaking syngen a little more efficient
- syngen_diffusion_pipeline.py +71 -23
syngen_diffusion_pipeline.py
CHANGED
@@ -19,8 +19,6 @@ from diffusers.utils import (
|
|
19 |
)
|
20 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
21 |
|
22 |
-
from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, \
|
23 |
-
align_wordpieces_indices, extract_attribution_indices
|
24 |
|
25 |
logger = logging.get_logger(__name__)
|
26 |
|
@@ -40,6 +38,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
40 |
requires_safety_checker)
|
41 |
|
42 |
self.parser = spacy.load("en_core_web_trf")
|
|
|
|
|
|
|
43 |
|
44 |
def _aggregate_and_get_attention_maps_per_token(self):
|
45 |
attention_maps = self.attention_store.aggregate_attention(
|
@@ -105,6 +106,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
105 |
callback_steps: int = 1,
|
106 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
107 |
syngen_step_size: float = 20.0,
|
|
|
108 |
):
|
109 |
r"""
|
110 |
Function invoked when calling the pipeline for generation.
|
@@ -165,7 +167,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
165 |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
166 |
`self.processor` in
|
167 |
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
168 |
-
syngen_step_size (`
|
169 |
Controls the step size of each SynGen update.
|
170 |
|
171 |
Examples:
|
@@ -177,6 +179,11 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
177 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
178 |
(nsfw) content, according to the `safety_checker`.
|
179 |
"""
|
|
|
|
|
|
|
|
|
|
|
180 |
# 0. Default height and width to unet
|
181 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
182 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -234,7 +241,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
234 |
latents,
|
235 |
)
|
236 |
|
237 |
-
# 6. Prepare extra step kwargs.
|
238 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
239 |
|
240 |
# NEW - stores the attention calculated in the unet
|
@@ -251,16 +258,17 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
251 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
252 |
for i, t in enumerate(timesteps):
|
253 |
# NEW
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
264 |
|
265 |
# expand the latents if we are doing classifier free guidance
|
266 |
latent_model_input = (
|
@@ -325,6 +333,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
325 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
326 |
self.final_offload_hook.offload()
|
327 |
|
|
|
|
|
|
|
328 |
if not return_dict:
|
329 |
return (image, has_nsfw_concept)
|
330 |
|
@@ -332,6 +343,8 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
332 |
images=image, nsfw_content_detected=has_nsfw_concept
|
333 |
)
|
334 |
|
|
|
|
|
335 |
def _syngen_step(
|
336 |
self,
|
337 |
latents,
|
@@ -358,12 +371,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
358 |
cross_attention_kwargs=cross_attention_kwargs,
|
359 |
).sample
|
360 |
self.unet.zero_grad()
|
361 |
-
|
362 |
# Get attention maps
|
363 |
attention_maps = self._aggregate_and_get_attention_maps_per_token()
|
364 |
-
|
365 |
loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
|
366 |
-
|
367 |
# Perform gradient update
|
368 |
if i < max_iter_to_alter:
|
369 |
if loss != 0:
|
@@ -393,7 +403,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
393 |
prompt: Union[str, List[str]],
|
394 |
attn_map_idx_to_wp,
|
395 |
) -> torch.Tensor:
|
396 |
-
|
|
|
|
|
397 |
loss = 0
|
398 |
|
399 |
for subtree_indices in subtrees_indices:
|
@@ -474,15 +486,24 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
474 |
collected_spacy_indices.add(collected_idx)
|
475 |
|
476 |
paired_indices.append(curr_collected_wp_indices)
|
477 |
-
|
478 |
return paired_indices
|
479 |
|
|
|
480 |
def _extract_attribution_indices(self, prompt):
|
481 |
-
|
482 |
-
|
483 |
-
|
|
|
|
|
|
|
|
|
|
|
484 |
|
485 |
|
|
|
|
|
|
|
486 |
|
487 |
def _get_attention_maps_list(
|
488 |
attention_maps: torch.Tensor
|
@@ -492,4 +513,31 @@ def _get_attention_maps_list(
|
|
492 |
attention_maps[:, :, i] for i in range(attention_maps.shape[2])
|
493 |
]
|
494 |
|
495 |
-
return attention_maps_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
)
|
20 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
21 |
|
|
|
|
|
22 |
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
|
|
38 |
requires_safety_checker)
|
39 |
|
40 |
self.parser = spacy.load("en_core_web_trf")
|
41 |
+
self.subtrees_indices = None
|
42 |
+
self.doc = None
|
43 |
+
# self.doc = ""#self.parser(prompt)
|
44 |
|
45 |
def _aggregate_and_get_attention_maps_per_token(self):
|
46 |
attention_maps = self.attention_store.aggregate_attention(
|
|
|
106 |
callback_steps: int = 1,
|
107 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
108 |
syngen_step_size: float = 20.0,
|
109 |
+
parsed_prompt: str=None
|
110 |
):
|
111 |
r"""
|
112 |
Function invoked when calling the pipeline for generation.
|
|
|
167 |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
168 |
`self.processor` in
|
169 |
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
170 |
+
syngen_step_size (`float`, *optional*, default to 20.0):
|
171 |
Controls the step size of each SynGen update.
|
172 |
|
173 |
Examples:
|
|
|
179 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
180 |
(nsfw) content, according to the `safety_checker`.
|
181 |
"""
|
182 |
+
|
183 |
+
if parsed_prompt:
|
184 |
+
self.doc = parsed_prompt
|
185 |
+
else:
|
186 |
+
self.doc = self.parser(prompt)
|
187 |
# 0. Default height and width to unet
|
188 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
189 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
241 |
latents,
|
242 |
)
|
243 |
|
244 |
+
# 6. Prepare extra step kwargs.
|
245 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
246 |
|
247 |
# NEW - stores the attention calculated in the unet
|
|
|
258 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
259 |
for i, t in enumerate(timesteps):
|
260 |
# NEW
|
261 |
+
if i < 25:
|
262 |
+
latents = self._syngen_step(
|
263 |
+
latents,
|
264 |
+
text_embeddings,
|
265 |
+
t,
|
266 |
+
i,
|
267 |
+
syngen_step_size,
|
268 |
+
cross_attention_kwargs,
|
269 |
+
prompt,
|
270 |
+
max_iter_to_alter=25,
|
271 |
+
)
|
272 |
|
273 |
# expand the latents if we are doing classifier free guidance
|
274 |
latent_model_input = (
|
|
|
333 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
334 |
self.final_offload_hook.offload()
|
335 |
|
336 |
+
self.doc = None
|
337 |
+
self.subtrees_indices = None
|
338 |
+
|
339 |
if not return_dict:
|
340 |
return (image, has_nsfw_concept)
|
341 |
|
|
|
343 |
images=image, nsfw_content_detected=has_nsfw_concept
|
344 |
)
|
345 |
|
346 |
+
|
347 |
+
|
348 |
def _syngen_step(
|
349 |
self,
|
350 |
latents,
|
|
|
371 |
cross_attention_kwargs=cross_attention_kwargs,
|
372 |
).sample
|
373 |
self.unet.zero_grad()
|
|
|
374 |
# Get attention maps
|
375 |
attention_maps = self._aggregate_and_get_attention_maps_per_token()
|
|
|
376 |
loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
|
|
|
377 |
# Perform gradient update
|
378 |
if i < max_iter_to_alter:
|
379 |
if loss != 0:
|
|
|
403 |
prompt: Union[str, List[str]],
|
404 |
attn_map_idx_to_wp,
|
405 |
) -> torch.Tensor:
|
406 |
+
if not self.subtrees_indices:
|
407 |
+
self.subtrees_indices = self._extract_attribution_indices(prompt)
|
408 |
+
subtrees_indices = self.subtrees_indices
|
409 |
loss = 0
|
410 |
|
411 |
for subtree_indices in subtrees_indices:
|
|
|
486 |
collected_spacy_indices.add(collected_idx)
|
487 |
|
488 |
paired_indices.append(curr_collected_wp_indices)
|
489 |
+
|
490 |
return paired_indices
|
491 |
|
492 |
+
|
493 |
def _extract_attribution_indices(self, prompt):
|
494 |
+
# extract standard attribution indices
|
495 |
+
pairs = extract_attribution_indices(self.doc)
|
496 |
+
|
497 |
+
# extract attribution indices with verbs in between
|
498 |
+
pairs_2 = extract_attribution_indices_with_verb_root(self.doc)
|
499 |
+
pairs_3 = extract_attribution_indices_with_verbs(self.doc)
|
500 |
+
# make sure there are no duplicates
|
501 |
+
pairs = unify_lists(pairs, pairs_2, pairs_3)
|
502 |
|
503 |
|
504 |
+
print(f"Final pairs collected: {pairs}")
|
505 |
+
paired_indices = self._align_indices(prompt, pairs)
|
506 |
+
return paired_indices
|
507 |
|
508 |
def _get_attention_maps_list(
|
509 |
attention_maps: torch.Tensor
|
|
|
513 |
attention_maps[:, :, i] for i in range(attention_maps.shape[2])
|
514 |
]
|
515 |
|
516 |
+
return attention_maps_list
|
517 |
+
|
518 |
+
def is_sublist(sub, main):
|
519 |
+
# This function checks if 'sub' is a sublist of 'main'
|
520 |
+
return len(sub) < len(main) and all(item in main for item in sub)
|
521 |
+
|
522 |
+
def unify_lists(lists_1, lists_2, lists_3):
|
523 |
+
unified_list = lists_1 + lists_2 + lists_3
|
524 |
+
sorted_list = sorted(unified_list, key=len)
|
525 |
+
seen = set()
|
526 |
+
|
527 |
+
result = []
|
528 |
+
|
529 |
+
for i in range(len(sorted_list)):
|
530 |
+
if tuple(sorted_list[i]) in seen: # Skip if already added
|
531 |
+
continue
|
532 |
+
|
533 |
+
sublist_to_add = True
|
534 |
+
for j in range(i + 1, len(sorted_list)):
|
535 |
+
if is_sublist(sorted_list[i], sorted_list[j]):
|
536 |
+
sublist_to_add = False
|
537 |
+
break
|
538 |
+
|
539 |
+
if sublist_to_add:
|
540 |
+
result.append(sorted_list[i])
|
541 |
+
seen.add(tuple(sorted_list[i]))
|
542 |
+
|
543 |
+
return result
|