bghira commited on
Commit
563f3d0
1 Parent(s): 8be7ccb

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +78 -72
custom_pipeline.py CHANGED
@@ -56,7 +56,6 @@ EXAMPLE_DOC_STRING = """
56
  ```py
57
  >>> import torch
58
  >>> from diffusers import FluxPipeline
59
-
60
  >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
61
  >>> pipe.to("cuda")
62
  >>> prompt = "A cat holding a sign that says hello world"
@@ -93,7 +92,6 @@ def retrieve_timesteps(
93
  """
94
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
95
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
96
-
97
  Args:
98
  scheduler (`SchedulerMixin`):
99
  The scheduler to get timesteps from.
@@ -108,7 +106,6 @@ def retrieve_timesteps(
108
  sigmas (`List[float]`, *optional*):
109
  Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
110
  `num_inference_steps` and `timesteps` must be `None`.
111
-
112
  Returns:
113
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
114
  second element is the number of inference steps.
@@ -150,9 +147,7 @@ def retrieve_timesteps(
150
  class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
151
  r"""
152
  The Flux pipeline for text-to-image generation.
153
-
154
  Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
155
-
156
  Args:
157
  transformer ([`FluxTransformer2DModel`]):
158
  Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
@@ -334,7 +329,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
334
  lora_scale: Optional[float] = None,
335
  ):
336
  r"""
337
-
338
  Args:
339
  prompt (`str` or `List[str]`, *optional*):
340
  prompt to be encoded
@@ -612,7 +606,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
612
  ):
613
  r"""
614
  Function invoked when calling the pipeline for generation.
615
-
616
  Args:
617
  prompt (`str` or `List[str]`, *optional*):
618
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
@@ -674,9 +667,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
674
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
675
  `._callback_tensor_inputs` attribute of your pipeline class.
676
  max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
677
-
678
  Examples:
679
-
680
  Returns:
681
  [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
682
  is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
@@ -797,102 +788,118 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
797
  latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
798
  timesteps = timesteps.to(self.transformer.device)
799
  text_ids = text_ids.to(self.transformer.device)[0]
 
 
 
 
800
 
801
  # 6. Denoising loop
 
802
  with self.progress_bar(total=num_inference_steps) as progress_bar:
803
  for i, t in enumerate(timesteps):
804
  if self.interrupt:
805
  continue
806
 
807
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
808
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
809
-
810
- # handle guidance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  if self.transformer.config.guidance_embeds:
812
- guidance = torch.tensor(
813
- [guidance_scale], device=self.transformer.device
814
- )
815
- guidance = guidance.expand(latents.shape[0])
816
  else:
817
  guidance = None
818
 
 
819
  extra_transformer_args = {}
820
  if prompt_mask is not None:
821
- extra_transformer_args["attention_mask"] = prompt_mask.to(
822
- device=self.transformer.device
823
- )
824
 
 
825
  noise_pred = self.transformer(
826
- hidden_states=latents.to(
827
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
828
- ),
829
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
830
  timestep=timestep / 1000,
831
  guidance=guidance,
832
- pooled_projections=pooled_prompt_embeds.to(
833
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
834
- ),
835
- encoder_hidden_states=prompt_embeds.to(
836
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
837
- ),
838
- txt_ids=text_ids,
839
- img_ids=latent_image_ids,
840
  joint_attention_kwargs=self.joint_attention_kwargs,
841
  return_dict=False,
842
  **extra_transformer_args,
843
  )[0]
844
 
845
- # TODO optionally use batch prediction to speed this up.
846
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
847
- noise_pred_uncond = self.transformer(
848
- hidden_states=latents.to(
849
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
850
- ),
851
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
852
- timestep=timestep / 1000,
853
- guidance=guidance,
854
- pooled_projections=negative_pooled_prompt_embeds.to(
855
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
856
- ),
857
- encoder_hidden_states=negative_prompt_embeds.to(
858
- device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
859
- ),
860
- txt_ids=negative_text_ids.to(device=self.transformer.device),
861
- img_ids=latent_image_ids.to(device=self.transformer.device),
862
- joint_attention_kwargs=self.joint_attention_kwargs,
863
- return_dict=False,
864
- )[0]
865
-
866
- noise_pred = noise_pred_uncond + guidance_scale_real * (
867
- noise_pred - noise_pred_uncond
868
- )
869
-
870
- # compute the previous noisy sample x_t -> x_t-1
871
  latents_dtype = latents.dtype
872
- latents = self.scheduler.step(
873
- noise_pred, t, latents, return_dict=False
874
- )[0]
875
 
 
876
  if latents.dtype != latents_dtype:
877
  if torch.backends.mps.is_available():
878
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
879
  latents = latents.to(latents_dtype)
880
 
 
881
  if callback_on_step_end is not None:
882
- callback_kwargs = {}
883
- for k in callback_on_step_end_tensor_inputs:
884
- callback_kwargs[k] = locals()[k]
885
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
 
 
886
 
887
- latents = callback_outputs.pop("latents", latents)
888
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
889
-
890
- # call the callback, if provided
891
- if i == len(timesteps) - 1 or (
892
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
893
- ):
894
  progress_bar.update()
895
 
 
896
  if XLA_AVAILABLE:
897
  xm.mark_step()
898
 
@@ -932,7 +939,6 @@ from diffusers.utils import BaseOutput
932
  class FluxPipelineOutput(BaseOutput):
933
  """
934
  Output class for Stable Diffusion pipelines.
935
-
936
  Args:
937
  images (`List[PIL.Image.Image]` or `np.ndarray`)
938
  List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
 
56
  ```py
57
  >>> import torch
58
  >>> from diffusers import FluxPipeline
 
59
  >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
60
  >>> pipe.to("cuda")
61
  >>> prompt = "A cat holding a sign that says hello world"
 
92
  """
93
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
 
95
  Args:
96
  scheduler (`SchedulerMixin`):
97
  The scheduler to get timesteps from.
 
106
  sigmas (`List[float]`, *optional*):
107
  Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
108
  `num_inference_steps` and `timesteps` must be `None`.
 
109
  Returns:
110
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
111
  second element is the number of inference steps.
 
147
  class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
148
  r"""
149
  The Flux pipeline for text-to-image generation.
 
150
  Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
 
151
  Args:
152
  transformer ([`FluxTransformer2DModel`]):
153
  Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
 
329
  lora_scale: Optional[float] = None,
330
  ):
331
  r"""
 
332
  Args:
333
  prompt (`str` or `List[str]`, *optional*):
334
  prompt to be encoded
 
606
  ):
607
  r"""
608
  Function invoked when calling the pipeline for generation.
 
609
  Args:
610
  prompt (`str` or `List[str]`, *optional*):
611
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
 
667
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
668
  `._callback_tensor_inputs` attribute of your pipeline class.
669
  max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
 
670
  Examples:
 
671
  Returns:
672
  [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
673
  is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
 
788
  latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
789
  timesteps = timesteps.to(self.transformer.device)
790
  text_ids = text_ids.to(self.transformer.device)[0]
791
+ negative_text_ids = negative_text_ids.to(self.transformer.device)[0]
792
+
793
+ # Assume 'do_batch_cfg' is a boolean indicating whether to use batched CFG
794
+ do_batch_cfg = True # Set this to False to use sequential CFG
795
 
796
  # 6. Denoising loop
797
+
798
  with self.progress_bar(total=num_inference_steps) as progress_bar:
799
  for i, t in enumerate(timesteps):
800
  if self.interrupt:
801
  continue
802
 
803
+ # Prepare the latent model input
804
+ prompt_embeds_input = prompt_embeds
805
+ pooled_prompt_embeds_input = pooled_prompt_embeds
806
+ text_ids_input = text_ids
807
+ latent_image_ids_input = latent_image_ids
808
+ prompt_mask_input = prompt_mask
809
+ latent_model_input = latents
810
+
811
+ if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
812
+ # Concatenate prompt embeddings
813
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
814
+ pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
815
+
816
+ # # Concatenate text IDs if they are used
817
+ # if text_ids is not None and negative_text_ids is not None:
818
+ # text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
819
+
820
+ # Concatenate latent image IDs if they are used
821
+ # if latent_image_ids is not None:
822
+ # latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
823
+
824
+ # Concatenate prompt masks if they are used
825
+ if prompt_mask is not None and negative_mask is not None:
826
+ prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
827
+ # Duplicate latents for unconditional and conditional inputs
828
+ latent_model_input = torch.cat([latents] * 2)
829
+
830
+ # Expand timestep to match batch size
831
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
832
+
833
+ # Handle guidance
834
  if self.transformer.config.guidance_embeds:
835
+ guidance = torch.tensor([guidance_scale], device=self.transformer.device)
836
+ guidance = guidance.expand(latent_model_input.shape[0])
 
 
837
  else:
838
  guidance = None
839
 
840
+ # Prepare extra transformer arguments
841
  extra_transformer_args = {}
842
  if prompt_mask is not None:
843
+ extra_transformer_args["attention_mask"] = prompt_mask.to(device=self.transformer.device)
 
 
844
 
845
+ # Forward pass through the transformer
846
  noise_pred = self.transformer(
847
+ hidden_states=latent_model_input.to(device=self.transformer.device),
 
 
 
848
  timestep=timestep / 1000,
849
  guidance=guidance,
850
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
851
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
852
+ txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
853
+ img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
 
 
 
 
854
  joint_attention_kwargs=self.joint_attention_kwargs,
855
  return_dict=False,
856
  **extra_transformer_args,
857
  )[0]
858
 
859
+ # Apply real CFG
860
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
861
+ if do_batch_cfg:
862
+ # Batched CFG: Split the noise prediction into unconditional and conditional parts
863
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
864
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
865
+ else:
866
+ # Sequential CFG: Compute unconditional noise prediction separately
867
+ noise_pred_uncond = self.transformer(
868
+ hidden_states=latents.to(device=self.transformer.device),
869
+ timestep=timestep / 1000,
870
+ guidance=guidance,
871
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
872
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
873
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
874
+ img_ids=latent_image_ids.to(device=self.transformer.device) if latent_image_ids is not None else None,
875
+ joint_attention_kwargs=self.joint_attention_kwargs,
876
+ return_dict=False,
877
+ )[0]
878
+
879
+ # Combine conditional and unconditional predictions
880
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
881
+
882
+ # Compute the previous noisy sample x_t -> x_t-1
 
 
883
  latents_dtype = latents.dtype
884
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
885
 
886
+ # Ensure latents have the correct dtype
887
  if latents.dtype != latents_dtype:
888
  if torch.backends.mps.is_available():
 
889
  latents = latents.to(latents_dtype)
890
 
891
+ # Callback at the end of the step, if provided
892
  if callback_on_step_end is not None:
893
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
 
 
894
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
895
+ latents = callback_outputs.get("latents", latents)
896
+ prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
897
 
898
+ # Update the progress bar
899
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
 
 
 
900
  progress_bar.update()
901
 
902
+ # Mark step for XLA devices
903
  if XLA_AVAILABLE:
904
  xm.mark_step()
905
 
 
939
  class FluxPipelineOutput(BaseOutput):
940
  """
941
  Output class for Stable Diffusion pipelines.
 
942
  Args:
943
  images (`List[PIL.Image.Image]` or `np.ndarray`)
944
  List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,