from typing import Optional import torch from torch import FloatTensor, LongTensor, Size, Tensor from torch import nn as nn from prior import generate_beta_tensor class InterpolatedAttnProcessor(nn.Module): def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, ): super().__init__() if t is None: ts = generate_beta_tensor(size, alpha=alpha, beta=beta) ts[0], ts[-1] = 0, 1 else: assert t > 0 and t < 1, "t must be between 0 and 1" ts = [0, t, 1] ts = torch.tensor(ts) size = 3 self.size = size self.coef = ts self.is_fused = is_fused self.activated = True def deactivate(self): self.activated = False def activate(self, t): self.activated = True assert t > 0 and t < 1, "t must be between 0 and 1" ts = [0, t, 1] ts = torch.tensor(ts) self.coef = ts def load_end_point(self, key_begin, value_begin, key_end, value_end): self.key_begin = key_begin self.value_begin = value_begin self.key_end = key_end self.value_end = value_end class ScaleControlIPAttnProcessor(InterpolatedAttnProcessor): r""" Personalized processor for control the impact of image prompt via attention interpolation. """ def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, ip_attn: Optional[nn.Module] = None, ): """ t: float, interpolation point between 0 and 1, if specified, size is set to 3 """ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta) self.num_tokens = ( ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,) ) self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None self.ip_attn = ip_attn def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if encoder_hidden_states is None: encoder_hidden_states = hidden_states ip_hidden_states = None else: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if not self.activated: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if ip_hidden_states is not None: key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][6:9]) value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][6:9]) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) ip_attention_probs = attn.get_attention_scores( query, key, attention_mask ) ip_hidden_states = torch.bmm(ip_attention_probs, value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = ( hidden_states + self.coef.reshape(-1, 1, 1).to(key.device, key.dtype) * ip_hidden_states ) else: key_begin = key[0:1].expand(3, *key.shape[1:]) key_end = key[-1:].expand(3, *key.shape[1:]) value_begin = value[0:1].expand(3, *value.shape[1:]) value_end = value[-1:].expand(3, *value.shape[1:]) key_begin = attn.head_to_batch_dim(key_begin) value_begin = attn.head_to_batch_dim(value_begin) key_end = attn.head_to_batch_dim(key_end) value_end = attn.head_to_batch_dim(value_end) if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_end = torch.cat([key, key_end], dim=-2) value_end = torch.cat([value, value_end], dim=-2) key_begin = torch.cat([key, key_begin], dim=-2) value_begin = torch.cat([value, value_begin], dim=-2) attention_probs_end = attn.get_attention_scores( query, key_end, attention_mask ) hidden_states_end = torch.bmm(attention_probs_end, value_end) hidden_states_end = attn.batch_to_head_dim(hidden_states_end) attention_probs_begin = attn.get_attention_scores( query, key_begin, attention_mask ) hidden_states_begin = torch.bmm(attention_probs_begin, value_begin) hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin) # Apply outer interpolation on attention coef = self.coef.reshape(-1, 1, 1) coef = coef.to(key.device, key.dtype) hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end # for ip-adapter if ip_hidden_states is not None: key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][6:9]) value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][6:9]) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) ip_attention_probs = attn.get_attention_scores( query, key, attention_mask ) ip_hidden_states = torch.bmm(ip_attention_probs, value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + coef * ip_hidden_states hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class OuterInterpolatedIPAttnProcessor(InterpolatedAttnProcessor): r""" Personalized processor for performing outer attention interpolation. Combined with IP-Adapter attention processor. """ def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, ip_attn: Optional[nn.Module] = None, ): """ t: float, interpolation point between 0 and 1, if specified, size is set to 3 """ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta) self.num_tokens = ( ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,) ) self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None self.ip_attn = ip_attn def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: if not self.activated: return self.ip_attn( attn, hidden_states, encoder_hidden_states, attention_mask, temb ) residual = hidden_states if encoder_hidden_states is None: encoder_hidden_states = hidden_states ip_hidden_states = None else: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Specify the first and last key and value key_begin = key[0:1].expand(3, *key.shape[1:]) key_end = key[-1:].expand(3, *key.shape[1:]) value_begin = value[0:1].expand(3, *value.shape[1:]) value_end = value[-1:].expand(3, *value.shape[1:]) key_begin = attn.head_to_batch_dim(key_begin) value_begin = attn.head_to_batch_dim(value_begin) key_end = attn.head_to_batch_dim(key_end) value_end = attn.head_to_batch_dim(value_end) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_end = torch.cat([key, key_end], dim=-2) value_end = torch.cat([value, value_end], dim=-2) key_begin = torch.cat([key, key_begin], dim=-2) value_begin = torch.cat([value, value_begin], dim=-2) attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask) hidden_states_end = torch.bmm(attention_probs_end, value_end) hidden_states_end = attn.batch_to_head_dim(hidden_states_end) attention_probs_begin = attn.get_attention_scores( query, key_begin, attention_mask ) hidden_states_begin = torch.bmm(attention_probs_begin, value_begin) hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin) # for ip-adapter if ip_hidden_states is not None: key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][::3]) value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][::3]) # Specify the first and last key and value key_begin = key[0:1].expand(3, *key.shape[1:]) key_end = key[-1:].expand(3, *key.shape[1:]) value_begin = value[0:1].expand(3, *value.shape[1:]) value_end = value[-1:].expand(3, *value.shape[1:]) key_begin = attn.head_to_batch_dim(key_begin) value_begin = attn.head_to_batch_dim(value_begin) key_end = attn.head_to_batch_dim(key_end) value_end = attn.head_to_batch_dim(value_end) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_end = torch.cat([key, key_end], dim=-2) value_end = torch.cat([value, value_end], dim=-2) key_begin = torch.cat([key, key_begin], dim=-2) value_begin = torch.cat([value, value_begin], dim=-2) ip_attention_probs_end = attn.get_attention_scores( query, key_end, attention_mask ) ip_hidden_states_end = torch.bmm(ip_attention_probs_end, value_end) ip_hidden_states_end = attn.batch_to_head_dim(ip_hidden_states_end) ip_attention_probs_begin = attn.get_attention_scores( query, key_begin, attention_mask ) ip_hidden_states_begin = torch.bmm(ip_attention_probs_begin, value_begin) ip_hidden_states_begin = attn.batch_to_head_dim(ip_hidden_states_begin) hidden_states_begin = ( hidden_states_begin + self.scale[0] * ip_hidden_states_begin ) hidden_states_end = hidden_states_end + self.scale[0] * ip_hidden_states_end # Apply outer interpolation on attention coef = self.coef.reshape(-1, 1, 1) coef = coef.to(key.device, key.dtype) hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class InnerInterpolatedIPAttnProcessor(InterpolatedAttnProcessor): r""" Personalized processor for performing inner attention interpolation. With IP-Adapter. """ def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, ip_attn: Optional[nn.Module] = None, ): """ t: float, interpolation point between 0 and 1, if specified, size is set to 3 """ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta) self.num_tokens = ( ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,) ) self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None self.ip_attn = ip_attn def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: if not self.activated: return self.ip_attn( attn, hidden_states, encoder_hidden_states, attention_mask, temb ) residual = hidden_states if encoder_hidden_states is None: encoder_hidden_states = hidden_states ip_hidden_states = None else: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Specify the first and last key and value key_begin = key[0:1].expand(3, *key.shape[1:]) key_end = key[-1:].expand(3, *key.shape[1:]) value_begin = value[0:1].expand(3, *value.shape[1:]) value_end = value[-1:].expand(3, *value.shape[1:]) coef = self.coef.reshape(-1, 1, 1) coef = coef.to(key.device, key.dtype) key_cross = (1 - coef) * key_begin + coef * key_end value_cross = (1 - coef) * value_begin + coef * value_end key_cross = attn.head_to_batch_dim(key_cross) value_cross = attn.head_to_batch_dim(value_cross) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_cross = torch.cat([key, key_cross], dim=-2) value_cross = torch.cat([value, value_cross], dim=-2) attention_probs = attn.get_attention_scores(query, key_cross, attention_mask) hidden_states = torch.bmm(attention_probs, value_cross) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter if ip_hidden_states is not None: key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][::3]) value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][::3]) key = key.squeeze() value = value.squeeze() # Specify the first and last key and value key_begin = key[0:1].expand(3, *key.shape[1:]) key_end = key[-1:].expand(3, *key.shape[1:]) value_begin = value[0:1].expand(3, *value.shape[1:]) value_end = value[-1:].expand(3, *value.shape[1:]) key_cross = (1 - coef) * key_begin + coef * key_end value_cross = (1 - coef) * value_begin + coef * value_end key_cross = attn.head_to_batch_dim(key_cross) value_cross = attn.head_to_batch_dim(value_cross) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_cross = torch.cat([key, key_cross], dim=-2) value_cross = torch.cat([value, value_cross], dim=-2) attention_probs = attn.get_attention_scores(query, key, attention_mask) ip_hidden_states = torch.bmm(attention_probs, value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale[0] * ip_hidden_states hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class OuterInterpolatedAttnProcessor(InterpolatedAttnProcessor): r""" Personalized processor for performing outer attention interpolation. The attention output of interpolated image is obtained by: (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m; If fused with self-attention: (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t]; """ def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, original_attn: Optional[nn.Module] = None, ): """ t: float, interpolation point between 0 and 1, if specified, size is set to 3 """ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta) self.original_attn = original_attn def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: if not self.activated: return self.original_attn( attn, hidden_states, encoder_hidden_states, attention_mask, temb ) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Specify the first and last key and value key_begin = key[0:1] key_end = key[-1:] value_begin = value[0:1] value_end = value[-1:] key_begin = torch.cat([key_begin] * (self.size)) key_end = torch.cat([key_end] * (self.size)) value_begin = torch.cat([value_begin] * (self.size)) value_end = torch.cat([value_end] * (self.size)) key_begin = attn.head_to_batch_dim(key_begin) value_begin = attn.head_to_batch_dim(value_begin) key_end = attn.head_to_batch_dim(key_end) value_end = attn.head_to_batch_dim(value_end) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_end = torch.cat([key, key_end], dim=-2) value_end = torch.cat([value, value_end], dim=-2) key_begin = torch.cat([key, key_begin], dim=-2) value_begin = torch.cat([value, value_begin], dim=-2) attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask) hidden_states_end = torch.bmm(attention_probs_end, value_end) hidden_states_end = attn.batch_to_head_dim(hidden_states_end) attention_probs_begin = attn.get_attention_scores( query, key_begin, attention_mask ) hidden_states_begin = torch.bmm(attention_probs_begin, value_begin) hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin) # Apply outer interpolation on attention coef = self.coef.reshape(-1, 1, 1) coef = coef.to(key.device, key.dtype) hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class InnerInterpolatedAttnProcessor(InterpolatedAttnProcessor): r""" Personalized processor for performing inner attention interpolation. The attention output of interpolated image is obtained by: (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m; If fused with self-attention: (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t]; """ def __init__( self, t: Optional[float] = None, size: int = 7, is_fused: bool = False, alpha: float = 1, beta: float = 1, original_attn: Optional[nn.Module] = None, ): """ t: float, interpolation point between 0 and 1, if specified, size is set to 3 """ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta) self.original_attn = original_attn def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: if not self.activated: return self.original_attn( attn, hidden_states, encoder_hidden_states, attention_mask, temb ) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Specify the first and last key and value key_start = key[0:1] key_end = key[-1:] value_start = value[0:1] value_end = value[-1:] key_start = torch.cat([key_start] * (self.size)) key_end = torch.cat([key_end] * (self.size)) value_start = torch.cat([value_start] * (self.size)) value_end = torch.cat([value_end] * (self.size)) # Apply inner interpolation on attention coef = self.coef.reshape(-1, 1, 1) coef = coef.to(key.device, key.dtype) key_cross = (1 - coef) * key_start + coef * key_end value_cross = (1 - coef) * value_start + coef * value_end key_cross = attn.head_to_batch_dim(key_cross) value_cross = attn.head_to_batch_dim(value_cross) # Fused with self-attention if self.is_fused: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key_cross = torch.cat([key, key_cross], dim=-2) value_cross = torch.cat([value, value_cross], dim=-2) attention_probs = attn.get_attention_scores(query, key_cross, attention_mask) hidden_states = torch.bmm(attention_probs, value_cross) hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def linear_interpolation( l1: FloatTensor, l2: FloatTensor, ts: Optional[FloatTensor] = None, size: int = 5 ) -> FloatTensor: """ Linear interpolation Args: l1: Starting vector: (1, *) l2: Final vector: (1, *) ts: FloatTensor, interpolation points between 0 and 1 size: int, number of interpolation points including l1 and l2 Returns: Interpolated vectors: (size, *) """ assert l1.shape == l2.shape, "shapes of l1 and l2 must match" res = [] if ts is not None: for t in ts: li = torch.lerp(l1, l2, t) res.append(li) else: for i in range(size): t = i / (size - 1) li = torch.lerp(l1, l2, t) res.append(li) res = torch.cat(res, dim=0) return res def spherical_interpolation(l1: FloatTensor, l2: FloatTensor, size=5) -> FloatTensor: """ Spherical interpolation Args: l1: Starting vector: (1, *) l2: Final vector: (1, *) size: int, number of interpolation points including l1 and l2 Returns: Interpolated vectors: (size, *) """ assert l1.shape == l2.shape, "shapes of l1 and l2 must match" res = [] for i in range(size): t = i / (size - 1) li = slerp(l1, l2, t) res.append(li) res = torch.cat(res, dim=0) return res def slerp(v0: FloatTensor, v1: FloatTensor, t, threshold=0.9995): """ Spherical linear interpolation Args: v0: Starting vector v1: Final vector t: Float value between 0.0 and 1.0 threshold: Threshold for considering the two vectors as colinear. Not recommended to alter this. Returns: Interpolation vector between v0 and v1 """ assert v0.shape == v1.shape, "shapes of v0 and v1 must match" # Normalize the vectors to get the directions and angles v0_norm: FloatTensor = torch.norm(v0, dim=-1) v1_norm: FloatTensor = torch.norm(v1, dim=-1) v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1) v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1) # Dot product with the normalized vectors dot: FloatTensor = (v0_normed * v1_normed).sum(-1) dot_mag: FloatTensor = dot.abs() # if dp is NaN, it's because the v0 or v1 row was filled with 0s # If absolute value of dot product is almost 1, vectors are ~colinear, so use torch.lerp gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > threshold) can_slerp: LongTensor = ~gotta_lerp t_batch_dim_count: int = max(0, t.dim() - v0.dim()) if isinstance(t, Tensor) else 0 t_batch_dims: Size = ( t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([]) ) out: FloatTensor = torch.zeros_like(v0.expand(*t_batch_dims, *[-1] * v0.dim())) # if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting if gotta_lerp.any(): lerped: FloatTensor = torch.lerp(v0, v1, t) out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out) # if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting if can_slerp.any(): # Calculate initial angle between v0 and v1 theta_0: FloatTensor = dot.arccos().unsqueeze(-1) sin_theta_0: FloatTensor = theta_0.sin() # Angle at timestep t theta_t: FloatTensor = theta_0 * t sin_theta_t: FloatTensor = theta_t.sin() # Finish the slerp algorithm s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0 s1: FloatTensor = sin_theta_t / sin_theta_0 slerped: FloatTensor = s0 * v0 + s1 * v1 out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out) return out