Spaces:
Runtime error
Runtime error
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from diffusers.utils import deprecate | |
from .activations import FP32SiLU, get_activation | |
from .attention_processor import Attention | |
def get_timestep_embedding( | |
timesteps: torch.Tensor, | |
embedding_dim: int, | |
flip_sin_to_cos: bool = False, | |
downscale_freq_shift: float = 1, | |
scale: float = 1, | |
max_period: int = 10000, | |
): | |
""" | |
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
Args | |
timesteps (torch.Tensor): | |
a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
embedding_dim (int): | |
the dimension of the output. | |
flip_sin_to_cos (bool): | |
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) | |
downscale_freq_shift (float): | |
Controls the delta between frequencies between dimensions | |
scale (float): | |
Scaling factor applied to the embeddings. | |
max_period (int): | |
Controls the maximum frequency of the embeddings | |
Returns | |
torch.Tensor: an [N x dim] Tensor of positional embeddings. | |
""" | |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
half_dim = embedding_dim // 2 | |
exponent = -math.log(max_period) * torch.arange( | |
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
) | |
exponent = exponent / (half_dim - downscale_freq_shift) | |
emb = torch.exp(exponent) | |
emb = timesteps[:, None].float() * emb[None, :] | |
# scale embeddings | |
emb = scale * emb | |
# concat sine and cosine embeddings | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
# flip sine and cosine embeddings | |
if flip_sin_to_cos: | |
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
# zero pad | |
if embedding_dim % 2 == 1: | |
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
return emb | |
def get_2d_sincos_pos_embed( | |
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 | |
): | |
""" | |
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
if isinstance(grid_size, int): | |
grid_size = (grid_size, grid_size) | |
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token and extra_tokens > 0: | |
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
""" | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
omega = np.arange(embed_dim // 2, dtype=np.float64) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class PatchEmbed(nn.Module): | |
"""2D Image to Patch Embedding with support for SD3 cropping.""" | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=1, | |
pos_embed_type="sincos", | |
pos_embed_max_size=None, # For SD3 cropping | |
): | |
super().__init__() | |
num_patches = (height // patch_size) * (width // patch_size) | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.pos_embed_max_size = pos_embed_max_size | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size = patch_size | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
self.interpolation_scale = interpolation_scale | |
# Calculate positional embeddings based on max size or default | |
if pos_embed_max_size: | |
grid_size = pos_embed_max_size | |
else: | |
grid_size = int(num_patches**0.5) | |
if pos_embed_type is None: | |
self.pos_embed = None | |
elif pos_embed_type == "sincos": | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
) | |
persistent = True if pos_embed_max_size else False | |
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) | |
else: | |
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") | |
def cropped_pos_embed(self, height, width): | |
"""Crops positional embeddings for SD3 compatibility.""" | |
if self.pos_embed_max_size is None: | |
raise ValueError("`pos_embed_max_size` must be set for cropping.") | |
height = height // self.patch_size | |
width = width // self.patch_size | |
if height > self.pos_embed_max_size: | |
raise ValueError( | |
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
) | |
if width > self.pos_embed_max_size: | |
raise ValueError( | |
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
) | |
top = (self.pos_embed_max_size - height) // 2 | |
left = (self.pos_embed_max_size - width) // 2 | |
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) | |
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] | |
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) | |
return spatial_pos_embed | |
def forward(self, latent): | |
if self.pos_embed_max_size is not None: | |
height, width = latent.shape[-2:] | |
else: | |
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
latent = self.proj(latent) | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC | |
if self.layer_norm: | |
latent = self.norm(latent) | |
if self.pos_embed is None: | |
return latent.to(latent.dtype) | |
# Interpolate or crop positional embeddings as needed | |
if self.pos_embed_max_size: | |
pos_embed = self.cropped_pos_embed(height, width) | |
else: | |
if self.height != height or self.width != width: | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim=self.pos_embed.shape[-1], | |
grid_size=(height, width), | |
base_size=self.base_size, | |
interpolation_scale=self.interpolation_scale, | |
) | |
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) | |
else: | |
pos_embed = self.pos_embed | |
return (latent + pos_embed).to(latent.dtype) | |
class LuminaPatchEmbed(nn.Module): | |
"""2D Image to Patch Embedding with support for Lumina-T2X""" | |
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): | |
super().__init__() | |
self.patch_size = patch_size | |
self.proj = nn.Linear( | |
in_features=patch_size * patch_size * in_channels, | |
out_features=embed_dim, | |
bias=bias, | |
) | |
def forward(self, x, freqs_cis): | |
""" | |
Patchifies and embeds the input tensor(s). | |
Args: | |
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified | |
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the | |
frequency tensor(s). | |
""" | |
freqs_cis = freqs_cis.to(x[0].device) | |
patch_height = patch_width = self.patch_size | |
batch_size, channel, height, width = x.size() | |
height_tokens, width_tokens = height // patch_height, width // patch_width | |
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( | |
0, 2, 4, 1, 3, 5 | |
) | |
x = x.flatten(3) | |
x = self.proj(x) | |
x = x.flatten(1, 2) | |
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) | |
return ( | |
x, | |
mask, | |
[(height, width)] * batch_size, | |
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), | |
) | |
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): | |
""" | |
RoPE for image tokens with 2d structure. | |
Args: | |
embed_dim: (`int`): | |
The embedding dimension size | |
crops_coords (`Tuple[int]`) | |
The top-left and bottom-right coordinates of the crop. | |
grid_size (`Tuple[int]`): | |
The grid size of the positional embedding. | |
use_real (`bool`): | |
If True, return real part and imaginary part separately. Otherwise, return complex numbers. | |
Returns: | |
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. | |
""" | |
start, stop = crops_coords | |
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) | |
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) # [2, W, H] | |
grid = grid.reshape([2, 1, *grid.shape[1:]]) | |
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) | |
return pos_embed | |
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): | |
assert embed_dim % 4 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_rotary_pos_embed( | |
embed_dim // 2, grid[0].reshape(-1), use_real=use_real | |
) # (H*W, D/2) if use_real else (H*W, D/4) | |
emb_w = get_1d_rotary_pos_embed( | |
embed_dim // 2, grid[1].reshape(-1), use_real=use_real | |
) # (H*W, D/2) if use_real else (H*W, D/4) | |
if use_real: | |
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) | |
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) | |
return cos, sin | |
else: | |
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) | |
return emb | |
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): | |
assert embed_dim % 4 == 0 | |
emb_h = get_1d_rotary_pos_embed( | |
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor | |
) # (H, D/4) | |
emb_w = get_1d_rotary_pos_embed( | |
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor | |
) # (W, D/4) | |
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) | |
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) | |
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) | |
return emb | |
def get_1d_rotary_pos_embed( | |
dim: int, | |
pos: Union[np.ndarray, int], | |
theta: float = 10000.0, | |
use_real=False, | |
linear_factor=1.0, | |
ntk_factor=1.0, | |
repeat_interleave_real=True, | |
): | |
""" | |
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | |
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end | |
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 | |
data type. | |
Args: | |
dim (`int`): Dimension of the frequency tensor. | |
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar | |
theta (`float`, *optional*, defaults to 10000.0): | |
Scaling factor for frequency computation. Defaults to 10000.0. | |
use_real (`bool`, *optional*): | |
If True, return real part and imaginary part separately. Otherwise, return complex numbers. | |
linear_factor (`float`, *optional*, defaults to 1.0): | |
Scaling factor for the context extrapolation. Defaults to 1.0. | |
ntk_factor (`float`, *optional*, defaults to 1.0): | |
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. | |
repeat_interleave_real (`bool`, *optional*, defaults to `True`): | |
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. | |
Otherwise, they are concateanted with themselves. | |
Returns: | |
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] | |
""" | |
assert dim % 2 == 0 | |
if isinstance(pos, int): | |
pos = np.arange(pos) | |
theta = theta * ntk_factor | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] | |
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | |
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] | |
if use_real and repeat_interleave_real: | |
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | |
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | |
return freqs_cos, freqs_sin | |
elif use_real: | |
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] | |
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] | |
return freqs_cos, freqs_sin | |
else: | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | |
return freqs_cis | |
def apply_rotary_emb( | |
x: torch.Tensor, | |
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
use_real: bool = True, | |
use_real_unbind_dim: int = -1, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings | |
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are | |
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting | |
tensors contain rotary embeddings and are returned as real tensors. | |
Args: | |
x (`torch.Tensor`): | |
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply | |
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | |
""" | |
if use_real: | |
cos, sin = freqs_cis # [S, D] | |
cos = cos[None, None] | |
sin = sin[None, None] | |
cos, sin = cos.to(x.device), sin.to(x.device) | |
if use_real_unbind_dim == -1: | |
# Use for example in Lumina | |
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
elif use_real_unbind_dim == -2: | |
# Use for example in Stable Audio | |
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] | |
x_rotated = torch.cat([-x_imag, x_real], dim=-1) | |
else: | |
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") | |
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
return out | |
else: | |
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
freqs_cis = freqs_cis.unsqueeze(2) | |
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) | |
return x_out.type_as(x) | |
class TimestepEmbedding(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
time_embed_dim: int, | |
act_fn: str = "silu", | |
out_dim: int = None, | |
post_act_fn: Optional[str] = None, | |
cond_proj_dim=None, | |
sample_proj_bias=True, | |
): | |
super().__init__() | |
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) | |
if cond_proj_dim is not None: | |
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
else: | |
self.cond_proj = None | |
self.act = get_activation(act_fn) | |
if out_dim is not None: | |
time_embed_dim_out = out_dim | |
else: | |
time_embed_dim_out = time_embed_dim | |
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
if post_act_fn is None: | |
self.post_act = None | |
else: | |
self.post_act = get_activation(post_act_fn) | |
def forward(self, sample, condition=None): | |
if condition is not None: | |
sample = sample + self.cond_proj(condition) | |
sample = self.linear_1(sample) | |
if self.act is not None: | |
sample = self.act(sample) | |
sample = self.linear_2(sample) | |
if self.post_act is not None: | |
sample = self.post_act(sample) | |
return sample | |
class Timesteps(nn.Module): | |
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): | |
super().__init__() | |
self.num_channels = num_channels | |
self.flip_sin_to_cos = flip_sin_to_cos | |
self.downscale_freq_shift = downscale_freq_shift | |
self.scale = scale | |
def forward(self, timesteps): | |
t_emb = get_timestep_embedding( | |
timesteps, | |
self.num_channels, | |
flip_sin_to_cos=self.flip_sin_to_cos, | |
downscale_freq_shift=self.downscale_freq_shift, | |
scale=self.scale, | |
) | |
return t_emb | |
class GaussianFourierProjection(nn.Module): | |
"""Gaussian Fourier embeddings for noise levels.""" | |
def __init__( | |
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False | |
): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
self.log = log | |
self.flip_sin_to_cos = flip_sin_to_cos | |
if set_W_to_weight: | |
# to delete later | |
del self.weight | |
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
self.weight = self.W | |
del self.W | |
def forward(self, x): | |
if self.log: | |
x = torch.log(x) | |
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi | |
if self.flip_sin_to_cos: | |
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) | |
else: | |
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) | |
return out | |
class SinusoidalPositionalEmbedding(nn.Module): | |
"""Apply positional information to a sequence of embeddings. | |
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to | |
them | |
Args: | |
embed_dim: (int): Dimension of the positional embedding. | |
max_seq_length: Maximum sequence length to apply positional embeddings | |
""" | |
def __init__(self, embed_dim: int, max_seq_length: int = 32): | |
super().__init__() | |
position = torch.arange(max_seq_length).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) | |
pe = torch.zeros(1, max_seq_length, embed_dim) | |
pe[0, :, 0::2] = torch.sin(position * div_term) | |
pe[0, :, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
_, seq_length, _ = x.shape | |
x = x + self.pe[:, :seq_length] | |
return x | |
class ImagePositionalEmbeddings(nn.Module): | |
""" | |
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the | |
height and width of the latent space. | |
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 | |
For VQ-diffusion: | |
Output vector embeddings are used as input for the transformer. | |
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. | |
Args: | |
num_embed (`int`): | |
Number of embeddings for the latent pixels embeddings. | |
height (`int`): | |
Height of the latent image i.e. the number of height embeddings. | |
width (`int`): | |
Width of the latent image i.e. the number of width embeddings. | |
embed_dim (`int`): | |
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. | |
""" | |
def __init__( | |
self, | |
num_embed: int, | |
height: int, | |
width: int, | |
embed_dim: int, | |
): | |
super().__init__() | |
self.height = height | |
self.width = width | |
self.num_embed = num_embed | |
self.embed_dim = embed_dim | |
self.emb = nn.Embedding(self.num_embed, embed_dim) | |
self.height_emb = nn.Embedding(self.height, embed_dim) | |
self.width_emb = nn.Embedding(self.width, embed_dim) | |
def forward(self, index): | |
emb = self.emb(index) | |
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) | |
# 1 x H x D -> 1 x H x 1 x D | |
height_emb = height_emb.unsqueeze(2) | |
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) | |
# 1 x W x D -> 1 x 1 x W x D | |
width_emb = width_emb.unsqueeze(1) | |
pos_emb = height_emb + width_emb | |
# 1 x H x W x D -> 1 x L xD | |
pos_emb = pos_emb.view(1, self.height * self.width, -1) | |
emb = emb + pos_emb[:, : emb.shape[1], :] | |
return emb | |
class LabelEmbedding(nn.Module): | |
""" | |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
Args: | |
num_classes (`int`): The number of classes. | |
hidden_size (`int`): The size of the vector embeddings. | |
dropout_prob (`float`): The probability of dropping a label. | |
""" | |
def __init__(self, num_classes, hidden_size, dropout_prob): | |
super().__init__() | |
use_cfg_embedding = dropout_prob > 0 | |
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) | |
self.num_classes = num_classes | |
self.dropout_prob = dropout_prob | |
def token_drop(self, labels, force_drop_ids=None): | |
""" | |
Drops labels to enable classifier-free guidance. | |
""" | |
if force_drop_ids is None: | |
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob | |
else: | |
drop_ids = torch.tensor(force_drop_ids == 1) | |
labels = torch.where(drop_ids, self.num_classes, labels) | |
return labels | |
def forward(self, labels: torch.LongTensor, force_drop_ids=None): | |
use_dropout = self.dropout_prob > 0 | |
if (self.training and use_dropout) or (force_drop_ids is not None): | |
labels = self.token_drop(labels, force_drop_ids) | |
embeddings = self.embedding_table(labels) | |
return embeddings | |
class TextImageProjection(nn.Module): | |
def __init__( | |
self, | |
text_embed_dim: int = 1024, | |
image_embed_dim: int = 768, | |
cross_attention_dim: int = 768, | |
num_image_text_embeds: int = 10, | |
): | |
super().__init__() | |
self.num_image_text_embeds = num_image_text_embeds | |
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) | |
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
batch_size = text_embeds.shape[0] | |
# image | |
image_text_embeds = self.image_embeds(image_embeds) | |
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
# text | |
text_embeds = self.text_proj(text_embeds) | |
return torch.cat([image_text_embeds, text_embeds], dim=1) | |
class ImageProjection(nn.Module): | |
def __init__( | |
self, | |
image_embed_dim: int = 768, | |
cross_attention_dim: int = 768, | |
num_image_text_embeds: int = 32, | |
): | |
super().__init__() | |
self.num_image_text_embeds = num_image_text_embeds | |
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds: torch.Tensor): | |
batch_size = image_embeds.shape[0] | |
# image | |
image_embeds = self.image_embeds(image_embeds) | |
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
image_embeds = self.norm(image_embeds) | |
return image_embeds | |
class IPAdapterFullImageProjection(nn.Module): | |
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): | |
super().__init__() | |
from .attention import FeedForward | |
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds: torch.Tensor): | |
return self.norm(self.ff(image_embeds)) | |
class IPAdapterFaceIDImageProjection(nn.Module): | |
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): | |
super().__init__() | |
from .attention import FeedForward | |
self.num_tokens = num_tokens | |
self.cross_attention_dim = cross_attention_dim | |
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds: torch.Tensor): | |
x = self.ff(image_embeds) | |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) | |
return self.norm(x) | |
class CombinedTimestepLabelEmbeddings(nn.Module): | |
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) | |
def forward(self, timestep, class_labels, hidden_dtype=None): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
class_labels = self.class_embedder(class_labels) # (N, D) | |
conditioning = timesteps_emb + class_labels # (N, D) | |
return conditioning | |
class CombinedTimestepTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, timestep, pooled_projection): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
pooled_projections = self.text_embedder(pooled_projection) | |
conditioning = timesteps_emb + pooled_projections | |
return conditioning | |
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, timestep, guidance, pooled_projection): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
guidance_proj = self.time_proj(guidance) | |
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
time_guidance_emb = timesteps_emb + guidance_emb | |
pooled_projections = self.text_embedder(pooled_projection) | |
conditioning = time_guidance_emb + pooled_projections | |
return conditioning | |
class HunyuanDiTAttentionPool(nn.Module): | |
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 | |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
super().__init__() | |
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim) | |
self.q_proj = nn.Linear(embed_dim, embed_dim) | |
self.v_proj = nn.Linear(embed_dim, embed_dim) | |
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
self.num_heads = num_heads | |
def forward(self, x): | |
x = x.permute(1, 0, 2) # NLC -> LNC | |
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC | |
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC | |
x, _ = F.multi_head_attention_forward( | |
query=x[:1], | |
key=x, | |
value=x, | |
embed_dim_to_check=x.shape[-1], | |
num_heads=self.num_heads, | |
q_proj_weight=self.q_proj.weight, | |
k_proj_weight=self.k_proj.weight, | |
v_proj_weight=self.v_proj.weight, | |
in_proj_weight=None, | |
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
bias_k=None, | |
bias_v=None, | |
add_zero_attn=False, | |
dropout_p=0, | |
out_proj_weight=self.c_proj.weight, | |
out_proj_bias=self.c_proj.bias, | |
use_separate_proj_weight=True, | |
training=self.training, | |
need_weights=False, | |
) | |
return x.squeeze(0) | |
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): | |
def __init__( | |
self, | |
embedding_dim, | |
pooled_projection_dim=1024, | |
seq_len=256, | |
cross_attention_dim=2048, | |
use_style_cond_and_image_meta_size=True, | |
): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.pooler = HunyuanDiTAttentionPool( | |
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim | |
) | |
# Here we use a default learned embedder layer for future extension. | |
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size | |
if use_style_cond_and_image_meta_size: | |
self.style_embedder = nn.Embedding(1, embedding_dim) | |
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim | |
else: | |
extra_in_dim = pooled_projection_dim | |
self.extra_embedder = PixArtAlphaTextProjection( | |
in_features=extra_in_dim, | |
hidden_size=embedding_dim * 4, | |
out_features=embedding_dim, | |
act_fn="silu_fp32", | |
) | |
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) | |
# extra condition1: text | |
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) | |
if self.use_style_cond_and_image_meta_size: | |
# extra condition2: image meta size embedding | |
image_meta_size = self.size_proj(image_meta_size.view(-1)) | |
image_meta_size = image_meta_size.to(dtype=hidden_dtype) | |
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) | |
# extra condition3: style embedding | |
style_embedding = self.style_embedder(style) # (N, embedding_dim) | |
# Concatenate all extra vectors | |
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) | |
else: | |
extra_cond = torch.cat([pooled_projections], dim=1) | |
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] | |
return conditioning | |
class LuminaCombinedTimestepCaptionEmbedding(nn.Module): | |
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): | |
super().__init__() | |
self.time_proj = Timesteps( | |
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 | |
) | |
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) | |
self.caption_embedder = nn.Sequential( | |
nn.LayerNorm(cross_attention_dim), | |
nn.Linear( | |
cross_attention_dim, | |
hidden_size, | |
bias=True, | |
), | |
) | |
def forward(self, timestep, caption_feat, caption_mask): | |
# timestep embedding: | |
time_freq = self.time_proj(timestep) | |
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) | |
# caption condition embedding: | |
caption_mask_float = caption_mask.float().unsqueeze(-1) | |
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) | |
caption_feats_pool = caption_feats_pool.to(caption_feat) | |
caption_embed = self.caption_embedder(caption_feats_pool) | |
conditioning = time_embed + caption_embed | |
return conditioning | |
class TextTimeEmbedding(nn.Module): | |
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(encoder_dim) | |
self.pool = AttentionPooling(num_heads, encoder_dim) | |
self.proj = nn.Linear(encoder_dim, time_embed_dim) | |
self.norm2 = nn.LayerNorm(time_embed_dim) | |
def forward(self, hidden_states): | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = self.pool(hidden_states) | |
hidden_states = self.proj(hidden_states) | |
hidden_states = self.norm2(hidden_states) | |
return hidden_states | |
class TextImageTimeEmbedding(nn.Module): | |
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) | |
self.text_norm = nn.LayerNorm(time_embed_dim) | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
# text | |
time_text_embeds = self.text_proj(text_embeds) | |
time_text_embeds = self.text_norm(time_text_embeds) | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
return time_image_embeds + time_text_embeds | |
class ImageTimeEmbedding(nn.Module): | |
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
self.image_norm = nn.LayerNorm(time_embed_dim) | |
def forward(self, image_embeds: torch.Tensor): | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
time_image_embeds = self.image_norm(time_image_embeds) | |
return time_image_embeds | |
class ImageHintTimeEmbedding(nn.Module): | |
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
self.image_norm = nn.LayerNorm(time_embed_dim) | |
self.input_hint_block = nn.Sequential( | |
nn.Conv2d(3, 16, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(16, 16, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(16, 32, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(32, 32, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(32, 96, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(96, 96, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(96, 256, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(256, 4, 3, padding=1), | |
) | |
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
time_image_embeds = self.image_norm(time_image_embeds) | |
hint = self.input_hint_block(hint) | |
return time_image_embeds, hint | |
class AttentionPooling(nn.Module): | |
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 | |
def __init__(self, num_heads, embed_dim, dtype=None): | |
super().__init__() | |
self.dtype = dtype | |
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.num_heads = num_heads | |
self.dim_per_head = embed_dim // self.num_heads | |
def forward(self, x): | |
bs, length, width = x.size() | |
def shape(x): | |
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) | |
x = x.view(bs, -1, self.num_heads, self.dim_per_head) | |
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) | |
x = x.transpose(1, 2) | |
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) | |
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) | |
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) | |
x = x.transpose(1, 2) | |
return x | |
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) | |
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) | |
# (bs*n_heads, class_token_length, dim_per_head) | |
q = shape(self.q_proj(class_token)) | |
# (bs*n_heads, length+class_token_length, dim_per_head) | |
k = shape(self.k_proj(x)) | |
v = shape(self.v_proj(x)) | |
# (bs*n_heads, class_token_length, length+class_token_length): | |
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) | |
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards | |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
# (bs*n_heads, dim_per_head, class_token_length) | |
a = torch.einsum("bts,bcs->bct", weight, v) | |
# (bs, length+1, width) | |
a = a.reshape(bs, -1, 1).transpose(1, 2) | |
return a[:, 0, :] # cls_token | |
def get_fourier_embeds_from_boundingbox(embed_dim, box): | |
""" | |
Args: | |
embed_dim: int | |
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline | |
Returns: | |
[B x N x embed_dim] tensor of positional embeddings | |
""" | |
batch_size, num_boxes = box.shape[:2] | |
emb = 100 ** (torch.arange(embed_dim) / embed_dim) | |
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) | |
emb = emb * box.unsqueeze(-1) | |
emb = torch.stack((emb.sin(), emb.cos()), dim=-1) | |
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) | |
return emb | |
class GLIGENTextBoundingboxProjection(nn.Module): | |
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): | |
super().__init__() | |
self.positive_len = positive_len | |
self.out_dim = out_dim | |
self.fourier_embedder_dim = fourier_freqs | |
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy | |
if isinstance(out_dim, tuple): | |
out_dim = out_dim[0] | |
if feature_type == "text-only": | |
self.linears = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
elif feature_type == "text-image": | |
self.linears_text = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.linears_image = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
def forward( | |
self, | |
boxes, | |
masks, | |
positive_embeddings=None, | |
phrases_masks=None, | |
image_masks=None, | |
phrases_embeddings=None, | |
image_embeddings=None, | |
): | |
masks = masks.unsqueeze(-1) | |
# embedding position (it may includes padding as placeholder) | |
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C | |
# learnable null embedding | |
xyxy_null = self.null_position_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
# positionet with text only information | |
if positive_embeddings is not None: | |
# learnable null embedding | |
positive_null = self.null_positive_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null | |
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) | |
# positionet with text and image information | |
else: | |
phrases_masks = phrases_masks.unsqueeze(-1) | |
image_masks = image_masks.unsqueeze(-1) | |
# learnable null embedding | |
text_null = self.null_text_feature.view(1, 1, -1) | |
image_null = self.null_image_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null | |
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null | |
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) | |
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) | |
objs = torch.cat([objs_text, objs_image], dim=1) | |
return objs | |
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | |
""" | |
For PixArt-Alpha. | |
Reference: | |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
""" | |
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
super().__init__() | |
self.outdim = size_emb_dim | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.use_additional_conditions = use_additional_conditions | |
if use_additional_conditions: | |
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
if self.use_additional_conditions: | |
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) | |
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) | |
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) | |
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) | |
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) | |
else: | |
conditioning = timesteps_emb | |
return conditioning | |
class PixArtAlphaTextProjection(nn.Module): | |
""" | |
Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
""" | |
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): | |
super().__init__() | |
if out_features is None: | |
out_features = hidden_size | |
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
if act_fn == "gelu_tanh": | |
self.act_1 = nn.GELU(approximate="tanh") | |
elif act_fn == "silu": | |
self.act_1 = nn.SiLU() | |
elif act_fn == "silu_fp32": | |
self.act_1 = FP32SiLU() | |
else: | |
raise ValueError(f"Unknown activation function: {act_fn}") | |
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) | |
def forward(self, caption): | |
hidden_states = self.linear_1(caption) | |
hidden_states = self.act_1(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
class IPAdapterPlusImageProjectionBlock(nn.Module): | |
def __init__( | |
self, | |
embed_dims: int = 768, | |
dim_head: int = 64, | |
heads: int = 16, | |
ffn_ratio: float = 4, | |
) -> None: | |
super().__init__() | |
from .attention import FeedForward | |
self.ln0 = nn.LayerNorm(embed_dims) | |
self.ln1 = nn.LayerNorm(embed_dims) | |
self.attn = Attention( | |
query_dim=embed_dims, | |
dim_head=dim_head, | |
heads=heads, | |
out_bias=False, | |
) | |
self.ff = nn.Sequential( | |
nn.LayerNorm(embed_dims), | |
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | |
) | |
def forward(self, x, latents, residual): | |
encoder_hidden_states = self.ln0(x) | |
latents = self.ln1(latents) | |
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) | |
latents = self.attn(latents, encoder_hidden_states) + residual | |
latents = self.ff(latents) + latents | |
return latents | |
class IPAdapterPlusImageProjection(nn.Module): | |
"""Resampler of IP-Adapter Plus. | |
Args: | |
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, | |
that is the same | |
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. | |
hidden_dims (int): | |
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults | |
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. | |
Defaults to 16. num_queries (int): | |
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio | |
of feedforward network hidden | |
layer channels. Defaults to 4. | |
""" | |
def __init__( | |
self, | |
embed_dims: int = 768, | |
output_dims: int = 1024, | |
hidden_dims: int = 1280, | |
depth: int = 4, | |
dim_head: int = 64, | |
heads: int = 16, | |
num_queries: int = 8, | |
ffn_ratio: float = 4, | |
) -> None: | |
super().__init__() | |
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) | |
self.proj_in = nn.Linear(embed_dims, hidden_dims) | |
self.proj_out = nn.Linear(hidden_dims, output_dims) | |
self.norm_out = nn.LayerNorm(output_dims) | |
self.layers = nn.ModuleList( | |
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward pass. | |
Args: | |
x (torch.Tensor): Input Tensor. | |
Returns: | |
torch.Tensor: Output Tensor. | |
""" | |
latents = self.latents.repeat(x.size(0), 1, 1) | |
x = self.proj_in(x) | |
for block in self.layers: | |
residual = latents | |
latents = block(x, latents, residual) | |
latents = self.proj_out(latents) | |
return self.norm_out(latents) | |
class IPAdapterFaceIDPlusImageProjection(nn.Module): | |
"""FacePerceiverResampler of IP-Adapter Plus. | |
Args: | |
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, | |
that is the same | |
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. | |
hidden_dims (int): | |
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults | |
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. | |
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. | |
ffn_ratio (float): The expansion ratio of feedforward network hidden | |
layer channels. Defaults to 4. | |
ffproj_ratio (float): The expansion ratio of feedforward network hidden | |
layer channels (for ID embeddings). Defaults to 4. | |
""" | |
def __init__( | |
self, | |
embed_dims: int = 768, | |
output_dims: int = 768, | |
hidden_dims: int = 1280, | |
id_embeddings_dim: int = 512, | |
depth: int = 4, | |
dim_head: int = 64, | |
heads: int = 16, | |
num_tokens: int = 4, | |
num_queries: int = 8, | |
ffn_ratio: float = 4, | |
ffproj_ratio: int = 2, | |
) -> None: | |
super().__init__() | |
from .attention import FeedForward | |
self.num_tokens = num_tokens | |
self.embed_dim = embed_dims | |
self.clip_embeds = None | |
self.shortcut = False | |
self.shortcut_scale = 1.0 | |
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) | |
self.norm = nn.LayerNorm(embed_dims) | |
self.proj_in = nn.Linear(hidden_dims, embed_dims) | |
self.proj_out = nn.Linear(embed_dims, output_dims) | |
self.norm_out = nn.LayerNorm(output_dims) | |
self.layers = nn.ModuleList( | |
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] | |
) | |
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: | |
"""Forward pass. | |
Args: | |
id_embeds (torch.Tensor): Input Tensor (ID embeds). | |
Returns: | |
torch.Tensor: Output Tensor. | |
""" | |
id_embeds = id_embeds.to(self.clip_embeds.dtype) | |
id_embeds = self.proj(id_embeds) | |
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) | |
id_embeds = self.norm(id_embeds) | |
latents = id_embeds | |
clip_embeds = self.proj_in(self.clip_embeds) | |
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) | |
for block in self.layers: | |
residual = latents | |
latents = block(x, latents, residual) | |
latents = self.proj_out(latents) | |
out = self.norm_out(latents) | |
if self.shortcut: | |
out = id_embeds + self.shortcut_scale * out | |
return out | |
class MultiIPAdapterImageProjection(nn.Module): | |
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): | |
super().__init__() | |
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) | |
def forward(self, image_embeds: List[torch.Tensor]): | |
projected_image_embeds = [] | |
# currently, we accept `image_embeds` as | |
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] | |
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] | |
if not isinstance(image_embeds, list): | |
deprecation_message = ( | |
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." | |
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." | |
) | |
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) | |
image_embeds = [image_embeds.unsqueeze(1)] | |
if len(image_embeds) != len(self.image_projection_layers): | |
raise ValueError( | |
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" | |
) | |
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): | |
batch_size, num_images = image_embed.shape[0], image_embed.shape[1] | |
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) | |
image_embed = image_projection_layer(image_embed) | |
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) | |
projected_image_embeds.append(image_embed) | |
return projected_image_embeds |