Spaces:
Configuration error
Configuration error
Linoy Tsaban
commited on
Commit
·
0052810
1
Parent(s):
b7b2a49
Update pipeline_semantic_stable_diffusion_img2img_solver.py
Browse filesmerging Manuel's updates - removing edit_momentum and adjustments to attention store
pipeline_semantic_stable_diffusion_img2img_solver.py
CHANGED
@@ -36,21 +36,19 @@ class AttentionStore():
|
|
36 |
|
37 |
def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
|
38 |
# attn.shape = batch_size * head_size, seq_len query, seq_len_key
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
place_in_unet)
|
49 |
|
50 |
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
51 |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
52 |
-
|
53 |
-
self.step_store[key].append(attn)
|
54 |
|
55 |
def between_steps(self, store_step=True):
|
56 |
if store_step:
|
@@ -96,12 +94,13 @@ class AttentionStore():
|
|
96 |
out = out.sum(1) / out.shape[1]
|
97 |
return out
|
98 |
|
99 |
-
def __init__(self, average: bool, batch_size=1):
|
100 |
self.step_store = self.get_empty_store()
|
101 |
self.attention_store = []
|
102 |
self.cur_step = 0
|
103 |
self.average = average
|
104 |
self.batch_size = batch_size
|
|
|
105 |
|
106 |
|
107 |
class CrossAttnProcessor:
|
@@ -433,10 +432,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
433 |
|
434 |
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
435 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
|
436 |
-
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
437 |
|
438 |
-
if latents.shape != shape:
|
439 |
-
|
440 |
|
441 |
latents = latents.to(device)
|
442 |
|
@@ -456,7 +455,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
456 |
else:
|
457 |
continue
|
458 |
|
459 |
-
if "attn2" in name:
|
460 |
attn_procs[name] = CrossAttnProcessor(
|
461 |
attention_store=attention_store,
|
462 |
place_in_unet=place_in_unet,
|
@@ -488,12 +487,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
488 |
editing_prompt_embeddings: Optional[torch.Tensor] = None,
|
489 |
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
|
490 |
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
|
491 |
-
edit_warmup_steps: Optional[Union[int, List[int]]] =
|
492 |
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
|
493 |
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
|
494 |
user_mask: Optional[torch.FloatTensor] = None,
|
495 |
-
|
496 |
-
edit_mom_beta: Optional[float] = 0.4,
|
497 |
edit_weights: Optional[List[float]] = None,
|
498 |
sem_guidance: Optional[List[torch.Tensor]] = None,
|
499 |
verbose=True,
|
@@ -788,8 +786,6 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
788 |
# 6. Prepare extra step kwargs.
|
789 |
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
|
790 |
|
791 |
-
# Initialize edit_momentum to None
|
792 |
-
edit_momentum = None
|
793 |
|
794 |
self.uncond_estimates = None
|
795 |
self.text_estimates = None
|
@@ -833,12 +829,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
833 |
self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
|
834 |
self.text_estimates[i] = noise_pred_text.detach().cpu()
|
835 |
|
836 |
-
|
837 |
-
edit_momentum = torch.zeros_like(noise_guidance)
|
838 |
|
839 |
if sem_guidance is not None and len(sem_guidance) > i:
|
840 |
edit_guidance = sem_guidance[i].to(self.device)
|
841 |
-
edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * edit_guidance
|
842 |
noise_guidance = noise_guidance + edit_guidance
|
843 |
|
844 |
elif enable_edit_guidance:
|
|
|
36 |
|
37 |
def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
|
38 |
# attn.shape = batch_size * head_size, seq_len query, seq_len_key
|
39 |
+
if attn.shape[1] <= self.max_size:
|
40 |
+
bs = 1 + int(PnP) + editing_prompts
|
41 |
+
skip = 2 if PnP else 1 # skip PnP & unconditional
|
42 |
+
attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
|
43 |
+
source_batch_size = int(attn.shape[1] // bs)
|
44 |
+
self.forward(
|
45 |
+
attn[:, skip * source_batch_size:],
|
46 |
+
is_cross,
|
47 |
+
place_in_unet)
|
|
|
48 |
|
49 |
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
50 |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
51 |
+
self.step_store[key].append(attn)
|
|
|
52 |
|
53 |
def between_steps(self, store_step=True):
|
54 |
if store_step:
|
|
|
94 |
out = out.sum(1) / out.shape[1]
|
95 |
return out
|
96 |
|
97 |
+
def __init__(self, average: bool, batch_size=1, max_resolution=16):
|
98 |
self.step_store = self.get_empty_store()
|
99 |
self.attention_store = []
|
100 |
self.cur_step = 0
|
101 |
self.average = average
|
102 |
self.batch_size = batch_size
|
103 |
+
self.max_size = max_resolution ** 2
|
104 |
|
105 |
|
106 |
class CrossAttnProcessor:
|
|
|
432 |
|
433 |
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
434 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
|
435 |
+
# shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
436 |
|
437 |
+
# if latents.shape != shape:
|
438 |
+
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
439 |
|
440 |
latents = latents.to(device)
|
441 |
|
|
|
455 |
else:
|
456 |
continue
|
457 |
|
458 |
+
if "attn2" in name and place_in_unet != 'mid':
|
459 |
attn_procs[name] = CrossAttnProcessor(
|
460 |
attention_store=attention_store,
|
461 |
place_in_unet=place_in_unet,
|
|
|
487 |
editing_prompt_embeddings: Optional[torch.Tensor] = None,
|
488 |
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
|
489 |
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
|
490 |
+
edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
|
491 |
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
|
492 |
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
|
493 |
user_mask: Optional[torch.FloatTensor] = None,
|
494 |
+
|
|
|
495 |
edit_weights: Optional[List[float]] = None,
|
496 |
sem_guidance: Optional[List[torch.Tensor]] = None,
|
497 |
verbose=True,
|
|
|
786 |
# 6. Prepare extra step kwargs.
|
787 |
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
|
788 |
|
|
|
|
|
789 |
|
790 |
self.uncond_estimates = None
|
791 |
self.text_estimates = None
|
|
|
829 |
self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
|
830 |
self.text_estimates[i] = noise_pred_text.detach().cpu()
|
831 |
|
832 |
+
|
|
|
833 |
|
834 |
if sem_guidance is not None and len(sem_guidance) > i:
|
835 |
edit_guidance = sem_guidance[i].to(self.device)
|
|
|
836 |
noise_guidance = noise_guidance + edit_guidance
|
837 |
|
838 |
elif enable_edit_guidance:
|