jimmycarter commited on
Commit
e5befa2
1 Parent(s): b8de496

Final pipeline fixes

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. pipeline.py +8 -8
README.md CHANGED
@@ -69,12 +69,12 @@ quantize(
69
  freeze(pipe.transformer)
70
  pipe.enable_model_cpu_offload()
71
 
72
- # If you are still running out of memory, add do_batch_cfg=False below.
73
  images = pipe(
74
  prompt=prompt,
75
  negative_prompt=negative_prompt,
76
  device=None,
77
  return_dict=False,
 
78
  )
79
  images[0][0].save('chalkboard.png')
80
  ```
 
69
  freeze(pipe.transformer)
70
  pipe.enable_model_cpu_offload()
71
 
 
72
  images = pipe(
73
  prompt=prompt,
74
  negative_prompt=negative_prompt,
75
  device=None,
76
  return_dict=False,
77
+ do_batch_cfg=False, # https://github.com/huggingface/optimum-quanto/issues/327
78
  )
79
  images[0][0].save('chalkboard.png')
80
  ```
pipeline.py CHANGED
@@ -1614,14 +1614,14 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1614
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1615
  progress_bar.set_postfix(
1616
  {
1617
- 'ts': t / 1000.0,
1618
  'cfg': self._guidance_scale_real,
1619
  },
1620
  )
1621
  else:
1622
  progress_bar.set_postfix(
1623
  {
1624
- 'ts': t / 1000.0,
1625
  'cfg': 'N/A',
1626
  },
1627
  )
@@ -1658,17 +1658,17 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1658
  # Prepare extra transformer arguments
1659
  extra_transformer_args = {}
1660
  if prompt_mask is not None:
1661
- extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device).contiguous()
1662
 
1663
  # Forward pass through the transformer
1664
  noise_pred = self.transformer(
1665
- hidden_states=latent_model_input.to(device=self.transformer.device).contiguous() ,
1666
  timestep=timestep / 1000,
1667
  guidance=guidance,
1668
- pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
1669
- encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
1670
- txt_ids=text_ids_input.to(device=self.transformer.device).contiguous() if text_ids is not None else None,
1671
- img_ids=latent_image_ids_input.to(device=self.transformer.device).contiguous() if latent_image_ids is not None else None,
1672
  joint_attention_kwargs=self.joint_attention_kwargs,
1673
  return_dict=False,
1674
  **extra_transformer_args,
 
1614
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1615
  progress_bar.set_postfix(
1616
  {
1617
+ 'ts': t.detach().item() / 1000.0,
1618
  'cfg': self._guidance_scale_real,
1619
  },
1620
  )
1621
  else:
1622
  progress_bar.set_postfix(
1623
  {
1624
+ 'ts': t.detach().item() / 1000.0,
1625
  'cfg': 'N/A',
1626
  },
1627
  )
 
1658
  # Prepare extra transformer arguments
1659
  extra_transformer_args = {}
1660
  if prompt_mask is not None:
1661
+ extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
1662
 
1663
  # Forward pass through the transformer
1664
  noise_pred = self.transformer(
1665
+ hidden_states=latent_model_input.to(device=self.transformer.device),
1666
  timestep=timestep / 1000,
1667
  guidance=guidance,
1668
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
1669
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
1670
+ txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
1671
+ img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
1672
  joint_attention_kwargs=self.joint_attention_kwargs,
1673
  return_dict=False,
1674
  **extra_transformer_args,