|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Any, Dict, Optional |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import diffusers |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.models.attention_processor import ( |
|
Attention, |
|
AttnProcessor, |
|
XFormersAttnProcessor, |
|
AttnProcessor2_0 |
|
) |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDPMScheduler, |
|
DiffusionPipeline, |
|
EulerAncestralDiscreteScheduler, |
|
UNet2DConditionModel, |
|
ImagePipelineOutput |
|
) |
|
import transformers |
|
from transformers import ( |
|
CLIPImageProcessor, |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
CLIPVisionModelWithProjection, |
|
CLIPTextModelWithProjection |
|
) |
|
|
|
from .utils import to_rgb_image, white_out_background, recenter_img |
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers import Hunyuan3d_MVD_XL_Pipeline |
|
|
|
>>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained( |
|
... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe.to("cuda") |
|
|
|
>>> img = Image.open("demo.png") |
|
>>> res_img = pipe(img).images[0] |
|
``` |
|
""" |
|
|
|
|
|
|
|
def scale_latents(latents): return (latents - 0.22) * 0.75 |
|
def unscale_latents(latents): return (latents / 0.75) + 0.22 |
|
def scale_image(image): return (image - 0.5) / 0.5 |
|
def scale_image_2(image): return (image * 0.5) / 0.8 |
|
def unscale_image(image): return (image * 0.5) + 0.5 |
|
def unscale_image_2(image): return (image * 0.8) / 0.5 |
|
|
|
|
|
|
|
|
|
class ReferenceOnlyAttnProc(torch.nn.Module): |
|
def __init__(self, chained_proc, enabled=False, name=None): |
|
super().__init__() |
|
self.enabled = enabled |
|
self.chained_proc = chained_proc |
|
self.name = name |
|
|
|
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None): |
|
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states |
|
if self.enabled: |
|
if mode == 'w': ref_dict[self.name] = encoder_hidden_states |
|
elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) |
|
else: raise Exception(f"mode should not be {mode}") |
|
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) |
|
|
|
|
|
class RefOnlyNoisedUNet(torch.nn.Module): |
|
def __init__(self, unet, scheduler) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.scheduler = scheduler |
|
|
|
unet_attn_procs = dict() |
|
for name, _ in unet.attn_processors.items(): |
|
if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0() |
|
elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor() |
|
else: default_attn_proc = AttnProcessor() |
|
unet_attn_procs[name] = ReferenceOnlyAttnProc( |
|
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name |
|
) |
|
unet.set_attn_processor(unet_attn_procs) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
class_labels: Optional[torch.Tensor] = None, |
|
down_block_res_samples: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
return_dict: bool = True, |
|
**kwargs |
|
): |
|
|
|
dtype = self.unet.dtype |
|
|
|
|
|
cond_lat = cross_attention_kwargs['cond_lat'] |
|
noise = torch.randn_like(cond_lat) |
|
|
|
noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1)) |
|
noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) |
|
|
|
ref_dict = {} |
|
|
|
_ = self.unet( |
|
noisy_cond_lat, |
|
timestep, |
|
encoder_hidden_states = encoder_hidden_states, |
|
class_labels = class_labels, |
|
cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict), |
|
added_cond_kwargs = added_cond_kwargs, |
|
return_dict = return_dict, |
|
**kwargs |
|
) |
|
|
|
res = self.unet( |
|
sample, |
|
timestep, |
|
encoder_hidden_states, |
|
class_labels=class_labels, |
|
cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict), |
|
down_block_additional_residuals = [ |
|
sample.to(dtype=dtype) for sample in down_block_res_samples |
|
] if down_block_res_samples is not None else None, |
|
mid_block_additional_residual = ( |
|
mid_block_res_sample.to(dtype=dtype) |
|
if mid_block_res_sample is not None else None), |
|
added_cond_kwargs = added_cond_kwargs, |
|
return_dict = return_dict, |
|
**kwargs |
|
) |
|
return res |
|
|
|
|
|
|
|
class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline): |
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
feature_extractor_vae: CLIPImageProcessor, |
|
vision_processor: CLIPImageProcessor, |
|
vision_encoder: CLIPVisionModelWithProjection, |
|
vision_encoder_2: CLIPVisionModelWithProjection, |
|
ramping_coefficients: Optional[list] = None, |
|
add_watermarker: Optional[bool] = None, |
|
safety_checker = None, |
|
): |
|
DiffusionPipeline.__init__(self) |
|
|
|
self.register_modules( |
|
vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae, |
|
vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2, |
|
) |
|
self.register_to_config( ramping_coefficients = ramping_coefficients) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.default_sample_size = self.unet.config.sample_size |
|
self.watermark = None |
|
self.prepare_init = False |
|
|
|
def prepare(self): |
|
assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel" |
|
self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval() |
|
self.prepare_init = True |
|
|
|
def encode_image(self, image: torch.Tensor, scale_factor: bool = False): |
|
latent = self.vae.encode(image).latent_dist.sample() |
|
return (latent * self.vae.config.scaling_factor) if scale_factor else latent |
|
|
|
|
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
int(height) // self.vae_scale_factor, |
|
int(width) // self.vae_scale_factor, |
|
) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
return latents |
|
|
|
def _get_add_time_ids( |
|
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None |
|
): |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
|
|
passed_add_embed_dim = ( |
|
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim |
|
) |
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
|
|
|
if expected_add_embed_dim != passed_add_embed_dim: |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \ |
|
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \ |
|
f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." |
|
) |
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
|
return add_time_ids |
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
@property |
|
def guidance_scale(self): |
|
return self._guidance_scale |
|
|
|
@property |
|
def interrupt(self): |
|
return self._interrupt |
|
|
|
@property |
|
def do_classifier_free_guidance(self): |
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
image: Image.Image = None, |
|
guidance_scale = 2.0, |
|
output_type: Optional[str] = "pil", |
|
num_inference_steps: int = 50, |
|
return_dict: bool = True, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
latent: torch.Tensor = None, |
|
guidance_curve = None, |
|
**kwargs |
|
): |
|
if not self.prepare_init: |
|
self.prepare() |
|
|
|
here = dict(device=self.vae.device, dtype=self.vae.dtype) |
|
|
|
batch_size = 1 |
|
num_images_per_prompt = 1 |
|
width, height = 512 * 2, 512 * 3 |
|
target_size = original_size = (height, width) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
self._interrupt = False |
|
|
|
device = self._execution_device |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
self.vae.dtype, |
|
device, |
|
generator, |
|
latents=latent, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
|
|
text_encoder_projection_dim = 1280 |
|
add_time_ids = self._get_add_time_ids( |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
dtype=self.vae.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
negative_add_time_ids = add_time_ids |
|
|
|
|
|
cond_image = recenter_img(image) |
|
cond_image = to_rgb_image(image) |
|
image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here) |
|
image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here) |
|
|
|
|
|
cond_lat = self.encode_image(image_vae, scale_factor=False) |
|
negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False) |
|
cond_lat = torch.cat([negative_lat, cond_lat]) |
|
|
|
|
|
global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2) |
|
global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2) |
|
global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1) |
|
|
|
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) |
|
prompt_embeds = self.uc_text_emb.to(**here) |
|
pooled_prompt_embeds = self.uc_text_emb_2.to(**here) |
|
|
|
prompt_embeds = prompt_embeds + global_embeds * ramp |
|
add_text_embeds = pooled_prompt_embeds |
|
|
|
if self.do_classifier_free_guidance: |
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
|
|
|
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
timestep_cond = None |
|
self._num_timesteps = len(timesteps) |
|
|
|
if guidance_curve is None: |
|
guidance_curve = lambda t: guidance_scale |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=dict(cond_lat=cond_lat), |
|
added_cond_kwargs=added_cond_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
|
|
|
|
cur_guidance_scale = guidance_curve(t) |
|
|
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latents_dtype = latents.dtype |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
latents = unscale_latents(latents) |
|
|
|
if output_type=="latent": |
|
image = latents |
|
else: |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image = unscale_image(unscale_image_2(image)).clamp(0, 1) |
|
image = [ |
|
Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")), |
|
|
|
cond_image.resize((512, 512)) |
|
] |
|
|
|
if not return_dict: return (image,) |
|
return ImagePipelineOutput(images=image) |
|
|
|
def save_pretrained(self, save_directory): |
|
|
|
super().save_pretrained(save_directory) |
|
torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt")) |
|
torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt")) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt")) |
|
pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt")) |
|
return pipeline |
|
|