Nick088 commited on
Commit
9c05349
·
verified ·
1 Parent(s): 60fba13

sdxl flash & stable cascade, improved advanced settings

Browse files
Files changed (1) hide show
  1. app.py +516 -87
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import torch
2
- from diffusers import StableDiffusion3Pipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverSinglestepScheduler
3
  import gradio as gr
4
  import os
5
  import random
6
  import numpy as np
 
7
  import spaces
8
 
9
- HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
  if torch.cuda.is_available():
12
  device = "cuda"
@@ -19,41 +20,42 @@ else:
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
  # Initialize the pipelines for each sd model
22
- sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
23
- "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
24
- )
25
  sd3_medium_pipe.enable_model_cpu_offload()
26
 
27
- sd2_1_pipe = StableDiffusionPipeline.from_pretrained(
28
- "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
29
- )
30
  sd2_1_pipe.enable_model_cpu_offload()
31
 
32
- sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
33
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
34
- )
35
  sdxl_pipe.enable_model_cpu_offload()
36
 
37
- sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained(
38
- "sd-community/sdxl-flash", torch_dtype=torch.float16
39
- )
40
  sdxl_flash_pipe.enable_model_cpu_offload()
41
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
42
  sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing")
43
 
 
 
 
 
 
44
  # Helper function to generate images for a single model
45
  @spaces.GPU(duration=80)
46
  def generate_single_image(
47
  prompt,
48
  negative_prompt,
49
  num_inference_steps,
 
50
  height,
51
  width,
52
- guidance_scale,
53
  seed,
54
  num_images_per_prompt,
55
  model_choice,
56
  generator,
 
 
 
 
57
  ):
58
  # Select the correct pipeline based on the model choice
59
  if model_choice == "sd3 medium":
