MuseVSpace / MuseV /musev /models /super_model.py
anchorxia's picture
add musev
96d7ad8
raw
history blame
11 kB
from __future__ import annotations
import logging
from typing import Any, Dict, Tuple, Union, Optional
from einops import rearrange, repeat
from torch import nn
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
from ..data.data_util import align_repeat_tensor_single_dim
from .unet_3d_condition import UNet3DConditionModel
from .referencenet import ReferenceNet2D
from ip_adapter.ip_adapter import ImageProjModel
logger = logging.getLogger(__name__)
class SuperUNet3DConditionModel(nn.Module):
"""封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。
主要作用
1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些;
2. 便于 accelerator 的分布式训练;
wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj
1. support controlnet, referencenet, etc.
2. support accelerator distributed training
"""
_supports_gradient_checkpointing = True
print_idx = 0
# @register_to_config
def __init__(
self,
unet: nn.Module,
referencenet: nn.Module = None,
controlnet: nn.Module = None,
vae: nn.Module = None,
text_encoder: nn.Module = None,
tokenizer: nn.Module = None,
text_emb_extractor: nn.Module = None,
clip_vision_extractor: nn.Module = None,
ip_adapter_image_proj: nn.Module = None,
) -> None:
"""_summary_
Args:
unet (nn.Module): _description_
referencenet (nn.Module, optional): _description_. Defaults to None.
controlnet (nn.Module, optional): _description_. Defaults to None.
vae (nn.Module, optional): _description_. Defaults to None.
text_encoder (nn.Module, optional): _description_. Defaults to None.
tokenizer (nn.Module, optional): _description_. Defaults to None.
text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None.
clip_vision_extractor (nn.Module, optional): _description_. Defaults to None.
"""
super().__init__()
self.unet = unet
self.referencenet = referencenet
self.controlnet = controlnet
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.text_emb_extractor = text_emb_extractor
self.clip_vision_extractor = clip_vision_extractor
self.ip_adapter_image_proj = ip_adapter_image_proj
def forward(
self,
unet_params: Dict,
encoder_hidden_states: torch.Tensor,
referencenet_params: Dict = None,
controlnet_params: Dict = None,
controlnet_scale: float = 1.0,
vision_clip_emb: Union[torch.Tensor, None] = None,
prompt_only_use_image_prompt: bool = False,
):
"""_summary_
Args:
unet_params (Dict): _description_
encoder_hidden_states (torch.Tensor): b t n d
referencenet_params (Dict, optional): _description_. Defaults to None.
controlnet_params (Dict, optional): _description_. Defaults to None.
controlnet_scale (float, optional): _description_. Defaults to 1.0.
vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None.
prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
batch_size = unet_params["sample"].shape[0]
time_size = unet_params["sample"].shape[2]
# ip_adapter_cross_attn, prepare image prompt
if vision_clip_emb is not None:
# b t n d -> b t n d
if self.print_idx == 0:
logger.debug(
f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
)
if vision_clip_emb.ndim == 3:
vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d")
if self.ip_adapter_image_proj is not None:
vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d")
vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb)
if self.print_idx == 0:
logger.debug(
f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
)
if vision_clip_emb.ndim == 2:
vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d")
vision_clip_emb = rearrange(
vision_clip_emb, "(b t) n d -> b t n d", b=batch_size
)
vision_clip_emb = align_repeat_tensor_single_dim(
vision_clip_emb, target_length=time_size, dim=1
)
if self.print_idx == 0:
logger.debug(
f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}"
)
if vision_clip_emb is None and encoder_hidden_states is not None:
vision_clip_emb = encoder_hidden_states
if vision_clip_emb is not None and encoder_hidden_states is None:
encoder_hidden_states = vision_clip_emb
# 当 prompt_only_use_image_prompt 为True时,
# 1. referencenet 都使用 vision_clip_emb
# 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新
# 3. controlnet 当前使用 text_prompt
# when prompt_only_use_image_prompt True,
# 1. referencenet use vision_clip_emb
# 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update
# 3. controlnet use text_prompt
# extract referencenet emb
if self.referencenet is not None and referencenet_params is not None:
referencenet_encoder_hidden_states = align_repeat_tensor_single_dim(
vision_clip_emb,
target_length=referencenet_params["num_frames"],
dim=1,
)
referencenet_params["encoder_hidden_states"] = rearrange(
referencenet_encoder_hidden_states, "b t n d->(b t) n d"
)
referencenet_out = self.referencenet(**referencenet_params)
(
down_block_refer_embs,
mid_block_refer_emb,
refer_self_attn_emb,
) = referencenet_out
if down_block_refer_embs is not None:
if self.print_idx == 0:
logger.debug(
f"len(down_block_refer_embs)={len(down_block_refer_embs)}"
)
for i, down_emb in enumerate(down_block_refer_embs):
if self.print_idx == 0:
logger.debug(
f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}"
)
else:
if self.print_idx == 0:
logger.debug(f"down_block_refer_embs is None")
if mid_block_refer_emb is not None:
if self.print_idx == 0:
logger.debug(
f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}"
)
else:
if self.print_idx == 0:
logger.debug(f"mid_block_refer_emb is None")
if refer_self_attn_emb is not None:
if self.print_idx == 0:
logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}")
for i, self_attn_emb in enumerate(refer_self_attn_emb):
if self.print_idx == 0:
logger.debug(
f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}"
)
else:
if self.print_idx == 0:
logger.debug(f"refer_self_attn_emb is None")
else:
down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = (
None,
None,
None,
)
# extract controlnet emb
if self.controlnet is not None and controlnet_params is not None:
controlnet_encoder_hidden_states = align_repeat_tensor_single_dim(
encoder_hidden_states,
target_length=unet_params["sample"].shape[2],
dim=1,
)
controlnet_params["encoder_hidden_states"] = rearrange(
controlnet_encoder_hidden_states, " b t n d -> (b t) n d"
)
(
down_block_additional_residuals,
mid_block_additional_residual,
) = self.controlnet(**controlnet_params)
if controlnet_scale != 1.0:
down_block_additional_residuals = [
x * controlnet_scale for x in down_block_additional_residuals
]
mid_block_additional_residual = (
mid_block_additional_residual * controlnet_scale
)
for i, down_block_additional_residual in enumerate(
down_block_additional_residuals
):
if self.print_idx == 0:
logger.debug(
f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}"
)
if self.print_idx == 0:
logger.debug(
f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}"
)
else:
down_block_additional_residuals = None
mid_block_additional_residual = None
if prompt_only_use_image_prompt and vision_clip_emb is not None:
encoder_hidden_states = vision_clip_emb
# run unet
out = self.unet(
**unet_params,
down_block_refer_embs=down_block_refer_embs,
mid_block_refer_emb=mid_block_refer_emb,
refer_self_attn_emb=refer_self_attn_emb,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
encoder_hidden_states=encoder_hidden_states,
vision_clip_emb=vision_clip_emb,
)
self.print_idx += 1
return out
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)):
module.gradient_checkpointing = value