jbilcke-hf HF staff commited on
Commit
b13990e
1 Parent(s): cad2673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -24
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  #!/usr/bin/env python
2
 
3
  import os
@@ -6,29 +7,24 @@ import gradio as gr
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
- from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
13
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
14
 
 
 
 
15
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
  if torch.cuda.is_available():
17
- unet = UNet2DConditionModel.from_pretrained(
18
- "latent-consistency/lcm-ssd-1b",
19
- torch_dtype=torch.float16,
20
- variant="fp16"
21
- )
22
-
23
- pipe = DiffusionPipeline.from_pretrained(
24
- "segmind/SSD-1B",
25
- unet=unet,
26
- torch_dtype=torch.float16,
27
- variant="fp16"
28
- )
29
-
30
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
31
- pipe.to(device)
 
 
 
 
32
  else:
33
  pipe = None
34
 
@@ -44,8 +40,8 @@ def generate(prompt: str,
44
  seed: int = 0,
45
  width: int = 1024,
46
  height: int = 1024,
47
- guidance_scale: float = 1.0,
48
- num_inference_steps: int = 6,
49
  secret_token: str = '') -> PIL.Image.Image:
50
  if secret_token != SECRET_TOKEN:
51
  raise gr.Error(
@@ -69,7 +65,7 @@ with gr.Blocks() as demo:
69
  gr.HTML("""
70
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
71
  <div style="text-align: center; color: black;">
72
- <p style="color: black;">This space is a REST API to programmatically generate images using LCM-SSD-1B.</p>
73
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
74
  </div>
75
  </div>""")
@@ -117,16 +113,16 @@ with gr.Blocks() as demo:
117
  )
118
  guidance_scale = gr.Slider(
119
  label='Guidance scale',
120
- minimum=1,
121
- maximum=20,
122
  step=0.1,
123
- value=1.0)
124
  num_inference_steps = gr.Slider(
125
  label='Number of inference steps',
126
- minimum=2,
127
- maximum=40,
128
  step=1,
129
- value=6)
130
 
131
  use_negative_prompt.change(
132
  fn=lambda x: gr.update(visible=x),
 
1
+
2
  #!/usr/bin/env python
3
 
4
  import os
 
7
  import numpy as np
8
  import PIL.Image
9
  import torch
10
+ from diffusers import LCMScheduler, AutoPipelineForText2Image
11
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
14
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
15
 
16
+ MODEL_ID = "segmind/SSD-1B"
17
+ ADAPTER_ID = "latent-consistency/lcm-lora-ssd-1b"
18
+
19
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
20
  if torch.cuda.is_available():
21
+ pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=torch.float16, variant="fp16")
 
 
 
 
 
 
 
 
 
 
 
 
22
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
23
+ pipe.to("cuda")
24
+
25
+ # load and fuse
26
+ pipe.load_lora_weights(ADAPTER_ID)
27
+ pipe.fuse_lora()
28
  else:
29
  pipe = None
30
 
 
40
  seed: int = 0,
41
  width: int = 1024,
42
  height: int = 1024,
43
+ guidance_scale: float = 0.0,
44
+ num_inference_steps: int = 4,
45
  secret_token: str = '') -> PIL.Image.Image:
46
  if secret_token != SECRET_TOKEN:
47
  raise gr.Error(
 
65
  gr.HTML("""
66
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
67
  <div style="text-align: center; color: black;">
68
+ <p style="color: black;">This space is a REST API to programmatically generate images using LCM SDXL LoRA.</p>
69
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
70
  </div>
71
  </div>""")
 
113
  )
114
  guidance_scale = gr.Slider(
115
  label='Guidance scale',
116
+ minimum=0,
117
+ maximum=2,
118
  step=0.1,
119
+ value=0.0)
120
  num_inference_steps = gr.Slider(
121
  label='Number of inference steps',
122
+ minimum=1,
123
+ maximum=8,
124
  step=1,
125
+ value=4)
126
 
127
  use_negative_prompt.change(
128
  fn=lambda x: gr.update(visible=x),