@@ -64,19 +66,41 @@ def generate_single_image(
64
  pipe = sdxl_pipe
65
  elif model_choice == "sdxl flash":
66
  pipe = sdxl_flash_pipe
 
 
67
  else:
68
  raise ValueError(f"Invalid model choice: {model_choice}")
69
 
70
- output = pipe(
71
- prompt=prompt,
72
- negative_prompt=negative_prompt,
73
- num_inference_steps=num_inference_steps,
74
- height=height,
75
- width=width,
76
- guidance_scale=guidance_scale,
77
- generator=generator,
78
- num_images_per_prompt=num_images_per_prompt,
79
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  return output
82
 
@@ -85,14 +109,24 @@ def generate_single_image(
85
  def generate_arena_images(
86
  prompt,
87
  negative_prompt,
88
- num_inference_steps,
 
 
 
89
  height,
90
  width,
91
- guidance_scale,
92
  seed,
93
  num_images_per_prompt,
94
- model_choice_1,
95
- model_choice_2,
 
 
 
 
 
 
 
 
96
  progress=gr.Progress(track_tqdm=True),
97
  ):
98
  if seed == 0:
@@ -101,32 +135,40 @@ def generate_arena_images(
101
  generator = torch.Generator().manual_seed(seed)
102
 
103
  # Generate images for both models
104
- images_1 = generate_single_image(
105
  prompt,
106
  negative_prompt,
107
- num_inference_steps,
 
108
  height,
109
  width,
110
- guidance_scale,
111
  seed,
112
  num_images_per_prompt,
113
- model_choice_1,
114
  generator,
 
 
 
 
115
  )
116
- images_2 = generate_single_image(
117
  prompt,
118
  negative_prompt,
119
- num_inference_steps,
 
120
  height,
121
  width,
122
- guidance_scale,
123
  seed,
124
  num_images_per_prompt,
125
- model_choice_2,
126
  generator,
 
 
 
 
127
  )
128
 
129
- return images_1, images_2
130
 
131
  # Define the image generation function for the Individual tab
132
  @spaces.GPU(duration=80)
@@ -134,12 +176,16 @@ def generate_individual_image(
134
  prompt,
135
  negative_prompt,
136
  num_inference_steps,
 
137
  height,
138
  width,
139
- guidance_scale,
140
  seed,
141
  num_images_per_prompt,
142
  model_choice,
 
 
 
 
143
  progress=gr.Progress(track_tqdm=True),
144
  ):
145
  if seed == 0:
@@ -151,23 +197,100 @@ def generate_individual_image(
151
  prompt,
152
  negative_prompt,
153
  num_inference_steps,
 
154
  height,
155
  width,
156
- guidance_scale,
157
  seed,
158
  num_images_per_prompt,
159
  model_choice,
160
  generator,
 
 
 
 
161
  )
162
 
163
  return output
164
 
165
 
166
  # Create the Gradio interface
167
- examples = [
168
- ["A white car racing fast to the moon."],
169
- ["A woman in a red dress singing on top of a building."],
170
- ["An astronaut on mars in a futuristic cyborg suit."],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  ]
172
 
173
  css = """
@@ -199,44 +322,137 @@ with gr.Blocks(css=css) as demo:
199
  info="Describe the image you want",
200
  placeholder="A cat...",
201
  )
202
- model_choice_1 = gr.Dropdown(
203
- label="Stable Diffusion Model 1",
204
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
205
  value="sd3 medium",
206
  )
207
- model_choice_2 = gr.Dropdown(
208
- label="Stable Diffusion Model 2",
209
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
210
  value="sdxl",
211
  )
212
  run_button = gr.Button("Run")
213
- result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
214
- result_2 = gr.Gallery(label="Generated Images (Model 2)", elem_id="gallery_2")
215
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
216
  with gr.Row():
217
- negative_prompt = gr.Textbox(
218
- label="Negative Prompt",
219
- info="Describe what you don't want in the image",
220
- value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
221
- placeholder="Ugly, bad anatomy...",
222
- )
223
- with gr.Row():
224
- num_inference_steps = gr.Slider(
225
- label="Number of Inference Steps",
226
- info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
227
- minimum=1,
228
- maximum=50,
229
- value=25,
230
- step=1,
231
- )
232
- guidance_scale = gr.Slider(
233
- label="Guidance Scale",
234
- info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
235
- minimum=0.0,
236
- maximum=10.0,
237
- value=7.5,
238
- step=0.1,
239
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  with gr.Row():
241
  width = gr.Slider(
242
  label="Width",
@@ -272,9 +488,114 @@ with gr.Blocks(css=css) as demo:
272
  value=2,
273
  )
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  gr.Examples(
276
- examples=examples,
277
- inputs=[prompt],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  outputs=[result_1, result_2],
279
  fn=generate_arena_images,
280
  )
@@ -288,14 +609,24 @@ with gr.Blocks(css=css) as demo:
288
  inputs=[
289
  prompt,
290
  negative_prompt,
291
- num_inference_steps,
292
- width,
 
 
293
  height,
294
- guidance_scale,
295
  seed,
296
  num_images_per_prompt,
297
- model_choice_1,
298
- model_choice_2,
 
 
 
 
 
 
 
 
299
  ],
300
  outputs=[result_1, result_2],
301
  )
@@ -310,7 +641,7 @@ with gr.Blocks(css=css) as demo:
310
  )
311
  model_choice = gr.Dropdown(
312
  label="Stable Diffusion Model",
313
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
314
  value="sd3 medium",
315
  )
316
  run_button = gr.Button("Run")
@@ -331,6 +662,7 @@ with gr.Blocks(css=css) as demo:
331
  maximum=50,
332
  value=25,
333
  step=1,
 
334
  )
335
  guidance_scale = gr.Slider(
336
  label="Guidance Scale",
@@ -339,6 +671,43 @@ with gr.Blocks(css=css) as demo:
339
  maximum=10.0,
340
  value=7.5,
341
  step=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  )
343
  with gr.Row():
344
  width = gr.Slider(
@@ -375,9 +744,65 @@ with gr.Blocks(css=css) as demo:
375
  value=2,
376
  )
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  gr.Examples(
379
- examples=examples,
380
- inputs=[prompt],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  outputs=[result],
382
  fn=generate_individual_image,
383
  )
@@ -392,12 +817,16 @@ with gr.Blocks(css=css) as demo:
392
  prompt,
393
  negative_prompt,
394
  num_inference_steps,
395
- width,
396
- height,
397
  guidance_scale,
 
 
398
  seed,
399
  num_images_per_prompt,
400
  model_choice,
 
 
 
 
401
  ],
402
  outputs=[result],
403
  )
 
1
  import torch
2
+ from diffusers import StableDiffusion3Pipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverSinglestepScheduler, StableCascadePriorPipeline, StableCascadeDecoderPipeline
3
  import gradio as gr
4
  import os
5
  import random
6
  import numpy as np
7
+ from PIL import Image
8
  import spaces
9
 
10
+ HF_TOKEN = os.getenv("HF_TOKEN") # login with hf token to access sd gated models
11
 
12
  if torch.cuda.is_available():
13
  device = "cuda"
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
 
22
  # Initialize the pipelines for each sd model
23
+ sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
 
 
24
  sd3_medium_pipe.enable_model_cpu_offload()
25
 
26
+ sd2_1_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
 
 
27
  sd2_1_pipe.enable_model_cpu_offload()
28
 
29
+ sdxl_pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
 
 
30
  sdxl_pipe.enable_model_cpu_offload()
31
 
32
+ sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash", torch_dtype=torch.float16)
 
 
33
  sdxl_flash_pipe.enable_model_cpu_offload()
34
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
35
  sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing")
36
 
37
+ stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
38
+ stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
39
+ stable_cascade_prior_pipe.enable_model_cpu_offload()
40
+ stable_cascade_decoder_pipe.enable_model_cpu_offload()
41
+
42
  # Helper function to generate images for a single model
43
  @spaces.GPU(duration=80)
44
  def generate_single_image(
45
  prompt,
46
  negative_prompt,
47
  num_inference_steps,
48
+ guidance_scale,
49
  height,
50
  width,
 
51
  seed,
52
  num_images_per_prompt,
53
  model_choice,
54
  generator,
55
+ prior_num_inference_steps=None,
56
+ prior_guidance_scale=None,
57
+ decoder_num_inference_steps=None,
58
+ decoder_guidance_scale=None,
59
  ):
60
  # Select the correct pipeline based on the model choice
61
  if model_choice == "sd3 medium":
 
66
  pipe = sdxl_pipe
67
  elif model_choice == "sdxl flash":
68
  pipe = sdxl_flash_pipe
69
+ elif model_choice == "stable cascade":
70
+ pipe = stable_cascade_prior_pipe
71
  else:
72
  raise ValueError(f"Invalid model choice: {model_choice}")
73
 
74
+ if model_choice == "stable cascade":
75
+ prior_output = pipe(
76
+ prompt=prompt,
77
+ negative_prompt=negative_prompt,
78
+ num_inference_steps=prior_num_inference_steps,
79
+ guidance_scale=prior_guidance_scale,
80
+ height=height,
81
+ width=width,
82
+ generator=generator,
83
+ num_images_per_prompt=num_images_per_prompt,
84
+ )
85
+
86
+ output = stable_cascade_decoder_pipe(
87
+ image_embeddings=prior_output.image_embeddings.to(torch.float16),
88
+ prompt=prompt,
89
+ negative_prompt=negative_prompt,
90
+ num_inference_steps=decoder_num_inference_steps,
91
+ guidance_scale=decoder_guidance_scale,
92
+ ).images
93
+ else:
94
+ output = pipe(
95
+ prompt=prompt,
96
+ negative_prompt=negative_prompt,
97
+ num_inference_steps=num_inference_steps,
98
+ guidance_scale=guidance_scale,
99
+ height=height,
100
+ width=width,
101
+ generator=generator,
102
+ num_images_per_prompt=num_images_per_prompt,
103
+ ).images
104
 
105
  return output
106
 
 
109
  def generate_arena_images(
110
  prompt,
111
  negative_prompt,
112
+ num_inference_steps_a,
113
+ guidance_scale_a,
114
+ num_inference_steps_b,
115
+ guidance_scale_b,
116
  height,
117
  width,
 
118
  seed,
119
  num_images_per_prompt,
120
+ model_choice_a,
121
+ model_choice_b,
122
+ prior_num_inference_steps_a,
123
+ prior_guidance_scale_a,
124
+ decoder_num_inference_steps_a,
125
+ decoder_guidance_scale_a,
126
+ prior_num_inference_steps_b,
127
+ prior_guidance_scale_b,
128
+ decoder_num_inference_steps_b,
129
+ decoder_guidance_scale_b,
130
  progress=gr.Progress(track_tqdm=True),
131
  ):
132
  if seed == 0:
 
135
  generator = torch.Generator().manual_seed(seed)
136
 
137
  # Generate images for both models
138
+ images_a = generate_single_image(
139
  prompt,
140
  negative_prompt,
141
+ num_inference_steps_a,
142
+ guidance_scale_a,
143
  height,
144
  width,
 
145
  seed,
146
  num_images_per_prompt,
147
+ model_choice_a,
148
  generator,
149
+ prior_num_inference_steps_a,
150
+ prior_guidance_scale_a,
151
+ decoder_num_inference_steps_a,
152
+ decoder_guidance_scale_a,
153
  )
154
+ images_b = generate_single_image(
155
  prompt,
156
  negative_prompt,
157
+ num_inference_steps_b,
158
+ guidance_scale_b,
159
  height,
160
  width,
 
161
  seed,
162
  num_images_per_prompt,
163
+ model_choice_b,
164
  generator,
165
+ prior_num_inference_steps_b,
166
+ prior_guidance_scale_b,
167
+ decoder_num_inference_steps_b,
168
+ decoder_guidance_scale_b,
169
  )
170
 
171
+ return images_a, images_b
172
 
173
  # Define the image generation function for the Individual tab
174
  @spaces.GPU(duration=80)
 
176
  prompt,
177
  negative_prompt,
178
  num_inference_steps,
179
+ guidance_scale,
180
  height,
181
  width,
 
182
  seed,
183
  num_images_per_prompt,
184
  model_choice,
185
+ prior_num_inference_steps,
186
+ prior_guidance_scale,
187
+ decoder_num_inference_steps,
188
+ decoder_guidance_scale,
189
  progress=gr.Progress(track_tqdm=True),
190
  ):
191
  if seed == 0:
 
197
  prompt,
198
  negative_prompt,
199
  num_inference_steps,
200
+ guidance_scale,
201
  height,
202
  width,
 
203
  seed,
204
  num_images_per_prompt,
205
  model_choice,
206
  generator,
207
+ prior_num_inference_steps,
208
+ prior_guidance_scale,
209
+ decoder_num_inference_steps,
210
+ decoder_guidance_scale,
211
  )
212
 
213
  return output
214
 
215
 
216
  # Create the Gradio interface
217
+ examples_arena = [
218
+ [
219
+ "A woman in a red dress singing on top of a building.",
220
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
221
+ 25,
222
+ 7.5,
223
+ 25,
224
+ 7.5,
225
+ 1024,
226
+ 1024,
227
+ 42,
228
+ 2,
229
+ "sd3 medium",
230
+ "sdxl",
231
+ 25, #prior_num_inference_steps_a
232
+ 4.0, #prior_guidance_scale_a
233
+ 12, #decoder_num_inference_steps_a
234
+ 0.0, #decoder_guidance_scale_a
235
+ 25, #prior_num_inference_steps_b
236
+ 4.0, #prior_guidance_scale_b
237
+ 12, #decoder_num_inference_steps_b
238
+ 0.0 #decoder_guidance_scale_b
239
+ ],
240
+ [
241
+ "An astronaut on mars in a futuristic cyborg suit.",
242
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
243
+ 25,
244
+ 7.5,
245
+ 25,
246
+ 7.5,
247
+ 1024,
248
+ 1024,
249
+ 42,
250
+ 2,
251
+ "sd3 medium",
252
+ "sdxl",
253
+ 25, #prior_num_inference_steps_a
254
+ 4.0, #prior_guidance_scale_a
255
+ 12, #decoder_num_inference_steps_a
256
+ 0.0, #decoder_guidance_scale_a
257
+ 25, #prior_num_inference_steps_b
258
+ 4.0, #prior_guidance_scale_b
259
+ 12, #decoder_num_inference_steps_b
260
+ 0.0 #decoder_guidance_scale_b
261
+ ],
262
+ ]
263
+ examples_individual = [
264
+ [
265
+ "A woman in a red dress singing on top of a building.",
266
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
267
+ 25,
268
+ 7.5,
269
+ 1024,
270
+ 1024,
271
+ 42,
272
+ 2,
273
+ "sdxl",
274
+ 25, #prior_num_inference_steps
275
+ 4.0, #prior_guidance_scale
276
+ 12, #decoder_num_inference_steps
277
+ 0.0 #decoder_guidance_scale
278
+ ],
279
+ [
280
+ "An astronaut on mars in a futuristic cyborg suit.",
281
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
282
+ 25,
283
+ 7.5,
284
+ 1024,
285
+ 1024,
286
+ 42,
287
+ 2,
288
+ "sdxl",
289
+ 25, #prior_num_inference_steps
290
+ 4.0, #prior_guidance_scale
291
+ 12, #decoder_num_inference_steps
292
+ 0.0 #decoder_guidance_scale
293
+ ],
294
  ]
295
 
296
  css = """
 
322
  info="Describe the image you want",
323
  placeholder="A cat...",
324
  )
325
+ model_choice_a = gr.Dropdown(
326
+ label="Stable Diffusion Model A",
327
+ choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
328
  value="sd3 medium",
329
  )
330
+ model_choice_b = gr.Dropdown(
331
+ label="Stable Diffusion Model B",
332
+ choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
333
  value="sdxl",
334
  )
335
  run_button = gr.Button("Run")
336
+ result_1 = gr.Gallery(label="Generated Images (Model A)", elem_id="gallery_1")
337
+ result_2 = gr.Gallery(label="Generated Images (Model B)", elem_id="gallery_2")
338
  with gr.Accordion("Advanced options", open=False):
339
+ negative_prompt = gr.Textbox(
340
+ label="Negative Prompt",
341
+ info="Describe what you don't want in the image",
342
+ value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
343
+ placeholder="Ugly, bad anatomy...",
344
+ )
345
  with gr.Row():
346
+ with gr.Column():
347
+ num_inference_steps_a = gr.Slider(
348
+ label="Inference Steps (Model A)",
349
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
350
+ minimum=1,
351
+ maximum=50,
352
+ value=25,
353
+ step=1,
354
+ visible=True
355
+ )
356
+ guidance_scale_a = gr.Slider(
357
+ label="Guidance Scale (Model A)",
358
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
359
+ minimum=0.0,
360
+ maximum=10.0,
361
+ value=7.5,
362
+ step=0.1,
363
+ visible=True
364
+ )
365
+ prior_num_inference_steps_a = gr.Slider(
366
+ label="Prior Inference Steps (Model A)",
367
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
368
+ minimum=1,
369
+ maximum=50,
370
+ value=25,
371
+ step=1,
372
+ visible=False
373
+ )
374
+ prior_guidance_scale_a = gr.Slider(
375
+ label="Prior Guidance Scale (Model A)",
376
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
377
+ minimum=0.0,
378
+ maximum=10.0,
379
+ value=4.0,
380
+ step=0.1,
381
+ visible=False
382
+ )
383
+ decoder_num_inference_steps_a = gr.Slider(
384
+ label="Decoder Inference Steps (Model A)",
385
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
386
+ minimum=1,
387
+ maximum=15,
388
+ value=15,
389
+ step=1,
390
+ visible=False
391
+ )
392
+ decoder_guidance_scale_a = gr.Slider(
393
+ label="Decoder Guidance Scale (Model A)",
394
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
395
+ minimum=0.0,
396
+ maximum=10.0,
397
+ value=0.0,
398
+ step=0.1,
399
+ visible=False
400
+ )
401
+ with gr.Column():
402
+ num_inference_steps_b = gr.Slider(
403
+ label="Inference Steps (Model B)",
404
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
405
+ minimum=1,
406
+ maximum=50,
407
+ value=25,
408
+ step=1,
409
+ visible=True
410
+ )
411
+ guidance_scale_b = gr.Slider(
412
+ label="Guidance Scale (Model B)",
413
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
414
+ minimum=0.0,
415
+ maximum=10.0,
416
+ value=7.5,
417
+ step=0.1,
418
+ visible=True
419
+ )
420
+ prior_num_inference_steps_b = gr.Slider(
421
+ label="Prior Inference Steps (Model B)",
422
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
423
+ minimum=1,
424
+ maximum=50,
425
+ value=25,
426
+ step=1,
427
+ visible=False
428
+ )
429
+ prior_guidance_scale_b = gr.Slider(
430
+ label="Prior Guidance Scale (Model B)",
431
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
432
+ minimum=0.0,
433
+ maximum=10.0,
434
+ value=4.0,
435
+ step=0.1,
436
+ visible=False
437
+ )
438
+ decoder_num_inference_steps_b = gr.Slider(
439
+ label="Decoder Inference Steps (Model B)",
440
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
441
+ minimum=1,
442
+ maximum=15,
443
+ value=12,
444
+ step=1,
445
+ visible=False
446
+ )
447
+ decoder_guidance_scale_b = gr.Slider(
448
+ label="Decoder Guidance Scale (Model B)",
449
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
450
+ minimum=0.0,
451
+ maximum=10.0,
452
+ value=0.0,
453
+ step=0.1,
454
+ visible=False
455
+ )
456
  with gr.Row():
457
  width = gr.Slider(
458
  label="Width",
 
488
  value=2,
489
  )
490
 
491
+ def toggle_visibility_arena_a(model_choice_a):
492
+ if model_choice_a == "stable cascade":
493
+ return {
494
+ num_inference_steps_a: gr.update(visible=False),
495
+ guidance_scale_a: gr.update(visible=False),
496
+ prior_num_inference_steps_a: gr.update(visible=True),
497
+ prior_guidance_scale_a: gr.update(visible=True),
498
+ decoder_num_inference_steps_a: gr.update(visible=True),
499
+ decoder_guidance_scale_a: gr.update(visible=True),
500
+ }
501
+ elif model_choice_a == "sdxl flash":
502
+ return {
503
+ num_inference_steps_a: gr.update(visible=True, maximum=15, value=8),
504
+ guidance_scale_a: gr.update(visible=True, maximum=6.0, value=3.5),
505
+ prior_num_inference_steps_a: gr.update(visible=False),
506
+ prior_guidance_scale_a: gr.update(visible=False),
507
+ decoder_num_inference_steps_a: gr.update(visible=False),
508
+ decoder_guidance_scale_a: gr.update(visible=False),
509
+ }
510
+ else:
511
+ return {
512
+ num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
513
+ guidance_scale_a: gr.update(visible=True, maximum=10.0, value=7.5),
514
+ prior_num_inference_steps_a: gr.update(visible=False),
515
+ prior_guidance_scale_a: gr.update(visible=False),
516
+ decoder_num_inference_steps_a: gr.update(visible=False),
517
+ decoder_guidance_scale_a: gr.update(visible=False),
518
+ }
519
+
520
+ def toggle_visibility_arena_b(model_choice_b):
521
+ if model_choice_b == "stable cascade":
522
+ return {
523
+ num_inference_steps_b: gr.update(visible=False),
524
+ guidance_scale_b: gr.update(visible=False),
525
+ prior_num_inference_steps_b: gr.update(visible=True),
526
+ prior_guidance_scale_b: gr.update(visible=True),
527
+ decoder_num_inference_steps_b: gr.update(visible=True),
528
+ decoder_guidance_scale_b: gr.update(visible=True),
529
+ }
530
+ elif model_choice_b == "sdxl flash":
531
+ return {
532
+ num_inference_steps_b: gr.update(visible=True, maximum=15, value=8),
533
+ guidance_scale_b: gr.update(visible=True, maximum=6.0, value=3.5),
534
+ prior_num_inference_steps_b: gr.update(visible=False),
535
+ prior_guidance_scale_b: gr.update(visible=False),
536
+ decoder_num_inference_steps_b: gr.update(visible=False),
537
+ decoder_guidance_scale_b: gr.update(visible=False),
538
+ }
539
+ else:
540
+ return {
541
+ num_inference_steps_b: gr.update(visible=True, maximum=50, value=25),
542
+ guidance_scale_b: gr.update(visible=True, maximum=10.0, value=7.5),
543
+ prior_num_inference_steps_b: gr.update(visible=False),
544
+ prior_guidance_scale_b: gr.update(visible=False),
545
+ decoder_num_inference_steps_b: gr.update(visible=False),
546
+ decoder_guidance_scale_b: gr.update(visible=False),
547
+ }
548
+
549
+ model_choice_a.change(
550
+ toggle_visibility_arena_a,
551
+ inputs=[model_choice_a],
552
+ outputs=[
553
+ num_inference_steps_a,
554
+ guidance_scale_a,
555
+ prior_num_inference_steps_a,
556
+ prior_guidance_scale_a,
557
+ decoder_num_inference_steps_a,
558
+ decoder_guidance_scale_a
559
+ ]
560
+ )
561
+ model_choice_b.change(
562
+ toggle_visibility_arena_b,
563
+ inputs=[model_choice_b],
564
+ outputs=[
565
+ num_inference_steps_b,
566
+ guidance_scale_b,
567
+ prior_num_inference_steps_b,
568
+ prior_guidance_scale_b,
569
+ decoder_num_inference_steps_b,
570
+ decoder_guidance_scale_b
571
+ ]
572
+ )
573
+
574
+
575
  gr.Examples(
576
+ examples=examples_arena,
577
+ inputs=[
578
+ prompt,
579
+ negative_prompt,
580
+ num_inference_steps_a,
581
+ guidance_scale_a,
582
+ num_inference_steps_b,
583
+ guidance_scale_b,
584
+ height,
585
+ width,
586
+ seed,
587
+ num_images_per_prompt,
588
+ model_choice_a,
589
+ model_choice_b,
590
+ prior_num_inference_steps_a,
591
+ prior_guidance_scale_a,
592
+ decoder_num_inference_steps_a,
593
+ decoder_guidance_scale_a,
594
+ prior_num_inference_steps_b,
595
+ prior_guidance_scale_b,
596
+ decoder_num_inference_steps_b,
597
+ decoder_guidance_scale_b,
598
+ ],
599
  outputs=[result_1, result_2],
600
  fn=generate_arena_images,
601
  )
 
609
  inputs=[
610
  prompt,
611
  negative_prompt,
612
+ num_inference_steps_a,
613
+ guidance_scale_a,
614
+ num_inference_steps_b,
615
+ guidance_scale_b,
616
  height,
617
+ width,
618
  seed,
619
  num_images_per_prompt,
620
+ model_choice_a,
621
+ model_choice_b,
622
+ prior_num_inference_steps_a,
623
+ prior_guidance_scale_a,
624
+ decoder_num_inference_steps_a,
625
+ decoder_guidance_scale_a,
626
+ prior_num_inference_steps_b,
627
+ prior_guidance_scale_b,
628
+ decoder_num_inference_steps_b,
629
+ decoder_guidance_scale_b,
630
  ],
631
  outputs=[result_1, result_2],
632
  )
 
641
  )
642
  model_choice = gr.Dropdown(
643
  label="Stable Diffusion Model",
644
+ choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
645
  value="sd3 medium",
646
  )
647
  run_button = gr.Button("Run")
 
662
  maximum=50,
663
  value=25,
664
  step=1,
665
+ visible=True
666
  )
