Spaces:
Build error
Build error
#Based on Flux code because of weird hunyuan video code license. | |
import torch | |
import comfy.ldm.flux.layers | |
import comfy.ldm.modules.diffusionmodules.mmdit | |
from comfy.ldm.modules.attention import optimized_attention | |
from dataclasses import dataclass | |
from einops import repeat | |
from torch import Tensor, nn | |
from comfy.ldm.flux.layers import ( | |
DoubleStreamBlock, | |
EmbedND, | |
LastLayer, | |
MLPEmbedder, | |
SingleStreamBlock, | |
timestep_embedding | |
) | |
import comfy.ldm.common_dit | |
class HunyuanVideoParams: | |
in_channels: int | |
out_channels: int | |
vec_in_dim: int | |
context_in_dim: int | |
hidden_size: int | |
mlp_ratio: float | |
num_heads: int | |
depth: int | |
depth_single_blocks: int | |
axes_dim: list | |
theta: int | |
patch_size: list | |
qkv_bias: bool | |
guidance_embed: bool | |
class SelfAttentionRef(nn.Module): | |
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None): | |
super().__init__() | |
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) | |
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) | |
class TokenRefinerBlock(nn.Module): | |
def __init__( | |
self, | |
hidden_size, | |
heads, | |
dtype=None, | |
device=None, | |
operations=None | |
): | |
super().__init__() | |
self.heads = heads | |
mlp_hidden_dim = hidden_size * 4 | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device), | |
) | |
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) | |
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations) | |
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) | |
self.mlp = nn.Sequential( | |
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), | |
nn.SiLU(), | |
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), | |
) | |
def forward(self, x, c, mask): | |
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) | |
norm_x = self.norm1(x) | |
qkv = self.self_attn.qkv(norm_x) | |
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) | |
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) | |
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) | |
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) | |
return x | |
class IndividualTokenRefiner(nn.Module): | |
def __init__( | |
self, | |
hidden_size, | |
heads, | |
num_blocks, | |
dtype=None, | |
device=None, | |
operations=None | |
): | |
super().__init__() | |
self.blocks = nn.ModuleList( | |
[ | |
TokenRefinerBlock( | |
hidden_size=hidden_size, | |
heads=heads, | |
dtype=dtype, | |
device=device, | |
operations=operations | |
) | |
for _ in range(num_blocks) | |
] | |
) | |
def forward(self, x, c, mask): | |
m = None | |
if mask is not None: | |
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) | |
m = m + m.transpose(2, 3) | |
for block in self.blocks: | |
x = block(x, c, m) | |
return x | |
class TokenRefiner(nn.Module): | |
def __init__( | |
self, | |
text_dim, | |
hidden_size, | |
heads, | |
num_blocks, | |
dtype=None, | |
device=None, | |
operations=None | |
): | |
super().__init__() | |
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device) | |
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations) | |
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations) | |
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations) | |
def forward( | |
self, | |
x, | |
timesteps, | |
mask, | |
): | |
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) | |
# m = mask.float().unsqueeze(-1) | |
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise | |
c = x.sum(dim=1) / x.shape[1] | |
c = t + self.c_embedder(c.to(x.dtype)) | |
x = self.input_embedder(x) | |
x = self.individual_token_refiner(x, c, mask) | |
return x | |
class HunyuanVideo(nn.Module): | |
""" | |
Transformer model for flow matching on sequences. | |
""" | |
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): | |
super().__init__() | |
self.dtype = dtype | |
params = HunyuanVideoParams(**kwargs) | |
self.params = params | |
self.patch_size = params.patch_size | |
self.in_channels = params.in_channels | |
self.out_channels = params.out_channels | |
if params.hidden_size % params.num_heads != 0: | |
raise ValueError( | |
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" | |
) | |
pe_dim = params.hidden_size // params.num_heads | |
if sum(params.axes_dim) != pe_dim: | |
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") | |
self.hidden_size = params.hidden_size | |
self.num_heads = params.num_heads | |
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) | |
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations) | |
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) | |
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) | |
self.guidance_in = ( | |
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() | |
) | |
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations) | |
self.double_blocks = nn.ModuleList( | |
[ | |
DoubleStreamBlock( | |
self.hidden_size, | |
self.num_heads, | |
mlp_ratio=params.mlp_ratio, | |
qkv_bias=params.qkv_bias, | |
flipped_img_txt=True, | |
dtype=dtype, device=device, operations=operations | |
) | |
for _ in range(params.depth) | |
] | |
) | |
self.single_blocks = nn.ModuleList( | |
[ | |
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) | |
for _ in range(params.depth_single_blocks) | |
] | |
) | |
if final_layer: | |
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) | |
def forward_orig( | |
self, | |
img: Tensor, | |
img_ids: Tensor, | |
txt: Tensor, | |
txt_ids: Tensor, | |
txt_mask: Tensor, | |
timesteps: Tensor, | |
y: Tensor, | |
guidance: Tensor = None, | |
control=None, | |
transformer_options={}, | |
) -> Tensor: | |
patches_replace = transformer_options.get("patches_replace", {}) | |
initial_shape = list(img.shape) | |
# running on sequences img | |
img = self.img_in(img) | |
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) | |
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) | |
if self.params.guidance_embed: | |
if guidance is not None: | |
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) | |
if txt_mask is not None and not torch.is_floating_point(txt_mask): | |
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max | |
txt = self.txt_in(txt, timesteps, txt_mask) | |
ids = torch.cat((img_ids, txt_ids), dim=1) | |
pe = self.pe_embedder(ids) | |
img_len = img.shape[1] | |
if txt_mask is not None: | |
attn_mask_len = img_len + txt.shape[1] | |
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) | |
attn_mask[:, 0, img_len:] = txt_mask | |
else: | |
attn_mask = None | |
blocks_replace = patches_replace.get("dit", {}) | |
for i, block in enumerate(self.double_blocks): | |
if ("double_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) | |
return out | |
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) | |
txt = out["txt"] | |
img = out["img"] | |
else: | |
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) | |
if control is not None: # Controlnet | |
control_i = control.get("input") | |
if i < len(control_i): | |
add = control_i[i] | |
if add is not None: | |
img += add | |
img = torch.cat((img, txt), 1) | |
for i, block in enumerate(self.single_blocks): | |
if ("single_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) | |
return out | |
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) | |
img = out["img"] | |
else: | |
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) | |
if control is not None: # Controlnet | |
control_o = control.get("output") | |
if i < len(control_o): | |
add = control_o[i] | |
if add is not None: | |
img[:, : img_len] += add | |
img = img[:, : img_len] | |
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) | |
shape = initial_shape[-3:] | |
for i in range(len(shape)): | |
shape[i] = shape[i] // self.patch_size[i] | |
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) | |
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) | |
img = img.reshape(initial_shape) | |
return img | |
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): | |
bs, c, t, h, w = x.shape | |
patch_size = self.patch_size | |
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) | |
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) | |
w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) | |
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) | |
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) | |
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) | |
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) | |
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) | |
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) | |
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) | |
return out | |