Spaces:
Running
on
Zero
Running
on
Zero
Update custom_pipeline.py
Browse files- custom_pipeline.py +78 -72
custom_pipeline.py
CHANGED
@@ -56,7 +56,6 @@ EXAMPLE_DOC_STRING = """
|
|
56 |
```py
|
57 |
>>> import torch
|
58 |
>>> from diffusers import FluxPipeline
|
59 |
-
|
60 |
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
61 |
>>> pipe.to("cuda")
|
62 |
>>> prompt = "A cat holding a sign that says hello world"
|
@@ -93,7 +92,6 @@ def retrieve_timesteps(
|
|
93 |
"""
|
94 |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
95 |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
96 |
-
|
97 |
Args:
|
98 |
scheduler (`SchedulerMixin`):
|
99 |
The scheduler to get timesteps from.
|
@@ -108,7 +106,6 @@ def retrieve_timesteps(
|
|
108 |
sigmas (`List[float]`, *optional*):
|
109 |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
110 |
`num_inference_steps` and `timesteps` must be `None`.
|
111 |
-
|
112 |
Returns:
|
113 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
114 |
second element is the number of inference steps.
|
@@ -150,9 +147,7 @@ def retrieve_timesteps(
|
|
150 |
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
151 |
r"""
|
152 |
The Flux pipeline for text-to-image generation.
|
153 |
-
|
154 |
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
155 |
-
|
156 |
Args:
|
157 |
transformer ([`FluxTransformer2DModel`]):
|
158 |
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
@@ -334,7 +329,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
334 |
lora_scale: Optional[float] = None,
|
335 |
):
|
336 |
r"""
|
337 |
-
|
338 |
Args:
|
339 |
prompt (`str` or `List[str]`, *optional*):
|
340 |
prompt to be encoded
|
@@ -612,7 +606,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
612 |
):
|
613 |
r"""
|
614 |
Function invoked when calling the pipeline for generation.
|
615 |
-
|
616 |
Args:
|
617 |
prompt (`str` or `List[str]`, *optional*):
|
618 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
@@ -674,9 +667,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
674 |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
675 |
`._callback_tensor_inputs` attribute of your pipeline class.
|
676 |
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
677 |
-
|
678 |
Examples:
|
679 |
-
|
680 |
Returns:
|
681 |
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
682 |
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
@@ -797,102 +788,118 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
797 |
latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
|
798 |
timesteps = timesteps.to(self.transformer.device)
|
799 |
text_ids = text_ids.to(self.transformer.device)[0]
|
|
|
|
|
|
|
|
|
800 |
|
801 |
# 6. Denoising loop
|
|
|
802 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
803 |
for i, t in enumerate(timesteps):
|
804 |
if self.interrupt:
|
805 |
continue
|
806 |
|
807 |
-
#
|
808 |
-
|
809 |
-
|
810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
if self.transformer.config.guidance_embeds:
|
812 |
-
guidance = torch.tensor(
|
813 |
-
|
814 |
-
)
|
815 |
-
guidance = guidance.expand(latents.shape[0])
|
816 |
else:
|
817 |
guidance = None
|
818 |
|
|
|
819 |
extra_transformer_args = {}
|
820 |
if prompt_mask is not None:
|
821 |
-
extra_transformer_args["attention_mask"] = prompt_mask.to(
|
822 |
-
device=self.transformer.device
|
823 |
-
)
|
824 |
|
|
|
825 |
noise_pred = self.transformer(
|
826 |
-
hidden_states=
|
827 |
-
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
|
828 |
-
),
|
829 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
830 |
timestep=timestep / 1000,
|
831 |
guidance=guidance,
|
832 |
-
pooled_projections=
|
833 |
-
|
834 |
-
),
|
835 |
-
|
836 |
-
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
|
837 |
-
),
|
838 |
-
txt_ids=text_ids,
|
839 |
-
img_ids=latent_image_ids,
|
840 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
841 |
return_dict=False,
|
842 |
**extra_transformer_args,
|
843 |
)[0]
|
844 |
|
845 |
-
#
|
846 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
)
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
device=self.transformer.device
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
# compute the previous noisy sample x_t -> x_t-1
|
871 |
latents_dtype = latents.dtype
|
872 |
-
latents = self.scheduler.step(
|
873 |
-
noise_pred, t, latents, return_dict=False
|
874 |
-
)[0]
|
875 |
|
|
|
876 |
if latents.dtype != latents_dtype:
|
877 |
if torch.backends.mps.is_available():
|
878 |
-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
879 |
latents = latents.to(latents_dtype)
|
880 |
|
|
|
881 |
if callback_on_step_end is not None:
|
882 |
-
callback_kwargs = {}
|
883 |
-
for k in callback_on_step_end_tensor_inputs:
|
884 |
-
callback_kwargs[k] = locals()[k]
|
885 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
|
886 |
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
# call the callback, if provided
|
891 |
-
if i == len(timesteps) - 1 or (
|
892 |
-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
893 |
-
):
|
894 |
progress_bar.update()
|
895 |
|
|
|
896 |
if XLA_AVAILABLE:
|
897 |
xm.mark_step()
|
898 |
|
@@ -932,7 +939,6 @@ from diffusers.utils import BaseOutput
|
|
932 |
class FluxPipelineOutput(BaseOutput):
|
933 |
"""
|
934 |
Output class for Stable Diffusion pipelines.
|
935 |
-
|
936 |
Args:
|
937 |
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
938 |
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
|
|
56 |
```py
|
57 |
>>> import torch
|
58 |
>>> from diffusers import FluxPipeline
|
|
|
59 |
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
60 |
>>> pipe.to("cuda")
|
61 |
>>> prompt = "A cat holding a sign that says hello world"
|
|
|
92 |
"""
|
93 |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
94 |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
95 |
Args:
|
96 |
scheduler (`SchedulerMixin`):
|
97 |
The scheduler to get timesteps from.
|
|
|
106 |
sigmas (`List[float]`, *optional*):
|
107 |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
108 |
`num_inference_steps` and `timesteps` must be `None`.
|
|
|
109 |
Returns:
|
110 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
111 |
second element is the number of inference steps.
|
|
|
147 |
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
148 |
r"""
|
149 |
The Flux pipeline for text-to-image generation.
|
|
|
150 |
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
|
|
151 |
Args:
|
152 |
transformer ([`FluxTransformer2DModel`]):
|
153 |
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
|
|
329 |
lora_scale: Optional[float] = None,
|
330 |
):
|
331 |
r"""
|
|
|
332 |
Args:
|
333 |
prompt (`str` or `List[str]`, *optional*):
|
334 |
prompt to be encoded
|
|
|
606 |
):
|
607 |
r"""
|
608 |
Function invoked when calling the pipeline for generation.
|
|
|
609 |
Args:
|
610 |
prompt (`str` or `List[str]`, *optional*):
|
611 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
|
667 |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
668 |
`._callback_tensor_inputs` attribute of your pipeline class.
|
669 |
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
|
|
670 |
Examples:
|
|
|
671 |
Returns:
|
672 |
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
673 |
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
|
|
788 |
latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
|
789 |
timesteps = timesteps.to(self.transformer.device)
|
790 |
text_ids = text_ids.to(self.transformer.device)[0]
|
791 |
+
negative_text_ids = negative_text_ids.to(self.transformer.device)[0]
|
792 |
+
|
793 |
+
# Assume 'do_batch_cfg' is a boolean indicating whether to use batched CFG
|
794 |
+
do_batch_cfg = True # Set this to False to use sequential CFG
|
795 |
|
796 |
# 6. Denoising loop
|
797 |
+
|
798 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
799 |
for i, t in enumerate(timesteps):
|
800 |
if self.interrupt:
|
801 |
continue
|
802 |
|
803 |
+
# Prepare the latent model input
|
804 |
+
prompt_embeds_input = prompt_embeds
|
805 |
+
pooled_prompt_embeds_input = pooled_prompt_embeds
|
806 |
+
text_ids_input = text_ids
|
807 |
+
latent_image_ids_input = latent_image_ids
|
808 |
+
prompt_mask_input = prompt_mask
|
809 |
+
latent_model_input = latents
|
810 |
+
|
811 |
+
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
812 |
+
# Concatenate prompt embeddings
|
813 |
+
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
814 |
+
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
815 |
+
|
816 |
+
# # Concatenate text IDs if they are used
|
817 |
+
# if text_ids is not None and negative_text_ids is not None:
|
818 |
+
# text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
|
819 |
+
|
820 |
+
# Concatenate latent image IDs if they are used
|
821 |
+
# if latent_image_ids is not None:
|
822 |
+
# latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
|
823 |
+
|
824 |
+
# Concatenate prompt masks if they are used
|
825 |
+
if prompt_mask is not None and negative_mask is not None:
|
826 |
+
prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
|
827 |
+
# Duplicate latents for unconditional and conditional inputs
|
828 |
+
latent_model_input = torch.cat([latents] * 2)
|
829 |
+
|
830 |
+
# Expand timestep to match batch size
|
831 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
832 |
+
|
833 |
+
# Handle guidance
|
834 |
if self.transformer.config.guidance_embeds:
|
835 |
+
guidance = torch.tensor([guidance_scale], device=self.transformer.device)
|
836 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
|
|
|
|
837 |
else:
|
838 |
guidance = None
|
839 |
|
840 |
+
# Prepare extra transformer arguments
|
841 |
extra_transformer_args = {}
|
842 |
if prompt_mask is not None:
|
843 |
+
extra_transformer_args["attention_mask"] = prompt_mask.to(device=self.transformer.device)
|
|
|
|
|
844 |
|
845 |
+
# Forward pass through the transformer
|
846 |
noise_pred = self.transformer(
|
847 |
+
hidden_states=latent_model_input.to(device=self.transformer.device),
|
|
|
|
|
|
|
848 |
timestep=timestep / 1000,
|
849 |
guidance=guidance,
|
850 |
+
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
|
851 |
+
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
|
852 |
+
txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
|
853 |
+
img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
|
|
|
|
|
|
|
|
854 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
855 |
return_dict=False,
|
856 |
**extra_transformer_args,
|
857 |
)[0]
|
858 |
|
859 |
+
# Apply real CFG
|
860 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
861 |
+
if do_batch_cfg:
|
862 |
+
# Batched CFG: Split the noise prediction into unconditional and conditional parts
|
863 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
864 |
+
noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
|
865 |
+
else:
|
866 |
+
# Sequential CFG: Compute unconditional noise prediction separately
|
867 |
+
noise_pred_uncond = self.transformer(
|
868 |
+
hidden_states=latents.to(device=self.transformer.device),
|
869 |
+
timestep=timestep / 1000,
|
870 |
+
guidance=guidance,
|
871 |
+
pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
|
872 |
+
encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
|
873 |
+
txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
|
874 |
+
img_ids=latent_image_ids.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
875 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
876 |
+
return_dict=False,
|
877 |
+
)[0]
|
878 |
+
|
879 |
+
# Combine conditional and unconditional predictions
|
880 |
+
noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
|
881 |
+
|
882 |
+
# Compute the previous noisy sample x_t -> x_t-1
|
|
|
|
|
883 |
latents_dtype = latents.dtype
|
884 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
|
885 |
|
886 |
+
# Ensure latents have the correct dtype
|
887 |
if latents.dtype != latents_dtype:
|
888 |
if torch.backends.mps.is_available():
|
|
|
889 |
latents = latents.to(latents_dtype)
|
890 |
|
891 |
+
# Callback at the end of the step, if provided
|
892 |
if callback_on_step_end is not None:
|
893 |
+
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
|
|
|
|
894 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
895 |
+
latents = callback_outputs.get("latents", latents)
|
896 |
+
prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
|
897 |
|
898 |
+
# Update the progress bar
|
899 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
|
|
|
|
|
|
|
900 |
progress_bar.update()
|
901 |
|
902 |
+
# Mark step for XLA devices
|
903 |
if XLA_AVAILABLE:
|
904 |
xm.mark_step()
|
905 |
|
|
|
939 |
class FluxPipelineOutput(BaseOutput):
|
940 |
"""
|
941 |
Output class for Stable Diffusion pipelines.
|
|
|
942 |
Args:
|
943 |
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
944 |
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|