667
  guidance_scale = gr.Slider(
668
  label="Guidance Scale",
 
671
  maximum=10.0,
672
  value=7.5,
673
  step=0.1,
674
+ visible=True
675
+ )
676
+ prior_num_inference_steps = gr.Slider(
677
+ label="Prior Inference Steps",
678
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
679
+ minimum=1,
680
+ maximum=50,
681
+ value=25,
682
+ step=1,
683
+ visible=False
684
+ )
685
+ prior_guidance_scale = gr.Slider(
686
+ label="Prior Guidance Scale",
687
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
688
+ minimum=0.0,
689
+ maximum=10.0,
690
+ value=4.0,
691
+ step=0.1,
692
+ visible=False
693
+ )
694
+ decoder_num_inference_steps = gr.Slider(
695
+ label="Decoder Inference Steps",
696
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
697
+ minimum=1,
698
+ maximum=15,
699
+ value=12,
700
+ step=1,
701
+ visible=False
702
+ )
703
+ decoder_guidance_scale = gr.Slider(
704
+ label="Decoder Guidance Scale",
705
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
706
+ minimum=0.0,
707
+ maximum=10.0,
708
+ value=0.0,
709
+ step=0.1,
710
+ visible=False
711
  )
712
  with gr.Row():
713
  width = gr.Slider(
 
744
  value=2,
745
  )
