linoyts HF staff commited on
Commit
e92a1c6
Β·
verified Β·
1 Parent(s): 50d5527

Update ledits/pipeline_leditspp_stable_diffusion_xl.py

Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py CHANGED
@@ -415,10 +415,11 @@ class LEditsPPPipelineStableDiffusionXL(
415
  editing_prompt: Optional[str] = None,
416
  editing_prompt_embeds: Optional[torch.Tensor] = None,
417
  editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
418
- avg_diff = None,
419
- avg_diff_2 = None,
420
- correlation_weight_factor = 0.7,
421
  scale=2,
 
422
  ) -> object:
423
  r"""
424
  Encodes the prompt into text encoder hidden states.
@@ -538,9 +539,8 @@ class LEditsPPPipelineStableDiffusionXL(
538
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
539
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
540
 
541
- if avg_diff is not None and avg_diff_2 is not None:
542
- #scale=3
543
- print("SHALOM neg")
544
  normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
545
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
546
  if j == 0:
@@ -549,15 +549,26 @@ class LEditsPPPipelineStableDiffusionXL(
549
  standard_weights = torch.ones_like(weights)
550
 
551
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
552
- edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
553
  else:
554
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
555
 
556
  standard_weights = torch.ones_like(weights)
557
 
558
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
559
- edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
 
560
 
 
 
 
 
561
 
562
  negative_prompt_embeds_list.append(negative_prompt_embeds)
563
  j+=1
@@ -878,10 +889,12 @@ class LEditsPPPipelineStableDiffusionXL(
878
  clip_skip: Optional[int] = None,
879
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
880
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
881
- avg_diff = None,
882
- avg_diff_2 = None,
883
- correlation_weight_factor = 0.7,
884
  scale=2,
 
 
885
  init_latents: [torch.Tensor] = None,
886
  zs: [torch.Tensor] = None,
887
  **kwargs,
@@ -1088,9 +1101,10 @@ class LEditsPPPipelineStableDiffusionXL(
1088
  editing_prompt_embeds=editing_prompt_embeddings,
1089
  editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1090
  avg_diff = avg_diff,
1091
- avg_diff_2 = avg_diff_2,
1092
  correlation_weight_factor = correlation_weight_factor,
1093
  scale=scale,
 
1094
  )
1095
 
1096
  # 4. Prepare timesteps
 
415
  editing_prompt: Optional[str] = None,
416
  editing_prompt_embeds: Optional[torch.Tensor] = None,
417
  editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
418
+ avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
419
+ avg_diff_2nd=None, # text encoder 1,2
420
+ correlation_weight_factor=0.7,
421
  scale=2,
422
+ scale_2nd=2,
423
  ) -> object:
424
  r"""
425
  Encodes the prompt into text encoder hidden states.
 
539
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
540
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
541
 
542
+ if avg_diff is not None:
543
+ # scale=3
 
544
  normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
545
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
546
  if j == 0:
 
549
  standard_weights = torch.ones_like(weights)
550
 
551
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
552
+ edit_concepts_embeds = negative_prompt_embeds + (
553
+ weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
554
+
555
+ if avg_diff_2nd is not None:
556
+ edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
557
+ self.pipe.tokenizer.model_max_length,
558
+ 1) * scale_2nd)
559
  else:
560
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
561
 
562
  standard_weights = torch.ones_like(weights)
563
 
564
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
565
+ edit_concepts_embeds = negative_prompt_embeds + (
566
+ weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
567
 
568
+ if avg_diff_2nd is not None:
569
+ edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
570
+ self.pipe.tokenizer_2.model_max_length,
571
+ 1) * scale_2nd)
572
 
573
  negative_prompt_embeds_list.append(negative_prompt_embeds)
574
  j+=1
 
889
  clip_skip: Optional[int] = None,
890
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
891
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
892
+ avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
893
+ avg_diff_2nd=None, # text encoder 1,2
894
+ correlation_weight_factor=0.7,
895
  scale=2,
896
+ scale_2nd=2,
897
+ correlation_weight_factor = 0.7,
898
  init_latents: [torch.Tensor] = None,
899
  zs: [torch.Tensor] = None,
900
  **kwargs,
 
1101
  editing_prompt_embeds=editing_prompt_embeddings,
1102
  editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1103
  avg_diff = avg_diff,
1104
+ avg_diff_2nd = avg_diff_2nd,
1105
  correlation_weight_factor = correlation_weight_factor,
1106
  scale=scale,
1107
+ scale_2nd=scale_2nd
1108
  )
1109
 
1110
  # 4. Prepare timesteps