NCJ commited on
Commit
2b18e88
1 Parent(s): 733d5b4

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +481 -0
pipeline.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers.configuration_utils import register_to_config
5
+ from diffusers.models.controlnet import ControlNetModel, zero_module
6
+ from diffusers.models.embeddings import (
7
+ TextImageProjection,
8
+ TextImageTimeEmbedding,
9
+ TextTimeEmbedding,
10
+ TimestepEmbedding,
11
+ Timesteps,
12
+ )
13
+ from diffusers.models.unets.unet_2d_blocks import (
14
+ CrossAttnDownBlock2D,
15
+ DownBlock2D,
16
+ UNetMidBlock2D,
17
+ UNetMidBlock2DCrossAttn,
18
+ get_down_block,
19
+ )
20
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
21
+ from diffusers.utils import logging
22
+ from torch import nn
23
+ from torch.nn import functional as F
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ class ResBlock(nn.Module):
30
+ def __init__(self, dim):
31
+ super().__init__()
32
+ self.conv = nn.Sequential(
33
+ nn.Conv2d(dim, dim, 3, 1, 1),
34
+ nn.GroupNorm(num_groups=8, num_channels=dim),
35
+ nn.SiLU(inplace=True),
36
+ nn.Conv2d(dim, dim, 3, 1, 1),
37
+ )
38
+
39
+ def forward(self, x):
40
+ return x + self.conv(x)
41
+
42
+
43
+ class NeuralTextureEncoder(nn.Module):
44
+ def __init__(self, in_dim=3, out_dim=16, dims=(32, 64, 128), groups=8):
45
+ super().__init__()
46
+ self.model = nn.Sequential(
47
+ nn.Conv2d(in_dim, dims[0], kernel_size=3, padding=1),
48
+ nn.SiLU(inplace=True),
49
+
50
+ # down 1
51
+ nn.Conv2d(dims[0], dims[1], kernel_size=3, padding=1, stride=2),
52
+ nn.GroupNorm(num_groups=groups, num_channels=dims[1]),
53
+ nn.SiLU(inplace=True),
54
+
55
+ # down 2
56
+ nn.Conv2d(dims[1], dims[2], kernel_size=3, padding=1, stride=2),
57
+ nn.GroupNorm(num_groups=groups, num_channels=dims[2]),
58
+ nn.SiLU(inplace=True),
59
+
60
+ # res blocks
61
+ ResBlock(dims[2]),
62
+ ResBlock(dims[2]),
63
+ ResBlock(dims[2]),
64
+ ResBlock(dims[2]),
65
+
66
+ # up 1
67
+ nn.ConvTranspose2d(dims[2], dims[1], kernel_size=4, padding=1, stride=2),
68
+ nn.GroupNorm(num_groups=groups, num_channels=dims[1]),
69
+ nn.SiLU(inplace=True),
70
+
71
+ # up 2
72
+ nn.ConvTranspose2d(dims[1], dims[0], kernel_size=4, padding=1, stride=2),
73
+ nn.GroupNorm(num_groups=groups, num_channels=dims[0]),
74
+ nn.SiLU(inplace=True),
75
+
76
+ # out
77
+ nn.Conv2d(dims[0], out_dim, kernel_size=3, padding=1),
78
+ )
79
+ self.gradient_checkpointing = False
80
+
81
+ def forward(self, x):
82
+ if self.training and self.gradient_checkpointing:
83
+ x = checkpoint(self.model, x, use_reentrant=False)
84
+ else:
85
+ x = self.model(x)
86
+ return x
87
+
88
+
89
+ class NeuralTextureEmbedding(nn.Module):
90
+ def __init__(
91
+ self,
92
+ conditioning_embedding_channels: int,
93
+ conditioning_channels: int = 3,
94
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
95
+ shading_hint_channels: int = 12, # diffuse + 3 * ggx
96
+ ):
97
+ super().__init__()
98
+ self.conditioning_channels = conditioning_channels
99
+ self.shading_hint_channels = shading_hint_channels
100
+
101
+ self.conv_in = nn.Conv2d(shading_hint_channels, block_out_channels[0], kernel_size=3, padding=1)
102
+ self.neural_texture_encoder = NeuralTextureEncoder(in_dim=conditioning_channels, out_dim=shading_hint_channels)
103
+
104
+ self.blocks = nn.ModuleList([])
105
+
106
+ for i in range(len(block_out_channels) - 1):
107
+ channel_in = block_out_channels[i]
108
+ channel_out = block_out_channels[i + 1]
109
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
110
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
111
+
112
+ self.conv_out = zero_module(
113
+ nn.Conv2d(
114
+ block_out_channels[-1],
115
+ conditioning_embedding_channels,
116
+ kernel_size=3,
117
+ padding=1
118
+ )
119
+ )
120
+
121
+ def forward(self, all_conditioning):
122
+ # conditioning: [BS, 4 + 12, 512, 512] # RGB ref image + shading hint (diffuse + 3 * ggx)
123
+ conditioning, shading_hint = torch.split(
124
+ all_conditioning,
125
+ [self.conditioning_channels, self.shading_hint_channels],
126
+ dim=1
127
+ )
128
+ embedding = self.neural_texture_encoder(conditioning) # [BS, 15, 512, 512]
129
+
130
+ # multiply shading hint to each channel
131
+ embedding = embedding * shading_hint
132
+ embedding = self.conv_in(embedding)
133
+ embedding = F.silu(embedding)
134
+
135
+ for block in self.blocks:
136
+ embedding = block(embedding)
137
+ embedding = F.silu(embedding)
138
+
139
+ embedding = self.conv_out(embedding)
140
+
141
+ return embedding
142
+
143
+
144
+ class NeuralTextureControlNetModel(ControlNetModel):
145
+ """
146
+ A Neural Texture ControlNet Model.
147
+
148
+ Args:
149
+ in_channels (`int`, defaults to 4, RGBA):
150
+ The number of channels in the input sample.
151
+ shading_hint_channels (`int`, defaults to 12): channel number of hints
152
+ """
153
+
154
+ @register_to_config
155
+ def __init__(
156
+ self,
157
+ in_channels: int = 4,
158
+ conditioning_channels: int = 3,
159
+ flip_sin_to_cos: bool = True,
160
+ freq_shift: int = 0,
161
+ down_block_types: Tuple[str, ...] = (
162
+ "CrossAttnDownBlock2D",
163
+ "CrossAttnDownBlock2D",
164
+ "CrossAttnDownBlock2D",
165
+ "DownBlock2D",
166
+ ),
167
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
168
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
169
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
170
+ layers_per_block: int = 2,
171
+ downsample_padding: int = 1,
172
+ mid_block_scale_factor: float = 1,
173
+ act_fn: str = "silu",
174
+ norm_num_groups: Optional[int] = 32,
175
+ norm_eps: float = 1e-5,
176
+ cross_attention_dim: int = 1280,
177
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
178
+ encoder_hid_dim: Optional[int] = None,
179
+ encoder_hid_dim_type: Optional[str] = None,
180
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
181
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
182
+ use_linear_projection: bool = False,
183
+ class_embed_type: Optional[str] = None,
184
+ addition_embed_type: Optional[str] = None,
185
+ addition_time_embed_dim: Optional[int] = None,
186
+ num_class_embeds: Optional[int] = None,
187
+ upcast_attention: bool = False,
188
+ resnet_time_scale_shift: str = "default",
189
+ projection_class_embeddings_input_dim: Optional[int] = None,
190
+ controlnet_conditioning_channel_order: str = "rgb",
191
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
192
+ global_pool_conditions: bool = False,
193
+ addition_embed_type_num_heads: int = 64,
194
+ shading_hint_channels: int = 12,
195
+ ):
196
+ super().__init__()
197
+
198
+ num_attention_heads = num_attention_heads or attention_head_dim
199
+
200
+ assert controlnet_conditioning_channel_order == "rgb", "Only RGB channel order is supported."
201
+ assert global_pool_conditions is False, "Global pooling conditions is not supported."
202
+
203
+ # Check inputs
204
+ if len(block_out_channels) != len(down_block_types):
205
+ raise ValueError(
206
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
207
+ )
208
+
209
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
210
+ raise ValueError(
211
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
212
+ )
213
+
214
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
215
+ raise ValueError(
216
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
217
+ )
218
+
219
+ if isinstance(transformer_layers_per_block, int):
220
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
221
+
222
+ # input
223
+ conv_in_kernel = 3
224
+ conv_in_padding = (conv_in_kernel - 1) // 2
225
+ self.conv_in = nn.Conv2d(
226
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
227
+ )
228
+
229
+ # time
230
+ time_embed_dim = block_out_channels[0] * 4
231
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
232
+ timestep_input_dim = block_out_channels[0]
233
+ self.time_embedding = TimestepEmbedding(
234
+ timestep_input_dim,
235
+ time_embed_dim,
236
+ act_fn=act_fn,
237
+ )
238
+
239
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
240
+ encoder_hid_dim_type = "text_proj"
241
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
242
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
243
+
244
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
245
+ raise ValueError(
246
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
247
+ )
248
+
249
+ if encoder_hid_dim_type == "text_proj":
250
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
251
+ elif encoder_hid_dim_type == "text_image_proj":
252
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
253
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
254
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
255
+ self.encoder_hid_proj = TextImageProjection(
256
+ text_embed_dim=encoder_hid_dim,
257
+ image_embed_dim=cross_attention_dim,
258
+ cross_attention_dim=cross_attention_dim,
259
+ )
260
+
261
+ elif encoder_hid_dim_type is not None:
262
+ raise ValueError(
263
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
264
+ )
265
+ else:
266
+ self.encoder_hid_proj = None
267
+
268
+ # class embedding
269
+ if class_embed_type is None and num_class_embeds is not None:
270
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
271
+ elif class_embed_type == "timestep":
272
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
273
+ elif class_embed_type == "identity":
274
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
275
+ elif class_embed_type == "projection":
276
+ if projection_class_embeddings_input_dim is None:
277
+ raise ValueError(
278
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
279
+ )
280
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
281
+ else:
282
+ self.class_embedding = None
283
+
284
+ if addition_embed_type == "text":
285
+ if encoder_hid_dim is not None:
286
+ text_time_embedding_from_dim = encoder_hid_dim
287
+ else:
288
+ text_time_embedding_from_dim = cross_attention_dim
289
+
290
+ self.add_embedding = TextTimeEmbedding(
291
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
292
+ )
293
+ elif addition_embed_type == "text_image":
294
+ self.add_embedding = TextImageTimeEmbedding(
295
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
296
+ )
297
+ elif addition_embed_type == "text_time":
298
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
299
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
300
+
301
+ elif addition_embed_type is not None:
302
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
303
+
304
+ # control net conditioning embedding
305
+ self.controlnet_cond_embedding = NeuralTextureEmbedding(
306
+ conditioning_embedding_channels=block_out_channels[0],
307
+ block_out_channels=conditioning_embedding_out_channels,
308
+ conditioning_channels=conditioning_channels,
309
+ shading_hint_channels=shading_hint_channels,
310
+ )
311
+
312
+ self.down_blocks = nn.ModuleList([])
313
+ self.controlnet_down_blocks = nn.ModuleList([])
314
+
315
+ if isinstance(only_cross_attention, bool):
316
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
317
+
318
+ if isinstance(attention_head_dim, int):
319
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
320
+
321
+ if isinstance(num_attention_heads, int):
322
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
323
+
324
+ # down
325
+ output_channel = block_out_channels[0]
326
+
327
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
328
+ controlnet_block = zero_module(controlnet_block)
329
+ self.controlnet_down_blocks.append(controlnet_block)
330
+
331
+ for i, down_block_type in enumerate(down_block_types):
332
+ input_channel = output_channel
333
+ output_channel = block_out_channels[i]
334
+ is_final_block = i == len(block_out_channels) - 1
335
+
336
+ down_block = get_down_block(
337
+ down_block_type,
338
+ num_layers=layers_per_block,
339
+ transformer_layers_per_block=transformer_layers_per_block[i],
340
+ in_channels=input_channel,
341
+ out_channels=output_channel,
342
+ temb_channels=time_embed_dim,
343
+ add_downsample=not is_final_block,
344
+ resnet_eps=norm_eps,
345
+ resnet_act_fn=act_fn,
346
+ resnet_groups=norm_num_groups,
347
+ cross_attention_dim=cross_attention_dim,
348
+ num_attention_heads=num_attention_heads[i],
349
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
350
+ downsample_padding=downsample_padding,
351
+ use_linear_projection=use_linear_projection,
352
+ only_cross_attention=only_cross_attention[i],
353
+ upcast_attention=upcast_attention,
354
+ resnet_time_scale_shift=resnet_time_scale_shift,
355
+ )
356
+ self.down_blocks.append(down_block)
357
+
358
+ for _ in range(layers_per_block):
359
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
360
+ controlnet_block = zero_module(controlnet_block)
361
+ self.controlnet_down_blocks.append(controlnet_block)
362
+
363
+ if not is_final_block:
364
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
365
+ controlnet_block = zero_module(controlnet_block)
366
+ self.controlnet_down_blocks.append(controlnet_block)
367
+
368
+ # mid
369
+ mid_block_channel = block_out_channels[-1]
370
+
371
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
372
+ controlnet_block = zero_module(controlnet_block)
373
+ self.controlnet_mid_block = controlnet_block
374
+
375
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
376
+ self.mid_block = UNetMidBlock2DCrossAttn(
377
+ transformer_layers_per_block=transformer_layers_per_block[-1],
378
+ in_channels=mid_block_channel,
379
+ temb_channels=time_embed_dim,
380
+ resnet_eps=norm_eps,
381
+ resnet_act_fn=act_fn,
382
+ output_scale_factor=mid_block_scale_factor,
383
+ resnet_time_scale_shift=resnet_time_scale_shift,
384
+ cross_attention_dim=cross_attention_dim,
385
+ num_attention_heads=num_attention_heads[-1],
386
+ resnet_groups=norm_num_groups,
387
+ use_linear_projection=use_linear_projection,
388
+ upcast_attention=upcast_attention,
389
+ )
390
+ elif mid_block_type == "UNetMidBlock2D":
391
+ self.mid_block = UNetMidBlock2D(
392
+ in_channels=block_out_channels[-1],
393
+ temb_channels=time_embed_dim,
394
+ num_layers=0,
395
+ resnet_eps=norm_eps,
396
+ resnet_act_fn=act_fn,
397
+ output_scale_factor=mid_block_scale_factor,
398
+ resnet_groups=norm_num_groups,
399
+ resnet_time_scale_shift=resnet_time_scale_shift,
400
+ add_attention=False,
401
+ )
402
+ else:
403
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
404
+
405
+ @classmethod
406
+ def from_unet(
407
+ cls,
408
+ unet: UNet2DConditionModel,
409
+ controlnet_conditioning_channel_order: str = "rgb",
410
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
411
+ load_weights_from_unet: bool = True,
412
+ shading_hint_channels: int = 12,
413
+ conditioning_channels: int = 4,
414
+ ):
415
+ r"""
416
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
417
+
418
+ Parameters:
419
+ unet (`UNet2DConditionModel`):
420
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
421
+ where applicable.
422
+ """
423
+ transformer_layers_per_block = (
424
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
425
+ )
426
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
427
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
428
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
429
+ addition_time_embed_dim = (
430
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
431
+ )
432
+
433
+ controlnet = cls(
434
+ encoder_hid_dim=encoder_hid_dim,
435
+ encoder_hid_dim_type=encoder_hid_dim_type,
436
+ addition_embed_type=addition_embed_type,
437
+ addition_time_embed_dim=addition_time_embed_dim,
438
+ transformer_layers_per_block=transformer_layers_per_block,
439
+ in_channels=unet.config.in_channels,
440
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
441
+ freq_shift=unet.config.freq_shift,
442
+ down_block_types=unet.config.down_block_types,
443
+ only_cross_attention=unet.config.only_cross_attention,
444
+ block_out_channels=unet.config.block_out_channels,
445
+ layers_per_block=unet.config.layers_per_block,
446
+ downsample_padding=unet.config.downsample_padding,
447
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
448
+ act_fn=unet.config.act_fn,
449
+ norm_num_groups=unet.config.norm_num_groups,
450
+ norm_eps=unet.config.norm_eps,
451
+ cross_attention_dim=unet.config.cross_attention_dim,
452
+ attention_head_dim=unet.config.attention_head_dim,
453
+ num_attention_heads=unet.config.num_attention_heads,
454
+ use_linear_projection=unet.config.use_linear_projection,
455
+ class_embed_type=unet.config.class_embed_type,
456
+ num_class_embeds=unet.config.num_class_embeds,
457
+ upcast_attention=unet.config.upcast_attention,
458
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
459
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
460
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
461
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
462
+ shading_hint_channels=shading_hint_channels,
463
+ conditioning_channels=conditioning_channels,
464
+ )
465
+
466
+ if load_weights_from_unet:
467
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
468
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
469
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
470
+
471
+ if controlnet.class_embedding:
472
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
473
+
474
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
475
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
476
+
477
+ return controlnet
478
+
479
+ def _set_gradient_checkpointing(self, module, value=False):
480
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, NeuralTextureEncoder)):
481
+ module.gradient_checkpointing = value