Spatial_Control_for_SD / modules /model_k_diffusion.py
duongve's picture
Upload 24 files
7ef93e7 verified
import importlib
import inspect
import math
from pathlib import Path
import re
from collections import defaultdict
from typing import List, Optional, Union
import cv2
import time
import k_diffusion
import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .external_k_diffusion import CompVisDenoiser, CompVisVDenoiser
#from .prompt_parser import FrozenCLIPEmbedderWithCustomWords
from torch import einsum
from torch.autograd.function import Function
from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor,is_compiled_module
from diffusers.image_processor import VaeImageProcessor,PipelineImageInput
from safetensors.torch import load_file
from diffusers import ControlNetModel
from PIL import Image
import torchvision.transforms as transforms
from diffusers.models import AutoencoderKL, ImageProjection
from .ip_adapter import IPAdapterMixin
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
import gc
from .t2i_adapter import preprocessing_t2i_adapter,default_height_width
from .encoder_prompt_modify import encode_prompt_function
from .encode_region_map_function import encode_region_map
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import TextualInversionLoaderMixin
def get_image_size(image):
height, width = None, None
if isinstance(image, Image.Image):
return image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
return (width, height)
elif torch.is_tensor(image):
#RGB image
if len(image.shape) == 3:
_, height, width = image.shape
else:
height, width = image.shape
return (width, height)
else:
raise TypeError("The image must be an instance of PIL.Image, numpy.ndarray, or torch.Tensor.")
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class ModelWrapper:
def __init__(self, model, alphas_cumprod):
self.model = model
self.alphas_cumprod = alphas_cumprod
def apply_model(self, *args, **kwargs):
if len(args) == 3:
encoder_hidden_states = args[-1]
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
return self.model(
*args, encoder_hidden_states=encoder_hidden_states, **kwargs
).sample
class StableDiffusionPipeline(IPAdapterMixin,DiffusionPipeline,StableDiffusionMixin,LoraLoaderMixin,TextualInversionLoaderMixin):
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
feature_extractor,
image_encoder = None,
):
super().__init__()
# get correct sigmas from LMS
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.controlnet = None
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.setup_unet(self.unet)
#self.setup_text_encoder()
'''def setup_text_encoder(self, n=1, new_encoder=None):
if new_encoder is not None:
self.text_encoder = new_encoder
self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder,n)'''
#self.prompt_parser.CLIP_stop_at_last_layers = n
def setup_unet(self, unet):
unet = unet.to(self.device)
model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
if self.scheduler.config.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
def get_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
return getattr(sampling, scheduler_type)
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [
self.unet,
self.text_encoder,
self.vae,
self.safety_checker,
]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def decode_latents(self, latents):
latents = latents.to(self.device, dtype=self.vae.dtype)
#latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def _default_height_width(self, height, width, image):
if isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
@property
def do_classifier_free_guidance(self):
return self._do_classifier_free_guidance and self.unet.config.time_cond_proj_dim is None
def setup_controlnet(self,controlnet):
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules(
controlnet=controlnet,
)
def preprocess_controlnet(self,controlnet_conditioning_scale,control_guidance_start,control_guidance_end,image,width,height,num_inference_steps,batch_size,num_images_per_prompt):
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = False or global_pool_conditions
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
device=self._execution_device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel):
images = []
for image_ in image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
device=self._execution_device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
image = images
height, width = image[0].shape[-2:]
else:
assert False
# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
return image,controlnet_keep,guess_mode,controlnet_conditioning_scale
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(
shape, generator=generator, device="cpu", dtype=dtype
).to(device)
else:
latents = torch.randn(
shape, generator=generator, device=device, dtype=dtype
)
else:
# if latents.shape != shape:
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
return latents
def preprocess(self, image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
image = [
np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
None, :
]
for i in image
]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
#image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def numpy_to_pil(self,images):
r"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
#images = (images * 255).round().astype("uint8")
images = np.clip((images * 255).round(), 0, 255).astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def latent_to_image(self,latent,output_type):
image = self.decode_latents(latent)
if output_type == "pil":
image = self.numpy_to_pil(image)
if len(image) > 1:
return image
return image[0]
@torch.no_grad()
def img2img(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[torch.Generator] = None,
image: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
latents=None,
strength=1.0,
region_map_state=None,
sampler_name="",
sampler_opt={},
start_time=-1,
timeout=180,
scale_ratio=8.0,
latent_processing = 0,
weight_func = lambda w, sigma, qk: w * sigma * qk.std(),
upscale=False,
upscale_x: float = 2.0,
upscale_method: str = "bicubic",
upscale_antialias: bool = False,
upscale_denoising_strength: int = 0.7,
width = None,
height = None,
seed = 0,
sampler_name_hires="",
sampler_opt_hires= {},
latent_upscale_processing = False,
ip_adapter_image = None,
control_img = None,
controlnet_conditioning_scale = None,
control_guidance_start = None,
control_guidance_end = None,
image_t2i_adapter : Optional[PipelineImageInput] = None,
adapter_conditioning_scale: Union[float, List[float]] = 1.0,
adapter_conditioning_factor: float = 1.0,
guidance_rescale: float = 0.0,
cross_attention_kwargs = None,
clip_skip = None,
long_encode = 0,
num_images_per_prompt = 1,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
):
if isinstance(sampler_name, str):
sampler = self.get_scheduler(sampler_name)
else:
sampler = sampler_name
if height is None:
_,height = get_image_size(image)
height = int((height // 8)*8)
if width is None:
width,_ = get_image_size(image)
width = int((width // 8)*8)
if image_t2i_adapter is not None:
height, width = default_height_width(self,height, width, image_t2i_adapter)
if image is not None:
image = self.preprocess(image)
image = image.to(self.vae.device, dtype=self.vae.dtype)
init_latents = self.vae.encode(image).latent_dist.sample(generator)
latents = 0.18215 * init_latents
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
latents = latents.to(device, dtype=self.unet.dtype)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
self._do_classifier_free_guidance = False if guidance_scale <= 1.0 else True
'''if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")'''
# 3. Encode input prompt
text_embeddings, negative_prompt_embeds, text_input_ids = encode_prompt_function(
self,
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
lora_scale = lora_scale,
clip_skip = clip_skip,
long_encode = long_encode,
)
if self.do_classifier_free_guidance:
text_embeddings = torch.cat([negative_prompt_embeds, text_embeddings])
#text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
text_embeddings = text_embeddings.to(self.unet.dtype)
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
text_embeddings.device, dtype=text_embeddings.dtype
)
sigma_sched = sigmas[t_start:]
noise = randn_tensor(
latents.shape,
generator=generator,
device=device,
dtype=text_embeddings.dtype,
)
latents = latents.to(device)
latents = latents + noise * (sigma_sched[0]**2 + 1) ** 0.5
#latents = latents + noise * sigma_sched[0] #Nearly
steps_denoising = len(sigma_sched)
# 5. Prepare latent variables
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
latents.device
)
region_state = encode_region_map(
self,
region_map_state,
width = width,
height = height,
num_images_per_prompt = num_images_per_prompt,
text_ids=text_input_ids,
)
if cross_attention_kwargs is None:
cross_attention_kwargs ={}
controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy() if isinstance(controlnet_conditioning_scale, list) else controlnet_conditioning_scale
control_guidance_start_copy = control_guidance_start.copy() if isinstance(control_guidance_start, list) else control_guidance_start
control_guidance_end_copy = control_guidance_end.copy() if isinstance(control_guidance_end, list) else control_guidance_end
guess_mode = False
if self.controlnet is not None:
img_control,controlnet_keep,guess_mode,controlnet_conditioning_scale = self.preprocess_controlnet(controlnet_conditioning_scale,control_guidance_start,control_guidance_end,control_img,width,height,len(sigma_sched),batch_size,num_images_per_prompt)
#print(len(controlnet_keep))
#controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy()
#sp_control = 1
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
else None
)
#if controlnet_img is not None:
#controlnet_img_processing = controlnet_img.convert("RGB")
#transform = transforms.Compose([transforms.PILToTensor()])
#controlnet_img_processing = transform(controlnet_img)
#controlnet_img_processing=controlnet_img_processing.to(device=device, dtype=self.cnet.dtype)
#controlnet_img = torch.from_numpy(controlnet_img).half()
#controlnet_img = controlnet_img.unsqueeze(0)
#controlnet_img = controlnet_img.repeat_interleave(3, dim=0)
#controlnet_img=controlnet_img.to(device)
#controlnet_img = controlnet_img.repeat_interleave(4 // len(controlnet_img), 0)
if latent_processing == 1:
latents_process = [self.latent_to_image(latents,output_type)]
lst_latent_sigma = []
step_control = -1
adapter_state = None
adapter_sp_count = []
if image_t2i_adapter is not None:
adapter_state = preprocessing_t2i_adapter(self,image_t2i_adapter,width,height,adapter_conditioning_scale,1)
def model_fn(x, sigma):
nonlocal step_control,lst_latent_sigma,adapter_sp_count
if start_time > 0 and timeout > 0:
assert (time.time() - start_time) < timeout, "inference process timed out"
latent_model_input = torch.cat([x] * 2) if self.do_classifier_free_guidance else x
region_prompt = {
"region_state": region_state,
"sigma": sigma[0],
"weight_func": weight_func,
}
cross_attention_kwargs["region_prompt"] = region_prompt
#print(self.k_diffusion_model.sigma_to_t(sigma[0]))
if latent_model_input.dtype != text_embeddings.dtype:
latent_model_input = latent_model_input.to(text_embeddings.dtype)
ukwargs = {}
down_intrablock_additional_residuals = None
if adapter_state is not None:
if len(adapter_sp_count) < int( steps_denoising* adapter_conditioning_factor):
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
down_intrablock_additional_residuals = None
sigma_string_t2i = str(sigma.item())
if sigma_string_t2i not in adapter_sp_count:
adapter_sp_count.append(sigma_string_t2i)
if self.controlnet is not None :
sigma_string = str(sigma.item())
if sigma_string not in lst_latent_sigma:
#sigmas_sp = sigma.detach().clone()
step_control+=1
lst_latent_sigma.append(sigma_string)
if isinstance(controlnet_keep[step_control], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[step_control])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[step_control]
down_block_res_samples = None
mid_block_res_sample = None
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input / ((sigma**2 + 1) ** 0.5),
self.k_diffusion_model.sigma_to_t(sigma),
encoder_hidden_states=text_embeddings,
controlnet_cond=img_control,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
ukwargs ={
"down_block_additional_residuals": down_block_res_samples,
"mid_block_additional_residual":mid_block_res_sample,
}
noise_pred = self.k_diffusion_model(
latent_model_input, sigma, cond=text_embeddings,cross_attention_kwargs = cross_attention_kwargs,down_intrablock_additional_residuals = down_intrablock_additional_residuals,added_cond_kwargs=added_cond_kwargs, **ukwargs
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
if latent_processing == 1:
latents_process.append(self.latent_to_image(noise_pred,output_type))
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=0.7)
return noise_pred
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched,len(sigma_sched),sampler_opt,latents,seed, sampler)
latents = sampler(model_fn, latents, **sampler_args)
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
gc.collect()
if upscale:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
target_height = int(height * upscale_x // vae_scale_factor )* 8
target_width = int(width * upscale_x // vae_scale_factor)*8
latents = torch.nn.functional.interpolate(
latents,
size=(
int(target_height // vae_scale_factor),
int(target_width // vae_scale_factor),
),
mode=upscale_method,
antialias=upscale_antialias,
)
#if controlnet_img is not None:
#controlnet_img = cv2.resize(controlnet_img,(latents.size(0), latents.size(1)))
#controlnet_img=controlnet_img.resize((latents.size(0), latents.size(1)), Image.LANCZOS)
#region_map_state = apply_size_sketch(int(target_width),int(target_height),region_map_state)
latent_reisze= self.img2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
latents=latents,
strength=upscale_denoising_strength,
sampler_name=sampler_name_hires,
sampler_opt=sampler_opt_hires,
region_map_state=region_map_state,
latent_processing = latent_upscale_processing,
width = int(target_width),
height = int(target_height),
seed = seed,
ip_adapter_image = ip_adapter_image,
control_img = control_img,
controlnet_conditioning_scale = controlnet_conditioning_scale_copy,
control_guidance_start = control_guidance_start_copy,
control_guidance_end = control_guidance_end_copy,
image_t2i_adapter= image_t2i_adapter,
adapter_conditioning_scale = adapter_conditioning_scale,
adapter_conditioning_factor = adapter_conditioning_factor,
guidance_rescale = guidance_rescale,
cross_attention_kwargs = cross_attention_kwargs,
clip_skip = clip_skip,
long_encode = long_encode,
num_images_per_prompt = num_images_per_prompt,
)
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
image[-1] = latent_reisze
return image'''
if latent_processing == 1:
latents_process= latents_process+latent_reisze
return latents_process
torch.cuda.empty_cache()
gc.collect()
return latent_reisze
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
# 10. Convert to PIL
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
else:
image = self.decode_latents(latents)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)'''
if latent_processing == 1:
return latents_process
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
gc.collect()
return [self.latent_to_image(latents,output_type)]
def get_sigmas(self, steps, params):
discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
steps += 1 if discard_next_to_last_sigma else 0
if params.get("scheduler", None) == "karras":
sigma_min, sigma_max = (
self.k_diffusion_model.sigmas[0].item(),
self.k_diffusion_model.sigmas[-1].item(),
)
sigmas = k_diffusion.sampling.get_sigmas_karras(
n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
)
elif params.get("scheduler", None) == "exponential":
sigma_min, sigma_max = (
self.k_diffusion_model.sigmas[0].item(),
self.k_diffusion_model.sigmas[-1].item(),
)
sigmas = k_diffusion.sampling.get_sigmas_exponential(
n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
)
elif params.get("scheduler", None) == "polyexponential":
sigma_min, sigma_max = (
self.k_diffusion_model.sigmas[0].item(),
self.k_diffusion_model.sigmas[-1].item(),
)
sigmas = k_diffusion.sampling.get_sigmas_polyexponential(
n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
)
else:
sigmas = self.k_diffusion_model.get_sigmas(steps)
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def create_noise_sampler(self, x, sigmas, p,seed):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
#current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
def get_sampler_extra_args_t2i(self, sigmas, eta, steps,sampler_opt,latents,seed, func):
extra_params_kwargs = {}
if "eta" in inspect.signature(func).parameters:
extra_params_kwargs["eta"] = eta
if "sigma_min" in inspect.signature(func).parameters:
extra_params_kwargs["sigma_min"] = sigmas[0].item()
extra_params_kwargs["sigma_max"] = sigmas[-1].item()
if "n" in inspect.signature(func).parameters:
extra_params_kwargs["n"] = steps
else:
extra_params_kwargs["sigmas"] = sigmas
if sampler_opt.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(latents, sigmas, steps,seed)
extra_params_kwargs['noise_sampler'] = noise_sampler
if sampler_opt.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
return extra_params_kwargs
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
def get_sampler_extra_args_i2i(self, sigmas,steps,sampler_opt,latents,seed, func):
extra_params_kwargs = {}
if "sigma_min" in inspect.signature(func).parameters:
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
extra_params_kwargs["sigma_min"] = sigmas[-2]
if "sigma_max" in inspect.signature(func).parameters:
extra_params_kwargs["sigma_max"] = sigmas[0]
if "n" in inspect.signature(func).parameters:
extra_params_kwargs["n"] = len(sigmas) - 1
if "sigma_sched" in inspect.signature(func).parameters:
extra_params_kwargs["sigma_sched"] = sigmas
if "sigmas" in inspect.signature(func).parameters:
extra_params_kwargs["sigmas"] = sigmas
if sampler_opt.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(latents, sigmas, steps,seed)
extra_params_kwargs['noise_sampler'] = noise_sampler
if sampler_opt.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
return extra_params_kwargs
@torch.no_grad()
def txt2img(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
callback_steps: Optional[int] = 1,
upscale=False,
upscale_x: float = 2.0,
upscale_method: str = "bicubic",
upscale_antialias: bool = False,
upscale_denoising_strength: int = 0.7,
region_map_state=None,
sampler_name="",
sampler_opt={},
start_time=-1,
timeout=180,
latent_processing = 0,
weight_func = lambda w, sigma, qk: w * sigma * qk.std(),
seed = 0,
sampler_name_hires= "",
sampler_opt_hires= {},
latent_upscale_processing = False,
ip_adapter_image = None,
control_img = None,
controlnet_conditioning_scale = None,
control_guidance_start = None,
control_guidance_end = None,
image_t2i_adapter : Optional[PipelineImageInput] = None,
adapter_conditioning_scale: Union[float, List[float]] = 1.0,
adapter_conditioning_factor: float = 1.0,
guidance_rescale: float = 0.0,
cross_attention_kwargs = None,
clip_skip = None,
long_encode = 0,
num_images_per_prompt = 1,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
):
height, width = self._default_height_width(height, width, None)
if isinstance(sampler_name, str):
sampler = self.get_scheduler(sampler_name)
else:
sampler = sampler_name
# 1. Check inputs. Raise error if not correct
if image_t2i_adapter is not None:
height, width = default_height_width(self,height, width, image_t2i_adapter)
#print(default_height_width(self,height, width, image_t2i_adapter))
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
'''do_classifier_free_guidance = True
if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")'''
lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
self._do_classifier_free_guidance = False if guidance_scale <= 1.0 else True
'''if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")'''
# 3. Encode input prompt
text_embeddings, negative_prompt_embeds, text_input_ids = encode_prompt_function(
self,
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
lora_scale = lora_scale,
clip_skip = clip_skip,
long_encode = long_encode,
)
if self.do_classifier_free_guidance:
text_embeddings = torch.cat([negative_prompt_embeds, text_embeddings])
# 3. Encode input prompt
#text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
text_embeddings = text_embeddings.to(self.unet.dtype)
# 4. Prepare timesteps
sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
text_embeddings.device, dtype=text_embeddings.dtype
)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
latents = latents * (sigmas[0]**2 + 1) ** 0.5
#latents = latents * sigmas[0]#Nearly
steps_denoising = len(sigmas)
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
latents.device
)
region_state = encode_region_map(
self,
region_map_state,
width = width,
height = height,
num_images_per_prompt = num_images_per_prompt,
text_ids=text_input_ids,
)
if cross_attention_kwargs is None:
cross_attention_kwargs ={}
controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy() if isinstance(controlnet_conditioning_scale, list) else controlnet_conditioning_scale
control_guidance_start_copy = control_guidance_start.copy() if isinstance(control_guidance_start, list) else control_guidance_start
control_guidance_end_copy = control_guidance_end.copy() if isinstance(control_guidance_end, list) else control_guidance_end
guess_mode = False
if self.controlnet is not None:
img_control,controlnet_keep,guess_mode,controlnet_conditioning_scale = self.preprocess_controlnet(controlnet_conditioning_scale,control_guidance_start,control_guidance_end,control_img,width,height,num_inference_steps,batch_size,num_images_per_prompt)
#print(len(controlnet_keep))
#controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy()
#sp_control = 1
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
else None
)
#if controlnet_img is not None:
#controlnet_img_processing = controlnet_img.convert("RGB")
#transform = transforms.Compose([transforms.PILToTensor()])
#controlnet_img_processing = transform(controlnet_img)
#controlnet_img_processing=controlnet_img_processing.to(device=device, dtype=self.cnet.dtype)
if latent_processing == 1:
latents_process = [self.latent_to_image(latents,output_type)]
#sp_find_new = None
lst_latent_sigma = []
step_control = -1
adapter_state = None
adapter_sp_count = []
if image_t2i_adapter is not None:
adapter_state = preprocessing_t2i_adapter(self,image_t2i_adapter,width,height,adapter_conditioning_scale,1)
def model_fn(x, sigma):
nonlocal step_control,lst_latent_sigma,adapter_sp_count
if start_time > 0 and timeout > 0:
assert (time.time() - start_time) < timeout, "inference process timed out"
latent_model_input = torch.cat([x] * 2) if self.do_classifier_free_guidance else x
region_prompt = {
"region_state": region_state,
"sigma": sigma[0],
"weight_func": weight_func,
}
cross_attention_kwargs["region_prompt"] = region_prompt
#print(self.k_diffusion_model.sigma_to_t(sigma[0]))
if latent_model_input.dtype != text_embeddings.dtype:
latent_model_input = latent_model_input.to(text_embeddings.dtype)
ukwargs = {}
down_intrablock_additional_residuals = None
if adapter_state is not None:
if len(adapter_sp_count) < int( steps_denoising* adapter_conditioning_factor):
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
down_intrablock_additional_residuals = None
sigma_string_t2i = str(sigma.item())
if sigma_string_t2i not in adapter_sp_count:
adapter_sp_count.append(sigma_string_t2i)
if self.controlnet is not None :
sigma_string = str(sigma.item())
if sigma_string not in lst_latent_sigma:
#sigmas_sp = sigma.detach().clone()
step_control+=1
lst_latent_sigma.append(sigma_string)
if isinstance(controlnet_keep[step_control], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[step_control])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[step_control]
down_block_res_samples = None
mid_block_res_sample = None
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input / ((sigma**2 + 1) ** 0.5),
self.k_diffusion_model.sigma_to_t(sigma),
encoder_hidden_states=text_embeddings,
controlnet_cond=img_control,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
ukwargs ={
"down_block_additional_residuals": down_block_res_samples,
"mid_block_additional_residual":mid_block_res_sample,
}
noise_pred = self.k_diffusion_model(
latent_model_input, sigma, cond=text_embeddings,cross_attention_kwargs=cross_attention_kwargs,down_intrablock_additional_residuals=down_intrablock_additional_residuals,added_cond_kwargs=added_cond_kwargs, **ukwargs
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
if latent_processing == 1:
latents_process.append(self.latent_to_image(noise_pred,output_type))
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=0.7)
return noise_pred
extra_args = self.get_sampler_extra_args_t2i(
sigmas, eta, num_inference_steps,sampler_opt,latents,seed, sampler
)
latents = sampler(model_fn, latents, **extra_args)
#latents = latents_process[0]
#print(len(latents_process))
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
gc.collect()
if upscale:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
target_height = int(height * upscale_x // vae_scale_factor )* 8
target_width = int(width * upscale_x // vae_scale_factor)*8
latents = torch.nn.functional.interpolate(
latents,
size=(
int(target_height // vae_scale_factor),
int(target_width // vae_scale_factor),
),
mode=upscale_method,
antialias=upscale_antialias,
)
#if controlnet_img is not None:
#controlnet_img = cv2.resize(controlnet_img,(latents.size(0), latents.size(1)))
#controlnet_img=controlnet_img.resize((latents.size(0), latents.size(1)), Image.LANCZOS)
latent_reisze= self.img2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
latents=latents,
strength=upscale_denoising_strength,
sampler_name=sampler_name_hires,
sampler_opt=sampler_opt_hires,
region_map_state = region_map_state,
latent_processing = latent_upscale_processing,
width = int(target_width),
height = int(target_height),
seed = seed,
ip_adapter_image = ip_adapter_image,
control_img = control_img,
controlnet_conditioning_scale = controlnet_conditioning_scale_copy,
control_guidance_start = control_guidance_start_copy,
control_guidance_end = control_guidance_end_copy,
image_t2i_adapter= image_t2i_adapter,
adapter_conditioning_scale = adapter_conditioning_scale,
adapter_conditioning_factor = adapter_conditioning_factor,
guidance_rescale = guidance_rescale,
cross_attention_kwargs = cross_attention_kwargs,
clip_skip = clip_skip,
long_encode = long_encode,
num_images_per_prompt = num_images_per_prompt,
)
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
image[-1] = latent_reisze
return image'''
if latent_processing == 1:
latents_process= latents_process+latent_reisze
return latents_process
torch.cuda.empty_cache()
gc.collect()
return latent_reisze
# 8. Post-processing
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
# 10. Convert to PIL
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
else:
image = self.decode_latents(latents)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)'''
if latent_processing == 1:
return latents_process
return [self.latent_to_image(latents,output_type)]
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
'''def get_image_latents(self,batch_size,image,device,dtype,generator):
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
return image_latents'''
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
def add_noise(self,init_latents_proper,noise,sigma):
if isinstance(sigma, torch.Tensor) and sigma.numel() > 1:
sigma,_ = sigma.sort(descending=True)
sigma = sigma[0].item()
#alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
init_latents_proper = init_latents_proper + sigma * noise
return init_latents_proper
def prepare_latents_inpating(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
sigma=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or sigma is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise sigma has not been provided."
)
if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.add_noise(image_latents, noise, sigma)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * (sigma.item()**2 + 1) ** 0.5 if is_strength_max else latents
#latents = latents * sigma.item() if is_strength_max else latents #Nearly
else:
noise = latents.to(device)
latents = noise * (sigma.item()**2 + 1) ** 0.5
#latents = noise * sigma.item() #Nearly
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
@torch.no_grad()
def inpaiting(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
callback_steps: Optional[int] = 1,
upscale=False,
upscale_x: float = 2.0,
upscale_method: str = "bicubic",
upscale_antialias: bool = False,
upscale_denoising_strength: int = 0.7,
region_map_state=None,
sampler_name="",
sampler_opt={},
start_time=-1,
timeout=180,
latent_processing = 0,
weight_func = lambda w, sigma, qk: w * sigma * qk.std(),
seed = 0,
sampler_name_hires= "",
sampler_opt_hires= {},
latent_upscale_processing = False,
ip_adapter_image = None,
control_img = None,
controlnet_conditioning_scale = None,
control_guidance_start = None,
control_guidance_end = None,
image_t2i_adapter : Optional[PipelineImageInput] = None,
adapter_conditioning_scale: Union[float, List[float]] = 1.0,
adapter_conditioning_factor: float = 1.0,
guidance_rescale: float = 0.0,
cross_attention_kwargs = None,
clip_skip = None,
long_encode = 0,
num_images_per_prompt = 1,
image: Union[torch.Tensor, PIL.Image.Image] = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
masked_image_latents: torch.Tensor = None,
padding_mask_crop: Optional[int] = None,
strength: float = 1.0,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
):
height, width = self._default_height_width(height, width, None)
if isinstance(sampler_name, str):
sampler = self.get_scheduler(sampler_name)
else:
sampler = sampler_name
# 1. Check inputs. Raise error if not correct
if image_t2i_adapter is not None:
height, width = default_height_width(self,height, width, image_t2i_adapter)
#print(default_height_width(self,height, width, image_t2i_adapter))
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
'''do_classifier_free_guidance = True
if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")'''
lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
self._do_classifier_free_guidance = False if guidance_scale <= 1.0 else True
'''if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")'''
# 3. Encode input prompt
text_embeddings, negative_prompt_embeds, text_input_ids = encode_prompt_function(
self,
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
lora_scale = lora_scale,
clip_skip = clip_skip,
long_encode = long_encode,
)
if self.do_classifier_free_guidance:
text_embeddings = torch.cat([negative_prompt_embeds, text_embeddings])
text_embeddings = text_embeddings.to(self.unet.dtype)
# 4. Prepare timesteps
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
text_embeddings.device, dtype=text_embeddings.dtype
)
sigmas = sigmas[t_start:] if strength >= 0 and strength < 1.0 else sigmas
is_strength_max = strength == 1.0
'''if latents is None:
noise_inpaiting = randn_tensor((batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8), generator=generator, device=device, dtype=text_embeddings.dtype)
else:
noise_inpaiting = latents.to(device)'''
# 5. Prepare mask, image,
if padding_mask_crop is not None:
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = init_image.to(dtype=torch.float32)
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
image_latents = None
noise_inpaiting = None
'''latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_unet,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)'''
#latents = latents * sigmas[0]
latents_outputs = self.prepare_latents_inpating(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
image=init_image,
sigma=sigmas[0],
is_strength_max=is_strength_max,
return_noise=True,
return_image_latents=return_image_latents,
)
if return_image_latents:
latents, noise_inpaiting, image_latents = latents_outputs
else:
latents, noise_inpaiting = latents_outputs
# 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents
mask, masked_image_latents = self.prepare_mask_latents(
mask_condition,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
text_embeddings.dtype,
device,
generator,
self.do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
raise ValueError(
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
)
steps_denoising = len(sigmas)
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
latents.device
)
region_state = encode_region_map(
self,
region_map_state,
width = width,
height = height,
num_images_per_prompt = num_images_per_prompt,
text_ids=text_input_ids,
)
if cross_attention_kwargs is None:
cross_attention_kwargs ={}
controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy() if isinstance(controlnet_conditioning_scale, list) else controlnet_conditioning_scale
control_guidance_start_copy = control_guidance_start.copy() if isinstance(control_guidance_start, list) else control_guidance_start
control_guidance_end_copy = control_guidance_end.copy() if isinstance(control_guidance_end, list) else control_guidance_end
guess_mode = False
if self.controlnet is not None:
img_control,controlnet_keep,guess_mode,controlnet_conditioning_scale = self.preprocess_controlnet(controlnet_conditioning_scale,control_guidance_start,control_guidance_end,control_img,width,height,num_inference_steps,batch_size,num_images_per_prompt)
#print(len(controlnet_keep))
#controlnet_conditioning_scale_copy = controlnet_conditioning_scale.copy()
#sp_control = 1
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
else None
)
#if controlnet_img is not None:
#controlnet_img_processing = controlnet_img.convert("RGB")
#transform = transforms.Compose([transforms.PILToTensor()])
#controlnet_img_processing = transform(controlnet_img)
#controlnet_img_processing=controlnet_img_processing.to(device=device, dtype=self.cnet.dtype)
if latent_processing == 1:
latents_process = [self.latent_to_image(latents,output_type)]
#sp_find_new = None
lst_latent_sigma = []
step_control = -1
adapter_state = None
adapter_sp_count = []
flag_add_noise_inpaiting = 0
if image_t2i_adapter is not None:
adapter_state = preprocessing_t2i_adapter(self,image_t2i_adapter,width,height,adapter_conditioning_scale,1)
def model_fn(x, sigma):
nonlocal step_control,lst_latent_sigma,adapter_sp_count,flag_add_noise_inpaiting
if start_time > 0 and timeout > 0:
assert (time.time() - start_time) < timeout, "inference process timed out"
if num_channels_unet == 4 and flag_add_noise_inpaiting:
init_latents_proper = image_latents
if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
if sigma.item() > sigmas[-1].item():
#indices = torch.where(sigmas == sigma.item())[0]
#sigma_next = sigmas[indices+1]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma.item())
init_latents_proper = alpha_t * init_latents_proper + sigma_t * noise_inpaiting
rate_latent_timestep_sigma = (sigma**2 + 1) ** 0.5
x = ((1 - init_mask) * init_latents_proper + init_mask * x/ rate_latent_timestep_sigma ) * rate_latent_timestep_sigma
non_inpainting_latent_model_input = (
torch.cat([x] * 2) if self.do_classifier_free_guidance else x
)
inpainting_latent_model_input = torch.cat(
[non_inpainting_latent_model_input,mask, masked_image_latents], dim=1
) if num_channels_unet == 9 else non_inpainting_latent_model_input
region_prompt = {
"region_state": region_state,
"sigma": sigma[0],
"weight_func": weight_func,
}
cross_attention_kwargs["region_prompt"] = region_prompt
#print(self.k_diffusion_model.sigma_to_t(sigma[0]))
if non_inpainting_latent_model_input.dtype != text_embeddings.dtype:
non_inpainting_latent_model_input = non_inpainting_latent_model_input.to(text_embeddings.dtype)
if inpainting_latent_model_input.dtype != text_embeddings.dtype:
inpainting_latent_model_input = inpainting_latent_model_input.to(text_embeddings.dtype)
ukwargs = {}
down_intrablock_additional_residuals = None
if adapter_state is not None:
if len(adapter_sp_count) < int( steps_denoising* adapter_conditioning_factor):
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
down_intrablock_additional_residuals = None
sigma_string_t2i = str(sigma.item())
if sigma_string_t2i not in adapter_sp_count:
adapter_sp_count.append(sigma_string_t2i)
if self.controlnet is not None :
sigma_string = str(sigma.item())
if sigma_string not in lst_latent_sigma:
#sigmas_sp = sigma.detach().clone()
step_control+=1
lst_latent_sigma.append(sigma_string)
if isinstance(controlnet_keep[step_control], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[step_control])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[step_control]
down_block_res_samples = None
mid_block_res_sample = None
down_block_res_samples, mid_block_res_sample = self.controlnet(
non_inpainting_latent_model_input / ((sigma**2 + 1) ** 0.5),
self.k_diffusion_model.sigma_to_t(sigma),
encoder_hidden_states=text_embeddings,
controlnet_cond=img_control,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
ukwargs ={
"down_block_additional_residuals": down_block_res_samples,
"mid_block_additional_residual":mid_block_res_sample,
}
noise_pred = self.k_diffusion_model(
inpainting_latent_model_input, sigma, cond=text_embeddings,cross_attention_kwargs=cross_attention_kwargs,down_intrablock_additional_residuals=down_intrablock_additional_residuals,added_cond_kwargs=added_cond_kwargs, **ukwargs
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
if latent_processing == 1:
latents_process.append(self.latent_to_image(noise_pred,output_type))
flag_add_noise_inpaiting = 1
return noise_pred
extra_args = self.get_sampler_extra_args_t2i(
sigmas, eta, num_inference_steps,sampler_opt,latents,seed, sampler
)
latents = sampler(model_fn, latents, **extra_args)
#latents = latents_process[0]
#print(len(latents_process))
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
gc.collect()
if upscale:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
target_height = int(height * upscale_x // vae_scale_factor )* 8
target_width = int(width * upscale_x // vae_scale_factor)*8
latents = torch.nn.functional.interpolate(
latents,
size=(
int(target_height // vae_scale_factor),
int(target_width // vae_scale_factor),
),
mode=upscale_method,
antialias=upscale_antialias,
)
#if controlnet_img is not None:
#controlnet_img = cv2.resize(controlnet_img,(latents.size(0), latents.size(1)))
#controlnet_img=controlnet_img.resize((latents.size(0), latents.size(1)), Image.LANCZOS)
latent_reisze= self.img2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
latents=latents,
strength=upscale_denoising_strength,
sampler_name=sampler_name_hires,
sampler_opt=sampler_opt_hires,
region_map_state = region_map_state,
latent_processing = latent_upscale_processing,
width = int(target_width),
height = int(target_height),
seed = seed,
ip_adapter_image = ip_adapter_image,
control_img = control_img,
controlnet_conditioning_scale = controlnet_conditioning_scale_copy,
control_guidance_start = control_guidance_start_copy,
control_guidance_end = control_guidance_end_copy,
image_t2i_adapter= image_t2i_adapter,
adapter_conditioning_scale = adapter_conditioning_scale,
adapter_conditioning_factor = adapter_conditioning_factor,
guidance_rescale = guidance_rescale,
cross_attention_kwargs = cross_attention_kwargs,
clip_skip = clip_skip,
long_encode = long_encode,
num_images_per_prompt = num_images_per_prompt,
)
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
image[-1] = latent_reisze
return image'''
if latent_processing == 1:
latents_process= latents_process+latent_reisze
return latents_process
torch.cuda.empty_cache()
gc.collect()
return latent_reisze
# 8. Post-processing
'''if latent_processing == 1:
latents = latents_process.copy()
images = []
for i in latents:
images.append(self.decode_latents(i))
image = []
# 10. Convert to PIL
if output_type == "pil":
for i in images:
image.append(self.numpy_to_pil(i))
else:
image = self.decode_latents(latents)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)'''
if latent_processing == 1:
return latents_process
return [self.latent_to_image(latents,output_type)]