Royir commited on
Commit
c0c23c6
1 Parent(s): 19cb368

Update syngen_diffusion_pipeline.py

Browse files

making syngen a little more efficient

Files changed (1) hide show
  1. syngen_diffusion_pipeline.py +71 -23
syngen_diffusion_pipeline.py CHANGED
@@ -19,8 +19,6 @@ from diffusers.utils import (
19
  )
20
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
21
 
22
- from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, \
23
- align_wordpieces_indices, extract_attribution_indices
24
 
25
  logger = logging.get_logger(__name__)
26
 
@@ -40,6 +38,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
40
  requires_safety_checker)
41
 
42
  self.parser = spacy.load("en_core_web_trf")
 
 
 
43
 
44
  def _aggregate_and_get_attention_maps_per_token(self):
45
  attention_maps = self.attention_store.aggregate_attention(
@@ -105,6 +106,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
105
  callback_steps: int = 1,
106
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
107
  syngen_step_size: float = 20.0,
 
108
  ):
109
  r"""
110
  Function invoked when calling the pipeline for generation.
@@ -165,7 +167,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
165
  A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
166
  `self.processor` in
167
  [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
168
- syngen_step_size (`int`, *optional*, default to 20):
169
  Controls the step size of each SynGen update.
170
 
171
  Examples:
@@ -177,6 +179,11 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
177
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
178
  (nsfw) content, according to the `safety_checker`.
179
  """
 
 
 
 
 
180
  # 0. Default height and width to unet
181
  height = height or self.unet.config.sample_size * self.vae_scale_factor
182
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -234,7 +241,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
234
  latents,
235
  )
236
 
237
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
238
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
239
 
240
  # NEW - stores the attention calculated in the unet
@@ -251,16 +258,17 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
251
  with self.progress_bar(total=num_inference_steps) as progress_bar:
252
  for i, t in enumerate(timesteps):
253
  # NEW
254
- latents = self._syngen_step(
255
- latents,
256
- text_embeddings,
257
- t,
258
- i,
259
- syngen_step_size,
260
- cross_attention_kwargs,
261
- prompt,
262
- max_iter_to_alter=25,
263
- )
 
264
 
265
  # expand the latents if we are doing classifier free guidance
266
  latent_model_input = (
@@ -325,6 +333,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
325
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
326
  self.final_offload_hook.offload()
327
 
 
 
 
328
  if not return_dict:
329
  return (image, has_nsfw_concept)
330
 
@@ -332,6 +343,8 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
332
  images=image, nsfw_content_detected=has_nsfw_concept
333
  )
334
 
 
 
335
  def _syngen_step(
336
  self,
337
  latents,
@@ -358,12 +371,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
358
  cross_attention_kwargs=cross_attention_kwargs,
359
  ).sample
360
  self.unet.zero_grad()
361
-
362
  # Get attention maps
363
  attention_maps = self._aggregate_and_get_attention_maps_per_token()
364
-
365
  loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
366
-
367
  # Perform gradient update
368
  if i < max_iter_to_alter:
369
  if loss != 0:
@@ -393,7 +403,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
393
  prompt: Union[str, List[str]],
394
  attn_map_idx_to_wp,
395
  ) -> torch.Tensor:
396
- subtrees_indices = self._extract_attribution_indices(prompt)
 
 
397
  loss = 0
398
 
399
  for subtree_indices in subtrees_indices:
@@ -474,15 +486,24 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
474
  collected_spacy_indices.add(collected_idx)
475
 
476
  paired_indices.append(curr_collected_wp_indices)
477
-
478
  return paired_indices
479
 
 
480
  def _extract_attribution_indices(self, prompt):
481
- pairs = extract_attribution_indices(prompt, self.parser)
482
- paired_indices = self._align_indices(prompt, pairs)
483
- return paired_indices
 
 
 
 
 
484
 
485
 
 
 
 
486
 
487
  def _get_attention_maps_list(
488
  attention_maps: torch.Tensor
@@ -492,4 +513,31 @@ def _get_attention_maps_list(
492
  attention_maps[:, :, i] for i in range(attention_maps.shape[2])
493
  ]
494
 
495
- return attention_maps_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
21
 
 
 
22
 
23
  logger = logging.get_logger(__name__)
24
 
 
38
  requires_safety_checker)
39
 
40
  self.parser = spacy.load("en_core_web_trf")
41
+ self.subtrees_indices = None
42
+ self.doc = None
43
+ # self.doc = ""#self.parser(prompt)
44
 
45
  def _aggregate_and_get_attention_maps_per_token(self):
46
  attention_maps = self.attention_store.aggregate_attention(
 
106
  callback_steps: int = 1,
107
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
108
  syngen_step_size: float = 20.0,
109
+ parsed_prompt: str=None
110
  ):
