from typing import Callable, Optional import torch from accelerate.logging import get_logger from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.cross_attention import CrossAttention from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer if is_xformers_available(): import xformers import xformers.ops else: xformers = None logger = get_logger(__name__) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported raise NotImplementedError( "Memory efficient attention with `xformers` is currently not supported when" " `self.added_kv_proj_dim` is defined." ) elif not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e processor = CustomDiffusionXFormersAttnProcessor( attention_op=attention_op) else: processor = CustomDiffusionAttnProcessor() self.set_processor(processor) class CustomDiffusionAttnProcessor: def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) crossattn = False if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: crossattn = True if attn.cross_attention_norm: encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0. key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CustomDiffusionXFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) crossattn = False if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: crossattn = True if attn.cross_attention_norm: encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0. key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CustomDiffusionPipeline(StableDiffusionPipeline): r""" Pipeline for custom diffusion model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. modifier_token_id: list of id of tokens related to the target concept that are modified when ablated. """ _optional_components = ["safety_checker", "feature_extractor", "modifier_token_id"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, modifier_token_id: list = [], ): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) self.modifier_token_id = modifier_token_id def save_pretrained(self, save_path, parameter_group="cross-attn", all=False): if all: super().save_pretrained(save_path) else: delta_dict = {'unet': {}} if parameter_group == 'embedding': delta_dict['text_encoder'] = self.text_encoder.state_dict() for name, params in self.unet.named_parameters(): if parameter_group == "cross-attn": if 'attn2.to_k' in name or 'attn2.to_v' in name: delta_dict['unet'][name] = params.cpu().clone() elif parameter_group == "full-weight": delta_dict['unet'][name] = params.cpu().clone() else: raise ValueError( "parameter_group argument only supports one of [cross-attn, full-weight, embedding]" ) torch.save(delta_dict, save_path) def load_model(self, save_path): st = torch.load(save_path) print(st.keys()) if 'text_encoder' in st: self.text_encoder.load_state_dict(st['text_encoder']) for name, params in self.unet.named_parameters(): if name in st['unet']: params.data.copy_(st['unet'][f'{name}'])