RohitGandikota commited on
Commit
7ad90cf
·
verified ·
1 Parent(s): dd7ab29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -2,26 +2,37 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
 
12
  if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
  negative_prompt,
@@ -38,6 +49,7 @@ def infer(
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
@@ -151,4 +163,4 @@ with gr.Blocks(css=css) as demo:
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
2
  import numpy as np
3
  import random
4
 
5
+ import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
8
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors.torch import load_file
11
+
12
+ model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
13
+ repo_name = "tianweiy/DMD2"
14
+ ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
15
 
 
 
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
  if torch.cuda.is_available():
19
+ torch_dtype = torch.bfloat16
20
  else:
21
  torch_dtype = torch.float32
22
 
23
+ # Load model.
24
+ unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet").to(device, torch_dtype)
25
+ unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
26
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, unet=unet, torch_dtype=torch_dtype).to(device)
27
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
28
+
29
+
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1024
33
 
34
 
35
+ @spaces.GPU #[uncomment to use ZeroGPU]
36
  def infer(
37
  prompt,
38
  negative_prompt,
 
49
 
50
  generator = torch.Generator().manual_seed(seed)
51
 
52
+ # with network:
53
  image = pipe(
54
  prompt=prompt,
55
  negative_prompt=negative_prompt,
 
163
  )
164
 
165
  if __name__ == "__main__":
166
+ demo.launch()