File size: 4,204 Bytes
be40477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch


class ControlNetEmbedder(nn.Module):

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        attention_head_dim: int,
        num_attention_heads: int,
        adm_in_channels: int,
        num_layers: int,
        main_model_double: int,
        double_y_emb: bool,
        device: torch.device,
        dtype: torch.dtype,
        pos_embed_max_size: Optional[int] = None,
        operations = None,
    ):
        super().__init__()
        self.main_model_double = main_model_double
        self.dtype = dtype
        self.hidden_size = num_attention_heads * attention_head_dim
        self.patch_size = patch_size
        self.x_embedder = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=pos_embed_max_size is None,
            device=device,
            dtype=dtype,
            operations=operations,
        )

        self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)

        self.double_y_emb = double_y_emb
        if self.double_y_emb:
            self.orig_y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )
            self.y_embedder = VectorEmbedder(
                self.hidden_size, self.hidden_size, dtype, device, operations=operations
            )
        else:
            self.y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )

        self.transformer_blocks = nn.ModuleList(
            DismantledBlock(
                hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
                dtype=dtype, device=device, operations=operations
            )
            for _ in range(num_layers)
        )

        # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
        # TODO double check this logic when 8b
        self.use_y_embedder = True

        self.controlnet_blocks = nn.ModuleList([])
        for _ in range(len(self.transformer_blocks)):
            controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
            self.controlnet_blocks.append(controlnet_block)

        self.pos_embed_input = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=False,
            device=device,
            dtype=dtype,
            operations=operations,
        )

    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
        hint = None,
    ) -> Tuple[Tensor, List[Tensor]]:
        x_shape = list(x.shape)
        x = self.x_embedder(x)
        if not self.double_y_emb:
            h = (x_shape[-2] + 1) // self.patch_size
            w = (x_shape[-1] + 1) // self.patch_size
            x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
        c = self.t_embedder(timesteps, dtype=x.dtype)
        if y is not None and self.y_embedder is not None:
            if self.double_y_emb:
                y = self.orig_y_embedder(y)
            y = self.y_embedder(y)
            c = c + y

        x = x + self.pos_embed_input(hint)

        block_out = ()

        repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
        for i in range(len(self.transformer_blocks)):
            out = self.transformer_blocks[i](x, c)
            if not self.double_y_emb:
                x = out
            block_out += (self.controlnet_blocks[i](out),) * repeat

        return {"output": block_out}