from diffusers.utils import ( USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, is_accelerate_available, logging, set_adapter_layers, set_weights_and_activate_adapters, ) import torch import torch.nn.functional as F from torch.autograd.function import Function import torch.nn as nn from torch import einsum import os from collections import defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.embeddings import ImageProjection from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta import math from einops import rearrange from diffusers.image_processor import IPAdapterMaskProcessor xformers_available = False try: import xformers xformers_available = True except ImportError: pass EPSILON = 1e-6 exists = lambda val: val is not None default = lambda val, d: val if exists(val) else d logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_attention_scores(attn, query, key, attention_mask=None): if attn.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device, ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=attn.scale, ) del baddbmm_input if attn.upcast_softmax: attention_scores = attention_scores.float() return attention_scores.to(query.dtype) # Get attention_score with this: def scaled_dot_product_attention_regionstate(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,weight_func =None, region_state = None, sigma = None) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype,device = query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias batch_size, num_heads, sequence_length, embed_dim = attn_weight.shape attn_weight = attn_weight.reshape((-1,sequence_length,embed_dim)) cross_attention_weight = weight_func(region_state, sigma, attn_weight) repeat_time = attn_weight.shape[0]//cross_attention_weight.shape[0] attn_weight += torch.repeat_interleave( cross_attention_weight, repeats=repeat_time, dim=0 ) attn_weight = attn_weight.reshape((-1,num_heads,sequence_length,embed_dim)) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value class FlashAttentionFunction(Function): @staticmethod @torch.no_grad() def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): """Algorithm 2 in the paper""" device = q.device max_neg_value = -torch.finfo(q.dtype).max qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) o = torch.zeros_like(q) all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device) all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device) scale = q.shape[-1] ** -0.5 if not exists(mask): mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) else: mask = rearrange(mask, "b n -> b 1 1 n") mask = mask.split(q_bucket_size, dim=-1) row_splits = zip( q.split(q_bucket_size, dim=-2), o.split(q_bucket_size, dim=-2), mask, all_row_sums.split(q_bucket_size, dim=-2), all_row_maxes.split(q_bucket_size, dim=-2), ) for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): q_start_index = ind * q_bucket_size - qk_len_diff col_splits = zip( k.split(k_bucket_size, dim=-2), v.split(k_bucket_size, dim=-2), ) for k_ind, (kc, vc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale if exists(row_mask): attn_weights.masked_fill_(~row_mask, max_neg_value) if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones( (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device ).triu(q_start_index - k_start_index + 1) attn_weights.masked_fill_(causal_mask, max_neg_value) block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) attn_weights -= block_row_maxes exp_weights = torch.exp(attn_weights) if exists(row_mask): exp_weights.masked_fill_(~row_mask, 0.0) block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( min=EPSILON ) new_row_maxes = torch.maximum(block_row_maxes, row_maxes) exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) new_row_sums = ( exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums ) oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( (exp_block_row_max_diff / new_row_sums) * exp_values ) row_maxes.copy_(new_row_maxes) row_sums.copy_(new_row_sums) lse = all_row_sums.log() + all_row_maxes ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) ctx.save_for_backward(q, k, v, o, lse) return o @staticmethod @torch.no_grad() def backward(ctx, do): """Algorithm 4 in the paper""" causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args q, k, v, o, lse = ctx.saved_tensors device = q.device max_neg_value = -torch.finfo(q.dtype).max qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) row_splits = zip( q.split(q_bucket_size, dim=-2), o.split(q_bucket_size, dim=-2), do.split(q_bucket_size, dim=-2), mask, lse.split(q_bucket_size, dim=-2), dq.split(q_bucket_size, dim=-2), ) for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits): q_start_index = ind * q_bucket_size - qk_len_diff col_splits = zip( k.split(k_bucket_size, dim=-2), v.split(k_bucket_size, dim=-2), dk.split(k_bucket_size, dim=-2), dv.split(k_bucket_size, dim=-2), ) for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones( (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device ).triu(q_start_index - k_start_index + 1) attn_weights.masked_fill_(causal_mask, max_neg_value) p = torch.exp(attn_weights - lsec) if exists(row_mask): p.masked_fill_(~row_mask, 0.0) dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) dp = einsum("... i d, ... j d -> ... i j", doc, vc) D = (doc * oc).sum(dim=-1, keepdims=True) ds = p * scale * (dp - D) dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) dqc.add_(dq_chunk) dkc.add_(dk_chunk) dvc.add_(dv_chunk) return dq, dk, dv, None, None, None, None class AttnProcessor(nn.Module): def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb: Optional[torch.Tensor] = None, region_prompt = None, ip_adapter_masks = None, *args, **kwargs, ): if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states #_,img_sequence_length,_ = hidden_states.shape img_sequence_length = hidden_states.shape[1] if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) is_xattn = False if encoder_hidden_states is not None and region_prompt is not None: is_xattn = True region_state = region_prompt["region_state"] weight_func = region_prompt["weight_func"] sigma = region_prompt["sigma"] batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length,batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) if is_xattn and isinstance(region_state, dict): # use torch.baddbmm method (slow) attention_scores = get_attention_scores(attn, query, key, attention_mask) cross_attention_weight = weight_func(region_state[img_sequence_length].to(query.device), sigma, attention_scores) attention_scores += torch.repeat_interleave( cross_attention_weight, repeats=attention_scores.shape[0] // cross_attention_weight.shape[0], dim=0 ) # calc probs attention_probs = attention_scores.softmax(dim=-1) attention_probs = attention_probs.to(query.dtype) hidden_states = torch.bmm(attention_probs, value) elif xformers_available: hidden_states = xformers.ops.memory_efficient_attention( query.contiguous(), key.contiguous(), value.contiguous(), attn_bias=attention_mask, ) hidden_states = hidden_states.to(query.dtype) else: '''q_bucket_size = 512 k_bucket_size = 1024 # use flash-attention hidden_states = FlashAttentionFunction.apply( query.contiguous(), key.contiguous(), value.contiguous(), attention_mask, False, q_bucket_size, k_bucket_size, )''' attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAdapterAttnProcessor(nn.Module): r""" Attention processor for Multiple IP-Adapters. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. scale (`float` or List[`float`], defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim if not isinstance(num_tokens, (tuple, list)): num_tokens = [num_tokens] self.num_tokens = num_tokens if not isinstance(scale, list): scale = [scale] * len(num_tokens) if len(scale) != len(num_tokens): raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale self.to_k_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) self.to_v_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale=1.0, region_prompt = None, ip_adapter_masks = None, ): #_,img_sequence_length,_ = hidden_states.shape img_sequence_length= hidden_states.shape[1] residual = hidden_states is_xattn = False if encoder_hidden_states is not None and region_prompt is not None: is_xattn = True region_state = region_prompt["region_state"] weight_func = region_prompt["weight_func"] sigma = region_prompt["sigma"] # separate ip_hidden_states from encoder_hidden_states if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: deprecation_message = ( "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." ) deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) if is_xattn and isinstance(region_state, dict): # use torch.baddbmm method (slow) attention_scores = get_attention_scores(attn, query, key, attention_mask) cross_attention_weight = weight_func(region_state[img_sequence_length].to(query.device), sigma, attention_scores) attention_scores += torch.repeat_interleave( cross_attention_weight, repeats=attention_scores.shape[0] // cross_attention_weight.shape[0], dim=0 ) # calc probs attention_probs = attention_scores.softmax(dim=-1) attention_probs = attention_probs.to(query.dtype) hidden_states = torch.bmm(attention_probs, value) elif xformers_available: hidden_states = xformers.ops.memory_efficient_attention( query.contiguous(), key.contiguous(), value.contiguous(), attn_bias=attention_mask, ) hidden_states = hidden_states.to(query.dtype) else: '''q_bucket_size = 512 k_bucket_size = 1024 # use flash-attention hidden_states = FlashAttentionFunction.apply( query.contiguous(), key.contiguous(), value.contiguous(), attention_mask, False, q_bucket_size, k_bucket_size, )''' attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) '''# for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) if xformers_available: current_ip_hidden_states = xformers.ops.memory_efficient_attention( query.contiguous(), ip_key.contiguous(), ip_value.contiguous(), attn_bias=None, ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) else: ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) hidden_states = hidden_states + scale * current_ip_hidden_states''' #control region apply ip-adapter if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: raise ValueError( "Each element of the ip_adapter_masks array should be a tensor with shape " "[1, num_images_for_ip_adapter, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask" ) if mask.shape[1] != ip_state.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of ip images ({ip_state.shape[1]}) at index {index}" ) if isinstance(scale, list) and not len(scale) == mask.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of scales ({len(scale)}) at index {index}" ) else: ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): if all(s == 0 for s in scale): skip = True elif scale == 0: skip = True if not skip: if mask is not None: if not isinstance(scale, list): scale = [scale] * mask.shape[1] current_num_images = mask.shape[1] for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, region_prompt = None, ip_adapter_masks = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states #_,img_sequence_length,_ = hidden_states.shape img_sequence_length= hidden_states.shape[1] if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) is_xattn = False if encoder_hidden_states is not None and region_prompt is not None: is_xattn = True region_state = region_prompt["region_state"] weight_func = region_prompt["weight_func"] sigma = region_prompt["sigma"] batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 if is_xattn and isinstance(region_state, dict): #w = attn.head_to_batch_dim(w,out_dim = 4).transpose(1, 2) hidden_states = scaled_dot_product_attention_regionstate(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,weight_func = weight_func,region_state=region_state[img_sequence_length].to(query.device),sigma = sigma) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAdapterAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapter for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. scale (`float` or `List[float]`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim if not isinstance(num_tokens, (tuple, list)): num_tokens = [num_tokens] self.num_tokens = num_tokens if not isinstance(scale, list): scale = [scale] * len(num_tokens) if len(scale) != len(num_tokens): raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale self.to_k_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) self.to_v_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale=1.0, region_prompt = None, ip_adapter_masks = None, ): residual = hidden_states #_,img_sequence_length,_ = hidden_states.shape img_sequence_length= hidden_states.shape[1] is_xattn = False if encoder_hidden_states is not None and region_prompt is not None: is_xattn = True region_state = region_prompt["region_state"] weight_func = region_prompt["weight_func"] sigma = region_prompt["sigma"] # separate ip_hidden_states from encoder_hidden_states if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: deprecation_message = ( "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." ) deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 if is_xattn and isinstance(region_state, dict): #w = attn.head_to_batch_dim(w,out_dim = 4).transpose(1, 2) hidden_states = scaled_dot_product_attention_regionstate(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,weight_func = weight_func,region_state=region_state[img_sequence_length].to(query.device),sigma = sigma) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) ''''# for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) hidden_states = hidden_states + scale * current_ip_hidden_states''' if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: raise ValueError( "Each element of the ip_adapter_masks array should be a tensor with shape " "[1, num_images_for_ip_adapter, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask" ) if mask.shape[1] != ip_state.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of ip images ({ip_state.shape[1]}) at index {index}" ) if isinstance(scale, list) and not len(scale) == mask.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of scales ({len(scale)}) at index {index}" ) else: ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): if all(s == 0 for s in scale): skip = True elif scale == 0: skip = True if not skip: if mask is not None: if not isinstance(scale, list): scale = [scale] * mask.shape[1] current_num_images = mask.shape[1] for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 _current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states