import copy from typing import Any, Callable, Dict, Iterable, Union import PIL import cv2 import torch import argparse import datetime import logging import inspect import math import os import shutil from typing import Dict, List, Optional, Tuple from pprint import pprint from collections import OrderedDict from dataclasses import dataclass import gc import time import numpy as np from omegaconf import OmegaConf from omegaconf import SCMode import torch from torch import nn import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange, repeat import pandas as pd import h5py from diffusers.models.modeling_utils import load_state_dict from diffusers.utils import ( logging, ) from diffusers.utils.import_utils import is_xformers_available from mmcm.vision.feature_extractor import clip_vision_extractor from mmcm.vision.feature_extractor.clip_vision_extractor import ( ImageClipVisionFeatureExtractor, ImageClipVisionFeatureExtractorV2, VerstailSDLastHiddenState2ImageEmb, ) from ip_adapter.resampler import Resampler from ip_adapter.ip_adapter import ImageProjModel from .unet_loader import update_unet_with_sd from .unet_3d_condition import UNet3DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def load_vision_clip_encoder_by_name( ip_image_encoder: Tuple[str, nn.Module] = None, dtype: torch.dtype = torch.float16, device: str = "cuda", vision_clip_extractor_class_name: str = None, ) -> nn.Module: if vision_clip_extractor_class_name is not None: vision_clip_extractor = getattr( clip_vision_extractor, vision_clip_extractor_class_name )( pretrained_model_name_or_path=ip_image_encoder, dtype=dtype, device=device, ) else: vision_clip_extractor = None return vision_clip_extractor def load_ip_adapter_image_proj_by_name( model_name: str, ip_ckpt: Tuple[str, nn.Module] = None, cross_attention_dim: int = 768, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4, ip_scale: float = 0.0, dtype: torch.dtype = torch.float16, device: str = "cuda", unet: nn.Module = None, vision_clip_extractor_class_name: str = None, ip_image_encoder: Tuple[str, nn.Module] = None, ) -> nn.Module: if model_name in [ "IPAdapter", "musev_referencenet", "musev_referencenet_pose", ]: ip_adapter_image_proj = ImageProjModel( cross_attention_dim=cross_attention_dim, clip_embeddings_dim=clip_embeddings_dim, clip_extra_context_tokens=clip_extra_context_tokens, ) elif model_name == "IPAdapterPlus": vision_clip_extractor = ImageClipVisionFeatureExtractorV2( pretrained_model_name_or_path=ip_image_encoder, dtype=dtype, device=device, ) ip_adapter_image_proj = Resampler( dim=cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=clip_extra_context_tokens, embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, output_dim=cross_attention_dim, ff_mult=4, ) elif model_name in [ "VerstailSDLastHiddenState2ImageEmb", "OriginLastHiddenState2ImageEmbd", "OriginLastHiddenState2Poolout", ]: ip_adapter_image_proj = getattr( clip_vision_extractor, model_name ).from_pretrained(ip_image_encoder) else: raise ValueError( f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb" ) if ip_ckpt is not None: ip_adapter_state_dict = torch.load( ip_ckpt, map_location="cpu", ) ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) if ( unet is not None and unet.ip_adapter_cross_attn and "ip_adapter" in ip_adapter_state_dict ): update_unet_ip_adapter_cross_attn_param( unet, ip_adapter_state_dict["ip_adapter"] ) logger.info( f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" ) return ip_adapter_image_proj def load_ip_adapter_vision_clip_encoder_by_name( model_name: str, ip_ckpt: Tuple[str, nn.Module], ip_image_encoder: Tuple[str, nn.Module] = None, cross_attention_dim: int = 768, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4, ip_scale: float = 0.0, dtype: torch.dtype = torch.float16, device: str = "cuda", unet: nn.Module = None, vision_clip_extractor_class_name: str = None, ) -> nn.Module: if vision_clip_extractor_class_name is not None: vision_clip_extractor = getattr( clip_vision_extractor, vision_clip_extractor_class_name )( pretrained_model_name_or_path=ip_image_encoder, dtype=dtype, device=device, ) else: vision_clip_extractor = None if model_name in [ "IPAdapter", "musev_referencenet", ]: if ip_image_encoder is not None: if vision_clip_extractor_class_name is None: vision_clip_extractor = ImageClipVisionFeatureExtractor( pretrained_model_name_or_path=ip_image_encoder, dtype=dtype, device=device, ) else: vision_clip_extractor = None ip_adapter_image_proj = ImageProjModel( cross_attention_dim=cross_attention_dim, clip_embeddings_dim=clip_embeddings_dim, clip_extra_context_tokens=clip_extra_context_tokens, ) elif model_name == "IPAdapterPlus": if ip_image_encoder is not None: if vision_clip_extractor_class_name is None: vision_clip_extractor = ImageClipVisionFeatureExtractorV2( pretrained_model_name_or_path=ip_image_encoder, dtype=dtype, device=device, ) else: vision_clip_extractor = None ip_adapter_image_proj = Resampler( dim=cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=clip_extra_context_tokens, embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, output_dim=cross_attention_dim, ff_mult=4, ).to(dtype=torch.float16) else: raise ValueError( f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus" ) ip_adapter_state_dict = torch.load( ip_ckpt, map_location="cpu", ) ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) if ( unet is not None and unet.ip_adapter_cross_attn and "ip_adapter" in ip_adapter_state_dict ): update_unet_ip_adapter_cross_attn_param( unet, ip_adapter_state_dict["ip_adapter"] ) logger.info( f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" ) return ( vision_clip_extractor, ip_adapter_image_proj, ) # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 unet_keys_list = [ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", ] ip_adapter_keys_list = [ "1.to_k_ip.weight", "1.to_v_ip.weight", "3.to_k_ip.weight", "3.to_v_ip.weight", "5.to_k_ip.weight", "5.to_v_ip.weight", "7.to_k_ip.weight", "7.to_v_ip.weight", "9.to_k_ip.weight", "9.to_v_ip.weight", "11.to_k_ip.weight", "11.to_v_ip.weight", "13.to_k_ip.weight", "13.to_v_ip.weight", "15.to_k_ip.weight", "15.to_v_ip.weight", "17.to_k_ip.weight", "17.to_v_ip.weight", "19.to_k_ip.weight", "19.to_v_ip.weight", "21.to_k_ip.weight", "21.to_v_ip.weight", "23.to_k_ip.weight", "23.to_v_ip.weight", "25.to_k_ip.weight", "25.to_v_ip.weight", "27.to_k_ip.weight", "27.to_v_ip.weight", "29.to_k_ip.weight", "29.to_v_ip.weight", "31.to_k_ip.weight", "31.to_v_ip.weight", ] UNET2IPAadapter_Keys_MAPIING = { k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) } def update_unet_ip_adapter_cross_attn_param( unet: UNet3DConditionModel, ip_adapter_state_dict: Dict ) -> None: """use independent ip_adapter attn 中的 to_k, to_v in unet ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] Args: unet (UNet3DConditionModel): _description_ ip_adapter_state_dict (Dict): _description_ """ unet_spatial_cross_atnns = unet.spatial_cross_attns[0] unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} for i, (unet_key_more, ip_adapter_key) in enumerate( UNET2IPAadapter_Keys_MAPIING.items() ): ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] unet_key_more_spit = unet_key_more.split(".") unet_key = ".".join(unet_key_more_spit[:-3]) suffix = ".".join(unet_key_more_spit[-3:]) logger.debug( f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", ) if "to_k" in suffix: with torch.no_grad(): unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_( ip_adapter_value.data ) else: with torch.no_grad(): unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_( ip_adapter_value.data )