import copy from typing import Any, Callable, Dict, Iterable, Union import PIL import cv2 import torch import argparse import datetime import logging import inspect import math import os import shutil from typing import Dict, List, Optional, Tuple from pprint import pprint from collections import OrderedDict from dataclasses import dataclass import gc import time import numpy as np from omegaconf import OmegaConf from omegaconf import SCMode import torch from torch import nn import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange, repeat import pandas as pd import h5py from diffusers.models.modeling_utils import load_state_dict from diffusers.utils import ( logging, ) from diffusers.utils.import_utils import is_xformers_available from ..models.unet_3d_condition import UNet3DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def update_unet_with_sd( unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet" ): """更新T2V模型中的T2I参数. update t2i parameters in t2v model Args: unet (nn.Module): _description_ sd_model (Tuple[str, nn.Module]): _description_ Returns: _type_: _description_ """ # dtype = unet.dtype # TODO: in this way, sd_model_path must be absolute path, to be more dynamic if isinstance(sd_model, str): if os.path.isdir(sd_model): unet_state_dict = load_state_dict( os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"), ) elif os.path.isfile(sd_model): if sd_model.endswith("pth"): unet_state_dict = torch.load(sd_model, map_location="cpu") print(f"referencenet successful load ={sd_model} with torch.load") else: try: unet_state_dict = load_state_dict(sd_model) print( f"referencenet successful load with {sd_model} with load_state_dict" ) except Exception as e: print(e) elif isinstance(sd_model, nn.Module): unet_state_dict = sd_model.state_dict() else: raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str") missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}" # unet.to(dtype=dtype) return unet def load_unet( sd_unet_model: Tuple[str, nn.Module], sd_model: Tuple[str, nn.Module] = None, cross_attention_dim: int = 768, temporal_transformer: str = "TransformerTemporalModel", temporal_conv_block: str = "TemporalConvLayer", need_spatial_position_emb: bool = False, need_transformer_in: bool = True, need_t2i_ip_adapter: bool = False, need_adain_temporal_cond: bool = False, t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor", keep_vision_condtion: bool = False, use_anivv1_cfg: bool = False, resnet_2d_skip_time_act: bool = False, dtype: torch.dtype = torch.float16, need_zero_vis_cond_temb: bool = True, norm_spatial_length: bool = True, spatial_max_length: int = 2048, need_refer_emb: bool = False, ip_adapter_cross_attn=False, t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", need_t2i_facein: bool = False, need_t2i_ip_adapter_face: bool = False, strict: bool = True, ): """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel Args: sd_unet_model (Tuple[str, nn.Module]): _description_ sd_model (Tuple[str, nn.Module]): _description_ cross_attention_dim (int, optional): _description_. Defaults to 768. temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel". temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer". need_spatial_position_emb (bool, optional): _description_. Defaults to False. need_transformer_in (bool, optional): _description_. Defaults to True. need_t2i_ip_adapter (bool, optional): _description_. Defaults to False. need_adain_temporal_cond (bool, optional): _description_. Defaults to False. t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor". keep_vision_condtion (bool, optional): _description_. Defaults to False. use_anivv1_cfg (bool, optional): _description_. Defaults to False. resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False. dtype (torch.dtype, optional): _description_. Defaults to torch.float16. need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True. norm_spatial_length (bool, optional): _description_. Defaults to True. spatial_max_length (int, optional): _description_. Defaults to 2048. Returns: _type_: _description_ """ if isinstance(sd_unet_model, str): unet = UNet3DConditionModel.from_pretrained_2d( sd_unet_model, subfolder="unet", temporal_transformer=temporal_transformer, temporal_conv_block=temporal_conv_block, cross_attention_dim=cross_attention_dim, need_spatial_position_emb=need_spatial_position_emb, need_transformer_in=need_transformer_in, need_t2i_ip_adapter=need_t2i_ip_adapter, need_adain_temporal_cond=need_adain_temporal_cond, t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor, keep_vision_condtion=keep_vision_condtion, use_anivv1_cfg=use_anivv1_cfg, resnet_2d_skip_time_act=resnet_2d_skip_time_act, torch_dtype=dtype, need_zero_vis_cond_temb=need_zero_vis_cond_temb, norm_spatial_length=norm_spatial_length, spatial_max_length=spatial_max_length, need_refer_emb=need_refer_emb, ip_adapter_cross_attn=ip_adapter_cross_attn, t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor, need_t2i_facein=need_t2i_facein, strict=strict, need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, ) elif isinstance(sd_unet_model, nn.Module): unet = sd_unet_model if sd_model is not None: unet = update_unet_with_sd(unet, sd_model) return unet def load_unet_custom_unet( sd_unet_model: Tuple[str, nn.Module], sd_model: Tuple[str, nn.Module], unet_class: nn.Module, ): """ 通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 model is not defined in models.unet_3d_condition.py:UNet3DConditionModel Args: sd_unet_model (Tuple[str, nn.Module]): _description_ sd_model (Tuple[str, nn.Module]): _description_ unet_class (nn.Module): _description_ Returns: _type_: _description_ """ if isinstance(sd_unet_model, str): unet = unet_class.from_pretrained( sd_unet_model, subfolder="unet", ) elif isinstance(sd_unet_model, nn.Module): unet = sd_unet_model # TODO: in this way, sd_model_path must be absolute path, to be more dynamic if isinstance(sd_model, str): unet_state_dict = load_state_dict( os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"), ) elif isinstance(sd_model, nn.Module): unet_state_dict = sd_model.state_dict() missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) assert ( len(unexpected) == 0 ), "unet load_state_dict error" # Load scheduler, tokenizer and models. return unet def load_unet_by_name( model_name: str, sd_unet_model: Tuple[str, nn.Module], sd_model: Tuple[str, nn.Module] = None, cross_attention_dim: int = 768, dtype: torch.dtype = torch.float16, need_t2i_facein: bool = False, need_t2i_ip_adapter_face: bool = False, strict: bool = True, ) -> nn.Module: """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 if you want to use pretrained model with simple name, you need to define it here. Args: model_name (str): _description_ sd_unet_model (Tuple[str, nn.Module]): _description_ sd_model (Tuple[str, nn.Module]): _description_ cross_attention_dim (int, optional): _description_. Defaults to 768. dtype (torch.dtype, optional): _description_. Defaults to torch.float16. Raises: ValueError: _description_ Returns: nn.Module: _description_ """ if model_name in ["musev"]: unet = load_unet( sd_unet_model=sd_unet_model, sd_model=sd_model, need_spatial_position_emb=False, cross_attention_dim=cross_attention_dim, need_t2i_ip_adapter=True, need_adain_temporal_cond=True, t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", dtype=dtype, ) elif model_name in [ "musev_referencenet", "musev_referencenet_pose", ]: unet = load_unet( sd_unet_model=sd_unet_model, sd_model=sd_model, cross_attention_dim=cross_attention_dim, temporal_conv_block="TemporalConvLayer", need_transformer_in=False, temporal_transformer="TransformerTemporalModel", use_anivv1_cfg=True, resnet_2d_skip_time_act=True, need_t2i_ip_adapter=True, need_adain_temporal_cond=True, keep_vision_condtion=True, t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", dtype=dtype, need_refer_emb=True, need_zero_vis_cond_temb=True, ip_adapter_cross_attn=True, t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", need_t2i_facein=need_t2i_facein, strict=strict, need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, ) else: raise ValueError( f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose" ) return unet