746
 
747
+ def toggle_visibility_individual(model_choice):
748
+ if model_choice == "stable cascade":
749
+ return {
750
+ num_inference_steps: gr.update(visible=False),
751
+ guidance_scale: gr.update(visible=False),
752
+ prior_num_inference_steps: gr.update(visible=True),
753
+ prior_guidance_scale: gr.update(visible=True),
754
+ decoder_num_inference_steps: gr.update(visible=True),
755
+ decoder_guidance_scale: gr.update(visible=True),
756
+ }
757
+ elif model_choice == "sdxl flash":
758
+ return {
759
+ num_inference_steps: gr.update(visible=True, maximum=15, value=8),
760
+ guidance_scale: gr.update(visible=True, maximum=6.0, value=3.5),
761
+ prior_num_inference_steps: gr.update(visible=False),
762
+ prior_guidance_scale: gr.update(visible=False),
763
+ decoder_num_inference_steps: gr.update(visible=False),
764
+ decoder_guidance_scale: gr.update(visible=False),
765
+ }
766
+ else:
767
+ return {
768
+ num_inference_steps: gr.update(visible=True, maximum=50, value=25),
769
+ guidance_scale: gr.update(visible=True, maximum=10.0, value=7.5),
770
+ prior_num_inference_steps: gr.update(visible=False),
771
+ prior_guidance_scale: gr.update(visible=False),
772
+ decoder_num_inference_steps: gr.update(visible=False),
773
+ decoder_guidance_scale: gr.update(visible=False),
774
+ }
775
+
776
+ model_choice.change(
777
+ toggle_visibility_individual,
778
+ inputs=[model_choice],
779
+ outputs=[
780
+ num_inference_steps,
781
+ guidance_scale,
782
+ prior_num_inference_steps,
783
+ prior_guidance_scale,
784
+ decoder_num_inference_steps,
785
+ decoder_guidance_scale
786
+ ]
787
+ )
788
+
789
  gr.Examples(
790
+ examples=examples_individual,
791
+ inputs=[
792
+ prompt,
793
+ negative_prompt,
794
+ num_inference_steps,
795
+ guidance_scale,
796
+ height,
797
+ width,
798
+ seed,
799
+ num_images_per_prompt,
800
+ model_choice,
801
+ prior_num_inference_steps,
802
+ prior_guidance_scale,
803
+ decoder_num_inference_steps,
804
+ decoder_guidance_scale,
805
+ ],
806
  outputs=[result],
807
  fn=generate_individual_image,
808
  )
 
817
  prompt,
818
  negative_prompt,
819
  num_inference_steps,
 
 
820
  guidance_scale,
821
+ height,
822
+ width,
823
  seed,
824
  num_images_per_prompt,
825
  model_choice,
826
+ prior_num_inference_steps,
827
+ prior_guidance_scale,
828
+ decoder_num_inference_steps,
829
+ decoder_guidance_scale,
830
  ],
831
  outputs=[result],
832
  )