mrfakename commited on
Commit
5d7a7f8
1 Parent(s): a3428b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -0
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+
15
+ DESCRIPTION = """
16
+ # OpenDalle
17
+
18
+ ## A demo of [OpenDalle](https://huggingface.co/dataautogpt3/OpenDalle) by @dataautogpt3
19
+
20
+ **This demo is based on [@hysts's SD-XL demo.](https://huggingface.co/spaces/hysts/SD-XL).**
21
+ """
22
+ if not torch.cuda.is_available():
23
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
24
+
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
27
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
28
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
29
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
30
+ ENABLE_REFINER = os.getenv("ENABLE_REFINER", "0") == "1"
31
+
32
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
+ if torch.cuda.is_available():
34
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
35
+ pipe = DiffusionPipeline.from_pretrained(
36
+ "dataautogpt3/OpenDalle",
37
+ vae=vae,
38
+ torch_dtype=torch.float16,
39
+ use_safetensors=True,
40
+ variant="fp16",
41
+ )
42
+ if ENABLE_REFINER:
43
+ refiner = DiffusionPipeline.from_pretrained(
44
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
45
+ vae=vae,
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ variant="fp16",
49
+ )
50
+
51
+ if ENABLE_CPU_OFFLOAD:
52
+ pipe.enable_model_cpu_offload()
53
+ if ENABLE_REFINER:
54
+ refiner.enable_model_cpu_offload()
55
+ else:
56
+ pipe.to(device)
57
+ if ENABLE_REFINER:
58
+ refiner.to(device)
59
+
60
+ if USE_TORCH_COMPILE:
61
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
62
+ if ENABLE_REFINER:
63
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
64
+
65
+
66
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
67
+ if randomize_seed:
68
+ seed = random.randint(0, MAX_SEED)
69
+ return seed
70
+
71
+
72
+ @spaces.GPU
73
+ def generate(
74
+ prompt: str,
75
+ negative_prompt: str = "",
76
+ prompt_2: str = "",
77
+ negative_prompt_2: str = "",
78
+ use_negative_prompt: bool = False,
79
+ use_prompt_2: bool = False,
80
+ use_negative_prompt_2: bool = False,
81
+ seed: int = 0,
82
+ width: int = 1024,
83
+ height: int = 1024,
84
+ guidance_scale_base: float = 5.0,
85
+ guidance_scale_refiner: float = 5.0,
86
+ num_inference_steps_base: int = 25,
87
+ num_inference_steps_refiner: int = 25,
88
+ apply_refiner: bool = False,
89
+ ) -> PIL.Image.Image:
90
+ generator = torch.Generator().manual_seed(seed)
91
+
92
+ if not use_negative_prompt:
93
+ negative_prompt = None # type: ignore
94
+ if not use_prompt_2:
95
+ prompt_2 = None # type: ignore
96
+ if not use_negative_prompt_2:
97
+ negative_prompt_2 = None # type: ignore
98
+
99
+ if not apply_refiner:
100
+ return pipe(
101
+ prompt=prompt,
102
+ negative_prompt=negative_prompt,
103
+ prompt_2=prompt_2,
104
+ negative_prompt_2=negative_prompt_2,
105
+ width=width,
106
+ height=height,
107
+ guidance_scale=guidance_scale_base,
108
+ num_inference_steps=num_inference_steps_base,
109
+ generator=generator,
110
+ output_type="pil",
111
+ ).images[0]
112
+ else:
113
+ latents = pipe(
114
+ prompt=prompt,
115
+ negative_prompt=negative_prompt,
116
+ prompt_2=prompt_2,
117
+ negative_prompt_2=negative_prompt_2,
118
+ width=width,
119
+ height=height,
120
+ guidance_scale=guidance_scale_base,
121
+ num_inference_steps=num_inference_steps_base,
122
+ generator=generator,
123
+ output_type="latent",
124
+ ).images
125
+ image = refiner(
126
+ prompt=prompt,
127
+ negative_prompt=negative_prompt,
128
+ prompt_2=prompt_2,
129
+ negative_prompt_2=negative_prompt_2,
130
+ guidance_scale=guidance_scale_refiner,
131
+ num_inference_steps=num_inference_steps_refiner,
132
+ image=latents,
133
+ generator=generator,
134
+ ).images[0]
135
+ return image
136
+
137
+
138
+ examples = [
139
+ "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
140
+ "An astronaut riding a green horse",
141
+ ]
142
+
143
+ theme = gr.themes.Base(
144
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
145
+ )
146
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
147
+ gr.Markdown(DESCRIPTION)
148
+ gr.DuplicateButton(
149
+ value="Duplicate Space for private use",
150
+ elem_id="duplicate-button",
151
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
152
+ )
153
+ with gr.Group():
154
+ with gr.Row():
155
+ prompt = gr.Text(
156
+ label="Prompt",
157
+ show_label=False,
158
+ max_lines=1,
159
+ placeholder="Enter your prompt",
160
+ container=False,
161
+ )
162
+ run_button = gr.Button("Run", scale=0)
163
+ result = gr.Image(label="Result", show_label=False)
164
+ with gr.Accordion("Advanced options", open=False):
165
+ with gr.Row():
166
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
167
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
168
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
169
+ negative_prompt = gr.Text(
170
+ label="Negative prompt",
171
+ max_lines=1,
172
+ placeholder="Enter a negative prompt",
173
+ visible=False,
174
+ )
175
+ prompt_2 = gr.Text(
176
+ label="Prompt 2",
177
+ max_lines=1,
178
+ placeholder="Enter your prompt",
179
+ visible=False,
180
+ )
181
+ negative_prompt_2 = gr.Text(
182
+ label="Negative prompt 2",
183
+ max_lines=1,
184
+ placeholder="Enter a negative prompt",
185
+ visible=False,
186
+ )
187
+
188
+ seed = gr.Slider(
189
+ label="Seed",
190
+ minimum=0,
191
+ maximum=MAX_SEED,
192
+ step=1,
193
+ value=0,
194
+ )
195
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
196
+ with gr.Row():
197
+ width = gr.Slider(
198
+ label="Width",
199
+ minimum=256,
200
+ maximum=MAX_IMAGE_SIZE,
201
+ step=32,
202
+ value=1024,
203
+ )
204
+ height = gr.Slider(
205
+ label="Height",
206
+ minimum=256,
207
+ maximum=MAX_IMAGE_SIZE,
208
+ step=32,
209
+ value=1024,
210
+ )
211
+ apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
212
+ with gr.Row():
213
+ guidance_scale_base = gr.Slider(
214
+ label="Guidance scale for base",
215
+ minimum=1,
216
+ maximum=20,
217
+ step=0.1,
218
+ value=5.0,
219
+ )
220
+ num_inference_steps_base = gr.Slider(
221
+ label="Number of inference steps for base",
222
+ minimum=10,
223
+ maximum=100,
224
+ step=1,
225
+ value=25,
226
+ )
227
+ with gr.Row(visible=False) as refiner_params:
228
+ guidance_scale_refiner = gr.Slider(
229
+ label="Guidance scale for refiner",
230
+ minimum=1,
231
+ maximum=20,
232
+ step=0.1,
233
+ value=5.0,
234
+ )
235
+ num_inference_steps_refiner = gr.Slider(
236
+ label="Number of inference steps for refiner",
237
+ minimum=10,
238
+ maximum=100,
239
+ step=1,
240
+ value=25,
241
+ )
242
+
243
+ gr.Examples(
244
+ examples=examples,
245
+ inputs=prompt,
246
+ outputs=result,
247
+ fn=generate,
248
+ cache_examples=CACHE_EXAMPLES,
249
+ )
250
+
251
+ use_negative_prompt.change(
252
+ fn=lambda x: gr.update(visible=x),
253
+ inputs=use_negative_prompt,
254
+ outputs=negative_prompt,
255
+ queue=False,
256
+ api_name=False,
257
+ )
258
+ use_prompt_2.change(
259
+ fn=lambda x: gr.update(visible=x),
260
+ inputs=use_prompt_2,
261
+ outputs=prompt_2,
262
+ queue=False,
263
+ api_name=False,
264
+ )
265
+ use_negative_prompt_2.change(
266
+ fn=lambda x: gr.update(visible=x),
267
+ inputs=use_negative_prompt_2,
268
+ outputs=negative_prompt_2,
269
+ queue=False,
270
+ api_name=False,
271
+ )
272
+ apply_refiner.change(
273
+ fn=lambda x: gr.update(visible=x),
274
+ inputs=apply_refiner,
275
+ outputs=refiner_params,
276
+ queue=False,
277
+ api_name=False,
278
+ )
279
+
280
+ gr.on(
281
+ triggers=[
282
+ prompt.submit,
283
+ negative_prompt.submit,
284
+ prompt_2.submit,
285
+ negative_prompt_2.submit,
286
+ run_button.click,
287
+ ],
288
+ fn=randomize_seed_fn,
289
+ inputs=[seed, randomize_seed],
290
+ outputs=seed,
291
+ queue=False,
292
+ api_name=False,
293
+ ).then(
294
+ fn=generate,
295
+ inputs=[
296
+ prompt,
297
+ negative_prompt,
298
+ prompt_2,
299
+ negative_prompt_2,
300
+ use_negative_prompt,
301
+ use_prompt_2,
302
+ use_negative_prompt_2,
303
+ seed,
304
+ width,
305
+ height,
306
+ guidance_scale_base,
307
+ guidance_scale_refiner,
308
+ num_inference_steps_base,
309
+ num_inference_steps_refiner,
310
+ apply_refiner,
311
+ ],
312
+ outputs=result,
313
+ api_name="run",
314
+ )
315
+
316
+ if __name__ == "__main__":
317
+ demo.queue(max_size=20).launch()