from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import warnings import os import torch.nn as nn import torch.nn.functional as F from diffusers.models.modeling_utils import ModelMixin import PIL from einops import rearrange, repeat import numpy as np import torch import torch.nn.init as init from diffusers.models.controlnet import ControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers from diffusers.utils.torch_utils import is_compiled_module class ControlnetPredictor(object): def __init__(self, controlnet_model_path: str, *args, **kwargs): """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取 Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training Args: controlnet_model_path (str): controlnet 模型路径. controlnet model path. """ super(ControlnetPredictor, self).__init__(*args, **kwargs) self.controlnet = ControlNetModel.from_pretrained( controlnet_model_path, ) def prepare_image( self, image, # b c t h w width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if height is None: height = image.shape[-2] if width is None: width = image.shape[-1] width, height = ( x - x % self.control_image_processor.vae_scale_factor for x in (width, height) ) image = rearrange(image, "b c t h w-> (b t) c h w") image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 image = ( torch.nn.functional.interpolate( image, size=(height, width), mode="bilinear", ), ) do_normalize = self.control_image_processor.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = self.control_image_processor.normalize(image) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @torch.no_grad() def __call__( self, batch_size: int, device: str, dtype: torch.dtype, timesteps: List[float], i: int, scheduler: KarrasDiffusionSchedulers, prompt_embeds: torch.Tensor, do_classifier_free_guidance: bool = False, # 2b co t ho wo latent_model_input: torch.Tensor = None, # b co t ho wo latents: torch.Tensor = None, # b c t h w image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, # b c t(1) hi wi controlnet_condition_frames: Optional[torch.FloatTensor] = None, # b c t ho wo controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, # b c t(1) ho wo controlnet_condition_latents: Optional[torch.FloatTensor] = None, height: Optional[int] = None, width: Optional[int] = None, num_videos_per_prompt: Optional[int] = 1, return_dict: bool = True, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, latent_index: torch.LongTensor = None, vision_condition_latent_index: torch.LongTensor = None, **kwargs, ): assert ( image is None and controlnet_latents is None ), "should set one of image and controlnet_latents" controlnet = ( self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet ) # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance( control_guidance_end, list ): control_guidance_start = len(control_guidance_end) * [ control_guidance_start ] elif not isinstance(control_guidance_end, list) and isinstance( control_guidance_start, list ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance( control_guidance_end, list ): mult = ( len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 ) control_guidance_start, control_guidance_end = mult * [ control_guidance_start ], mult * [control_guidance_end] if isinstance(controlnet, MultiControlNetModel) and isinstance( controlnet_conditioning_scale, float ): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( controlnet.nets ) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 4. Prepare image if isinstance(controlnet, ControlNetModel): if ( controlnet_latents is not None and controlnet_condition_latents is not None ): if isinstance(controlnet_latents, np.ndarray): controlnet_latents = torch.from_numpy(controlnet_latents) if isinstance(controlnet_condition_latents, np.ndarray): controlnet_condition_latents = torch.from_numpy( controlnet_condition_latents ) # TODO:使用index进行concat controlnet_latents = torch.concat( [controlnet_condition_latents, controlnet_latents], dim=2 ) if not guess_mode and do_classifier_free_guidance: controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) controlnet_latents = rearrange( controlnet_latents, "b c t h w->(b t) c h w" ) controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) else: # TODO:使用index进行concat # TODO: concat with index if controlnet_condition_frames is not None: if isinstance(controlnet_condition_frames, np.ndarray): image = np.concatenate( [controlnet_condition_frames, image], axis=2 ) image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # TODO: 支持直接使用controlnet_latent而不是frames # TODO: support using controlnet_latent directly instead of frames if controlnet_latents is not None: raise NotImplementedError else: for i, image_ in enumerate(image): if controlnet_condition_frames is not None and isinstance( controlnet_condition_frames, list ): if isinstance(controlnet_condition_frames[i], np.ndarray): image_ = np.concatenate( [controlnet_condition_frames[i], image_], axis=2 ) image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_videos_per_prompt, num_images_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append( keeps[0] if isinstance(controlnet, ControlNetModel) else keeps ) t = timesteps[i] # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [ c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) ] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] control_model_input_reshape = rearrange( control_model_input, "b c t h w -> (b t) c h w" ) encoder_hidden_states_repeat = repeat( controlnet_prompt_embeds, "b n q->(b t) n q", t=control_model_input.shape[2], ) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input_reshape, t, encoder_hidden_states_repeat, controlnet_cond=image, controlnet_cond_latents=controlnet_latents, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) return down_block_res_samples, mid_block_res_sample class InflatedConv3d(nn.Conv2d): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module class PoseGuider(ModelMixin): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 64, 128), ): super().__init__() self.conv_in = InflatedConv3d( conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 ) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append( InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) ) self.blocks.append( InflatedConv3d( channel_in, channel_out, kernel_size=3, padding=1, stride=2 ) ) self.conv_out = zero_module( InflatedConv3d( block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1, ) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding @classmethod def from_pretrained( cls, pretrained_model_path, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 64, 128), ): if not os.path.exists(pretrained_model_path): print(f"There is no model file in {pretrained_model_path}") print( f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..." ) state_dict = torch.load(pretrained_model_path, map_location="cpu") model = PoseGuider( conditioning_embedding_channels=conditioning_embedding_channels, conditioning_channels=conditioning_channels, block_out_channels=block_out_channels, ) m, u = model.load_state_dict(state_dict, strict=False) # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") params = [p.numel() for n, p in model.named_parameters()] print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") return model