import copy from typing import Any, Callable, Dict, Iterable, Union import PIL import cv2 import torch import argparse import datetime import logging import inspect import math import os import shutil from typing import Dict, List, Optional, Tuple from pprint import pprint from collections import OrderedDict from dataclasses import dataclass import gc import time import numpy as np from omegaconf import OmegaConf from omegaconf import SCMode import torch from torch import nn import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange, repeat import pandas as pd import h5py from diffusers.models.modeling_utils import load_state_dict from diffusers.utils import ( logging, ) from diffusers.utils.import_utils import is_xformers_available from .referencenet import ReferenceNet2D from .unet_loader import update_unet_with_sd logger = logging.get_logger(__name__) # pylint: disable=invalid-name def load_referencenet( sd_referencenet_model: Tuple[str, nn.Module], sd_model: nn.Module = None, need_self_attn_block_embs: bool = False, need_block_embs: bool = False, dtype: torch.dtype = torch.float16, cross_attention_dim: int = 768, subfolder: str = "unet", ): """ Loads the ReferenceNet model. Args: sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model. sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None. need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False. need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False. dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16. cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768. subfolder (str, optional): The subfolder of the model. Defaults to "unet". Returns: nn.Module: The loaded ReferenceNet model. """ if isinstance(sd_referencenet_model, str): referencenet = ReferenceNet2D.from_pretrained( sd_referencenet_model, subfolder=subfolder, need_self_attn_block_embs=need_self_attn_block_embs, need_block_embs=need_block_embs, torch_dtype=dtype, cross_attention_dim=cross_attention_dim, ) elif isinstance(sd_referencenet_model, nn.Module): referencenet = sd_referencenet_model if sd_model is not None: referencenet = update_unet_with_sd(referencenet, sd_model) return referencenet def load_referencenet_by_name( model_name: str, sd_referencenet_model: Tuple[str, nn.Module], sd_model: nn.Module = None, cross_attention_dim: int = 768, dtype: torch.dtype = torch.float16, ) -> nn.Module: """通过模型名字 初始化 referencenet,载入预训练参数, 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 init referencenet with model_name. if you want to use pretrained model with simple name, you need to define it here. Args: model_name (str): _description_ sd_unet_model (Tuple[str, nn.Module]): _description_ sd_model (Tuple[str, nn.Module]): _description_ cross_attention_dim (int, optional): _description_. Defaults to 768. dtype (torch.dtype, optional): _description_. Defaults to torch.float16. Raises: ValueError: _description_ Returns: nn.Module: _description_ """ if model_name in [ "musev_referencenet", ]: unet = load_referencenet( sd_referencenet_model=sd_referencenet_model, sd_model=sd_model, cross_attention_dim=cross_attention_dim, dtype=dtype, need_self_attn_block_embs=False, need_block_embs=True, subfolder="referencenet", ) else: raise ValueError( f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16" ) return unet