Spaces:
Paused
Paused
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
from ...image_processor import VaeImageProcessor | |
from ...models import UVit2DModel, VQModel | |
from ...schedulers import AmusedScheduler | |
from ...utils import replace_example_docstring | |
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import AmusedPipeline | |
>>> pipe = AmusedPipeline.from_pretrained( | |
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 | |
... ) | |
>>> pipe = pipe.to("cuda") | |
>>> prompt = "a photo of an astronaut riding a horse on mars" | |
>>> image = pipe(prompt).images[0] | |
``` | |
""" | |
class AmusedPipeline(DiffusionPipeline): | |
image_processor: VaeImageProcessor | |
vqvae: VQModel | |
tokenizer: CLIPTokenizer | |
text_encoder: CLIPTextModelWithProjection | |
transformer: UVit2DModel | |
scheduler: AmusedScheduler | |
model_cpu_offload_seq = "text_encoder->transformer->vqvae" | |
def __init__( | |
self, | |
vqvae: VQModel, | |
tokenizer: CLIPTokenizer, | |
text_encoder: CLIPTextModelWithProjection, | |
transformer: UVit2DModel, | |
scheduler: AmusedScheduler, | |
): | |
super().__init__() | |
self.register_modules( | |
vqvae=vqvae, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
transformer=transformer, | |
scheduler=scheduler, | |
) | |
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) | |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) | |
def __call__( | |
self, | |
prompt: Optional[Union[List[str], str]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 12, | |
guidance_scale: float = 10.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[torch.Generator] = None, | |
latents: Optional[torch.IntTensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_encoder_hidden_states: Optional[torch.Tensor] = None, | |
output_type="pil", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
micro_conditioning_aesthetic_score: int = 6, | |
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), | |
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), | |
): | |
""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): | |
The height in pixels of the generated image. | |
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
The width in pixels of the generated image. | |
num_inference_steps (`int`, *optional*, defaults to 16): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
guidance_scale (`float`, *optional*, defaults to 10.0): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
generator (`torch.Generator`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
latents (`torch.IntTensor`, *optional*): | |
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image | |
gneration. If not provided, the starting latents will be completely masked. | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. A single vector from the | |
pooled and projected final hidden states. | |
encoder_hidden_states (`torch.FloatTensor`, *optional*): | |
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): | |
Analogous to `encoder_hidden_states` for the positive prompt. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
callback (`Callable`, *optional*): | |
A function that calls every `callback_steps` steps during inference. The function is called with the | |
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
callback_steps (`int`, *optional*, defaults to 1): | |
The frequency at which the `callback` function is called. If not specified, the callback is called at | |
every step. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): | |
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ | |
and the micro-conditioning section of https://arxiv.org/abs/2307.01952. | |
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): | |
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. | |
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): | |
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. | |
Examples: | |
Returns: | |
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a | |
`tuple` is returned where the first element is a list with the generated images. | |
""" | |
if (prompt_embeds is not None and encoder_hidden_states is None) or ( | |
prompt_embeds is None and encoder_hidden_states is not None | |
): | |
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") | |
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( | |
negative_prompt_embeds is None and negative_encoder_hidden_states is not None | |
): | |
raise ValueError( | |
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" | |
) | |
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): | |
raise ValueError("pass only one of `prompt` or `prompt_embeds`") | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
if prompt is not None: | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
batch_size = batch_size * num_images_per_prompt | |
if height is None: | |
height = self.transformer.config.sample_size * self.vae_scale_factor | |
if width is None: | |
width = self.transformer.config.sample_size * self.vae_scale_factor | |
if prompt_embeds is None: | |
input_ids = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
).input_ids.to(self._execution_device) | |
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
prompt_embeds = outputs.text_embeds | |
encoder_hidden_states = outputs.hidden_states[-2] | |
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) | |
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
if guidance_scale > 1.0: | |
if negative_prompt_embeds is None: | |
if negative_prompt is None: | |
negative_prompt = [""] * len(prompt) | |
if isinstance(negative_prompt, str): | |
negative_prompt = [negative_prompt] | |
input_ids = self.tokenizer( | |
negative_prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
).input_ids.to(self._execution_device) | |
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
negative_prompt_embeds = outputs.text_embeds | |
negative_encoder_hidden_states = outputs.hidden_states[-2] | |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) | |
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) | |
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) | |
# Note that the micro conditionings _do_ flip the order of width, height for the original size | |
# and the crop coordinates. This is how it was done in the original code base | |
micro_conds = torch.tensor( | |
[ | |
width, | |
height, | |
micro_conditioning_crop_coord[0], | |
micro_conditioning_crop_coord[1], | |
micro_conditioning_aesthetic_score, | |
], | |
device=self._execution_device, | |
dtype=encoder_hidden_states.dtype, | |
) | |
micro_conds = micro_conds.unsqueeze(0) | |
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) | |
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
if latents is None: | |
latents = torch.full( | |
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device | |
) | |
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) | |
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, timestep in enumerate(self.scheduler.timesteps): | |
if guidance_scale > 1.0: | |
model_input = torch.cat([latents] * 2) | |
else: | |
model_input = latents | |
model_output = self.transformer( | |
model_input, | |
micro_conds=micro_conds, | |
pooled_text_emb=prompt_embeds, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
if guidance_scale > 1.0: | |
uncond_logits, cond_logits = model_output.chunk(2) | |
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
latents = self.scheduler.step( | |
model_output=model_output, | |
timestep=timestep, | |
sample=latents, | |
generator=generator, | |
).prev_sample | |
if i == len(self.scheduler.timesteps) - 1 or ( | |
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, timestep, latents) | |
if output_type == "latent": | |
output = latents | |
else: | |
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast | |
if needs_upcasting: | |
self.vqvae.float() | |
output = self.vqvae.decode( | |
latents, | |
force_not_quantize=True, | |
shape=( | |
batch_size, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor, | |
self.vqvae.config.latent_channels, | |
), | |
).sample.clip(0, 1) | |
output = self.image_processor.postprocess(output, output_type) | |
if needs_upcasting: | |
self.vqvae.half() | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (output,) | |
return ImagePipelineOutput(output) | |