111
  r"""
112
  Function invoked when calling the pipeline for generation.
 
167
  A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
168
  `self.processor` in
169
  [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
170
+ syngen_step_size (`float`, *optional*, default to 20.0):
171
  Controls the step size of each SynGen update.
172
 
173
  Examples:
 
179
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
180
  (nsfw) content, according to the `safety_checker`.
181
  """
182
+
183
+ if parsed_prompt:
184
+ self.doc = parsed_prompt
185
+ else:
186
+ self.doc = self.parser(prompt)
187
  # 0. Default height and width to unet
188
  height = height or self.unet.config.sample_size * self.vae_scale_factor
189
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
241
  latents,
242
  )
243
 
244
+ # 6. Prepare extra step kwargs.
245
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
246
 
247
  # NEW - stores the attention calculated in the unet
 
258
  with self.progress_bar(total=num_inference_steps) as progress_bar:
259
  for i, t in enumerate(timesteps):
260
  # NEW
261
+ if i < 25:
262
+ latents = self._syngen_step(
263
+ latents,
264
+ text_embeddings,
265
+ t,
266
+ i,
267
+ syngen_step_size,
268
+ cross_attention_kwargs,
269
+ prompt,
270
+ max_iter_to_alter=25,
271
+ )
272
 
273
  # expand the latents if we are doing classifier free guidance
274
  latent_model_input = (
 
333
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
334
  self.final_offload_hook.offload()
335
 
336
+ self.doc = None
337
+ self.subtrees_indices = None
338
+
339
  if not return_dict:
340
  return (image, has_nsfw_concept)
341
 
 
343
  images=image, nsfw_content_detected=has_nsfw_concept
344
  )
345
 
346
+
347
+
348
  def _syngen_step(
349
  self,
350
  latents,
 
371
  cross_attention_kwargs=cross_attention_kwargs,
372
  ).sample
373
  self.unet.zero_grad()
 
374
  # Get attention maps
375
  attention_maps = self._aggregate_and_get_attention_maps_per_token()
 
376
  loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
 
377
  # Perform gradient update
378
  if i < max_iter_to_alter:
379
  if loss != 0:
 
403
  prompt: Union[str, List[str]],
404
  attn_map_idx_to_wp,
405
  ) -> torch.Tensor:
406
+ if not self.subtrees_indices:
407
+ self.subtrees_indices = self._extract_attribution_indices(prompt)
408
+ subtrees_indices = self.subtrees_indices
409
  loss = 0
410
 
411
  for subtree_indices in subtrees_indices:
 
486
  collected_spacy_indices.add(collected_idx)
487
 
488
  paired_indices.append(curr_collected_wp_indices)
489
+
490
  return paired_indices
491
 
492
+
493
  def _extract_attribution_indices(self, prompt):
494
+ # extract standard attribution indices
495
+ pairs = extract_attribution_indices(self.doc)
496
+
497
+ # extract attribution indices with verbs in between
498
+ pairs_2 = extract_attribution_indices_with_verb_root(self.doc)
499
+ pairs_3 = extract_attribution_indices_with_verbs(self.doc)
500
+ # make sure there are no duplicates
501
+ pairs = unify_lists(pairs, pairs_2, pairs_3)
502
 
503
 
504
+ print(f"Final pairs collected: {pairs}")
505
+ paired_indices = self._align_indices(prompt, pairs)
506
+ return paired_indices
507
 
508
  def _get_attention_maps_list(
509
  attention_maps: torch.Tensor
 
513
  attention_maps[:, :, i] for i in range(attention_maps.shape[2])
514
  ]
515
 
516
+ return attention_maps_list
517
+
518
+ def is_sublist(sub, main):
519
+ # This function checks if 'sub' is a sublist of 'main'
520
+ return len(sub) < len(main) and all(item in main for item in sub)
521
+
522
+ def unify_lists(lists_1, lists_2, lists_3):
523
+ unified_list = lists_1 + lists_2 + lists_3
524
+ sorted_list = sorted(unified_list, key=len)
525
+ seen = set()
526
+
527
+ result = []
528
+
529
+ for i in range(len(sorted_list)):
530
+ if tuple(sorted_list[i]) in seen: # Skip if already added
531
+ continue
532
+
533
+ sublist_to_add = True
534
+ for j in range(i + 1, len(sorted_list)):
535
+ if is_sublist(sorted_list[i], sorted_list[j]):
536
+ sublist_to_add = False
537
+ break
538
+
539
+ if sublist_to_add:
540
+ result.append(sorted_list[i])
541
+ seen.add(tuple(sorted_list[i]))
542
+
543
+ return result