Spaces:
Build error
Build error
import math | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch | |
class ControlNetEmbedder(nn.Module): | |
def __init__( | |
self, | |
img_size: int, | |
patch_size: int, | |
in_chans: int, | |
attention_head_dim: int, | |
num_attention_heads: int, | |
adm_in_channels: int, | |
num_layers: int, | |
main_model_double: int, | |
double_y_emb: bool, | |
device: torch.device, | |
dtype: torch.dtype, | |
pos_embed_max_size: Optional[int] = None, | |
operations = None, | |
): | |
super().__init__() | |
self.main_model_double = main_model_double | |
self.dtype = dtype | |
self.hidden_size = num_attention_heads * attention_head_dim | |
self.patch_size = patch_size | |
self.x_embedder = PatchEmbed( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=self.hidden_size, | |
strict_img_size=pos_embed_max_size is None, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
) | |
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations) | |
self.double_y_emb = double_y_emb | |
if self.double_y_emb: | |
self.orig_y_embedder = VectorEmbedder( | |
adm_in_channels, self.hidden_size, dtype, device, operations=operations | |
) | |
self.y_embedder = VectorEmbedder( | |
self.hidden_size, self.hidden_size, dtype, device, operations=operations | |
) | |
else: | |
self.y_embedder = VectorEmbedder( | |
adm_in_channels, self.hidden_size, dtype, device, operations=operations | |
) | |
self.transformer_blocks = nn.ModuleList( | |
DismantledBlock( | |
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, | |
dtype=dtype, device=device, operations=operations | |
) | |
for _ in range(num_layers) | |
) | |
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features | |
# TODO double check this logic when 8b | |
self.use_y_embedder = True | |
self.controlnet_blocks = nn.ModuleList([]) | |
for _ in range(len(self.transformer_blocks)): | |
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) | |
self.controlnet_blocks.append(controlnet_block) | |
self.pos_embed_input = PatchEmbed( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=self.hidden_size, | |
strict_img_size=False, | |
device=device, | |
dtype=dtype, | |
operations=operations, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
timesteps: torch.Tensor, | |
y: Optional[torch.Tensor] = None, | |
context: Optional[torch.Tensor] = None, | |
hint = None, | |
) -> Tuple[Tensor, List[Tensor]]: | |
x_shape = list(x.shape) | |
x = self.x_embedder(x) | |
if not self.double_y_emb: | |
h = (x_shape[-2] + 1) // self.patch_size | |
w = (x_shape[-1] + 1) // self.patch_size | |
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device) | |
c = self.t_embedder(timesteps, dtype=x.dtype) | |
if y is not None and self.y_embedder is not None: | |
if self.double_y_emb: | |
y = self.orig_y_embedder(y) | |
y = self.y_embedder(y) | |
c = c + y | |
x = x + self.pos_embed_input(hint) | |
block_out = () | |
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks)) | |
for i in range(len(self.transformer_blocks)): | |
out = self.transformer_blocks[i](x, c) | |
if not self.double_y_emb: | |
x = out | |
block_out += (self.controlnet_blocks[i](out),) * repeat | |
return {"output": block_out} | |