multimodalart HF staff commited on
Commit
ebf93e0
·
verified ·
1 Parent(s): 29a07fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -1
app.py CHANGED
@@ -1,10 +1,38 @@
1
- from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
 
 
2
  import random
3
  import torch
4
  import numpy as np
5
  import gradio as gr
6
  import spaces
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
9
  pipe = StableDiffusionXLPipeline.from_pretrained(
10
  "stabilityai/stable-diffusion-xl-base-1.0",
 
1
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
2
+ from diffusers import DPMSolverMultistepScheduler as DefaultDPMSolver
3
+
4
  import random
5
  import torch
6
  import numpy as np
7
  import gradio as gr
8
  import spaces
9
 
10
+ # Add support for setting custom timesteps
11
+ class DPMSolverMultistepScheduler(DefaultDPMSolver):
12
+ def set_timesteps(
13
+ self, num_inference_steps=None, device=None,
14
+ timesteps=None
15
+ ):
16
+ if timesteps is None:
17
+ super().set_timesteps(num_inference_steps, device)
18
+ return
19
+
20
+ all_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
21
+ self.sigmas = torch.from_numpy(all_sigmas[timesteps])
22
+ self.timesteps = torch.tensor(timesteps[:-1]).to(device=device, dtype=torch.int64) # Ignore the last 0
23
+
24
+ self.num_inference_steps = len(timesteps)
25
+
26
+ self.model_outputs = [
27
+ None,
28
+ ] * self.config.solver_order
29
+ self.lower_order_nums = 0
30
+
31
+ # add an index counter for schedulers that allow duplicated timesteps
32
+ self._step_index = None
33
+ self._begin_index = None
34
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
35
+
36
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
37
  pipe = StableDiffusionXLPipeline.from_pretrained(
38
  "stabilityai/stable-diffusion-xl-base-1.0",