|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import torch |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline |
|
|
|
|
|
@torch.no_grad() |
|
def sd_pipeline_call( |
|
pipeline: StableDiffusionPipeline, |
|
prompt_embeds: torch.FloatTensor, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None): |
|
""" Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument.""" |
|
|
|
|
|
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor |
|
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor |
|
|
|
|
|
batch_size = 1 |
|
device = pipeline._execution_device |
|
|
|
neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt) |
|
negative_prompt_embeds, _ = pipeline.text_encoder( |
|
input_ids=neg_prompt.input_ids.to(device), |
|
attention_mask=None, |
|
) |
|
negative_prompt_embeds = negative_prompt_embeds[0] |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
pipeline.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = pipeline.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = pipeline.unet.in_channels |
|
latents = pipeline.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
pipeline.text_encoder.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order |
|
with pipeline.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
if do_classifier_free_guidance: |
|
latent_model_input = latents |
|
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred_uncond = pipeline.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1), |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
).sample |
|
|
|
|
|
|
|
|
|
embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds |
|
noise_pred_text = pipeline.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=embed, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
if output_type == "latent": |
|
image = latents |
|
has_nsfw_concept = None |
|
elif output_type == "pil": |
|
|
|
image = pipeline.decode_latents(latents) |
|
|
|
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype) |
|
|
|
image = pipeline.numpy_to_pil(image) |
|
else: |
|
|
|
image = pipeline.decode_latents(latents) |
|
|
|
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype) |
|
|
|
|
|
if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None: |
|
pipeline.final_offload_hook.offload() |
|
|
|
if not return_dict: |
|
return image, has_nsfw_concept |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|
|
|
|
def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline, |
|
negative_prompt: Optional[Union[str, List[str]]] = None): |
|
if negative_prompt is None: |
|
negative_prompt = "" |
|
uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
uncond_input = pipeline.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=pipeline.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
return uncond_input |
|
|