# ************************************************************************* # Copyright (2023) Bytedance Inc. # # Copyright (2023) DragDiffusion Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ************************************************************************* import torch import numpy as np import torch.nn.functional as F from tqdm import tqdm from PIL import Image from typing import Any, Dict, List, Optional, Tuple, Union from diffusers import StableDiffusionPipeline # override unet forward # The only difference from diffusers: # return intermediate UNet features of all UpSample blocks def override_forward(self): def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_intermediates: bool = False, ): # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # if USE_PEFT_BACKEND: # # weight the lora layers by setting `lora_scale` for each PEFT layer # scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual all_intermediate_features = [sample] # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=lora_scale, ) all_intermediate_features.append(sample) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # if USE_PEFT_BACKEND: # # remove `lora_scale` from each PEFT layer # unscale_lora_layers(self, lora_scale) # only difference from diffusers, return intermediate results if return_intermediates: return sample, all_intermediate_features else: return sample return forward class DragPipeline(StableDiffusionPipeline): # must call this function when initialize def modify_unet_forward(self): self.unet.forward = override_forward(self.unet) def inv_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 def step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, ): """ predict the sample of the next step in the denoise process. """ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir return x_prev, pred_x0 @torch.no_grad() def image2latent(self, image): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if type(image) is Image: image = np.array(image) image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) # input image density range [-1, 1] latents = self.vae.encode(image)['latent_dist'].mean latents = latents * 0.18215 return latents @torch.no_grad() def latent2image(self, latents, return_type='np'): latents = 1 / 0.18215 * latents.detach() image = self.vae.decode(latents)['sample'] if return_type == 'np': image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) elif return_type == "pt": image = (image / 2 + 0.5).clamp(0, 1) return image def latent2image_grad(self, latents): latents = 1 / 0.18215 * latents image = self.vae.decode(latents)['sample'] return image # range [-1, 1] @torch.no_grad() def get_text_embeddings(self, prompt): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] return text_embeddings # get all intermediate features and then do bilinear interpolation # return features in the layer_idx list def forward_unet_features( self, z, t, encoder_hidden_states, layer_idx=[0], interp_res_h=256, interp_res_w=256): unet_output, all_intermediate_features = self.unet( z, t, encoder_hidden_states=encoder_hidden_states, return_intermediates=True ) all_return_features = [] for idx in layer_idx: feat = all_intermediate_features[idx] feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear') all_return_features.append(feat) return_features = torch.cat(all_return_features, dim=1) return unet_output, return_features @torch.no_grad() def __call__( self, prompt, encoder_hidden_states=None, batch_size=1, height=512, width=512, num_inference_steps=50, num_actual_inference_steps=None, guidance_scale=7.5, latents=None, neg_prompt=None, return_intermediates=False, **kwds): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if encoder_hidden_states is None: if isinstance(prompt, list): batch_size = len(prompt) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size # text embeddings encoder_hidden_states = self.get_text_embeddings(prompt) # define initial latents if not predefined if latents is None: latents_shape = (batch_size, self.unet.in_channels, height//8, width//8) latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype) # unconditional embedding for classifier free guidance if guidance_scale > 1.: if neg_prompt: uc_text = neg_prompt else: uc_text = "" unconditional_embeddings = self.get_text_embeddings([uc_text]*batch_size) encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0) print("latents shape: ", latents.shape) # iterative sampling self.scheduler.set_timesteps(num_inference_steps) # print("Valid timesteps: ", reversed(self.scheduler.timesteps)) if return_intermediates: latents_list = [latents] for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")): if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: continue if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents # predict the noise noise_pred = self.unet( model_inputs, t, encoder_hidden_states=encoder_hidden_states, ) if guidance_scale > 1.0: noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if return_intermediates: latents_list.append(latents) image = self.latent2image(latents, return_type="pt") if return_intermediates: return image, latents_list return image @torch.no_grad() def invert( self, image: torch.Tensor, prompt, encoder_hidden_states=None, num_inference_steps=50, num_actual_inference_steps=None, guidance_scale=7.5, eta=0.0, return_intermediates=False, **kwds): """ invert a real image into noise map with determinisc DDIM inversion """ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") batch_size = image.shape[0] if encoder_hidden_states is None: if isinstance(prompt, list): if batch_size == 1: image = image.expand(len(prompt), -1, -1, -1) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size encoder_hidden_states = self.get_text_embeddings(prompt) # define initial latents latents = self.image2latent(image) # unconditional embedding for classifier free guidance if guidance_scale > 1.: max_length = text_input.input_ids.shape[-1] unconditional_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, return_tensors="pt" ) unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0) print("latents shape: ", latents.shape) # interative sampling self.scheduler.set_timesteps(num_inference_steps) print("Valid timesteps: ", reversed(self.scheduler.timesteps)) # print("attributes: ", self.scheduler.__dict__) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): if num_actual_inference_steps is not None and i >= num_actual_inference_steps: continue if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents # predict the noise noise_pred = self.unet(model_inputs, t, encoder_hidden_states=encoder_hidden_states, ) if guidance_scale > 1.: noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t-1 -> x_t latents, pred_x0 = self.inv_step(noise_pred, t, latents) latents_list.append(latents) pred_x0_list.append(pred_x0) if return_intermediates: # return the intermediate laters during inversion # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] return latents, latents_list return latents