Nick088 commited on
Commit
7c42710
·
verified ·
1 Parent(s): 43cbce0

code optimization

Browse files

free up cuda chache + gc + load models on cpu only when needed and unload after used + saving up on zerogpu duration

Files changed (1) hide show
  1. app.py +234 -210
app.py CHANGED
@@ -3,8 +3,9 @@ from diffusers import StableDiffusion3Pipeline, StableDiffusionPipeline, StableD
3
  import gradio as gr
4
  import os
5
  import random
6
- import numpy as np
7
  from PIL import Image
 
8
  import spaces
9
 
10
  HF_TOKEN = os.getenv("HF_TOKEN") # login with hf read token to access sd gated models
@@ -17,61 +18,65 @@ else:
17
  print("Using CPU")
18
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
 
22
- # Initialize the pipelines for each sd model
 
23
 
24
- # sd3 medium
25
- sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
26
- "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
27
- )
28
- sd3_medium_pipe.enable_model_cpu_offload()
29
-
30
- # sd 2.1
31
- sd2_1_pipe = StableDiffusionPipeline.from_pretrained(
32
- "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
33
- )
34
- sd2_1_pipe.enable_model_cpu_offload()
35
-
36
- # sdxl
37
- sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
38
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
39
- )
40
- sdxl_pipe.enable_model_cpu_offload()
41
 
42
- # sdxl flash
43
- sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained(
44
- "sd-community/sdxl-flash", torch_dtype=torch.float16
45
- )
46
- sdxl_flash_pipe.enable_model_cpu_offload()
47
- # Ensure sampler uses "trailing" timesteps for sdxl flash.
48
- sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
49
- sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing"
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # stable cascade
53
- stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained(
54
- "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
55
- )
56
- stable_cascade_prior_pipe.enable_model_cpu_offload()
57
- stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained(
58
- "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16
59
- )
60
- stable_cascade_decoder_pipe.enable_model_cpu_offload()
61
 
62
- # sd 1.5
63
- sd1_5_pipe = StableDiffusionPipeline.from_pretrained(
64
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
65
- )
66
- sd1_5_pipe.enable_model_cpu_offload()
67
 
68
- # empty cache to free up gpu memory before inference
69
- torch.cuda.empty_cache()
 
 
 
 
 
70
 
71
- # Helper function to generate images for a single model
72
  @spaces.GPU(duration=80)
73
- def generate_single_image(
74
  prompt,
 
75
  negative_prompt,
76
  num_inference_steps,
77
  guidance_scale,
@@ -79,71 +84,114 @@ def generate_single_image(
79
  width,
80
  seed,
81
  num_images_per_prompt,
82
- model_choice,
83
- generator,
84
  prior_num_inference_steps=None,
85
  prior_guidance_scale=None,
86
  decoder_num_inference_steps=None,
87
  decoder_guidance_scale=None,
88
  ):
89
- # Select the correct pipeline based on the model choice
90
- if model_choice == "sd3 medium":
91
- pipe = sd3_medium_pipe
92
- elif model_choice == "sd2.1":
93
- pipe = sd2_1_pipe
94
- elif model_choice == "sdxl":
95
- pipe = sdxl_pipe
96
- elif model_choice == "sdxl flash":
97
- pipe = sdxl_flash_pipe
98
- elif model_choice == "stable cascade":
99
- pipe = stable_cascade_prior_pipe
100
- elif model_choice == "sd1.5":
101
- pipe = sd1_5_pipe
102
- else:
103
- raise ValueError(f"Invalid model choice: {model_choice}")
104
 
105
- # stable cascade has 2 different type of pipelines
106
- if model_choice == "stable cascade":
107
- prior_output = pipe(
108
- prompt=prompt,
109
- negative_prompt=negative_prompt,
110
- num_inference_steps=prior_num_inference_steps,
111
- guidance_scale=prior_guidance_scale,
112
- height=height,
113
- width=width,
114
- generator=generator,
115
- num_images_per_prompt=num_images_per_prompt,
116
  )
117
 
118
- output = stable_cascade_decoder_pipe(
119
- image_embeddings=prior_output.image_embeddings.to(torch.float16),
120
- prompt=prompt,
121
- negative_prompt=negative_prompt,
122
- num_inference_steps=decoder_num_inference_steps,
123
- guidance_scale=decoder_guidance_scale,
124
- ).images
125
 
126
- # the rest of the models have similar pipeline
127
- else:
128
- output = pipe(
129
- prompt=prompt,
130
- negative_prompt=negative_prompt,
131
- num_inference_steps=num_inference_steps,
132
- guidance_scale=guidance_scale,
133
- height=height,
134
- width=width,
135
- generator=generator,
136
- num_images_per_prompt=num_images_per_prompt,
137
- ).images
138
 
139
- # empty cache to free up gpu memory
140
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  return output
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # Define the image generation function for the Arena tab
146
- @spaces.GPU(duration=240)
147
  def generate_arena_images(
148
  prompt,
149
  negative_prompt,
@@ -188,15 +236,12 @@ def generate_arena_images(
188
  decoder_guidance_scale_d,
189
  progress=gr.Progress(track_tqdm=True),
190
  ):
191
- if seed == 0:
192
- seed = random.randint(1, MAX_SEED)
193
-
194
- generator = torch.Generator().manual_seed(seed)
195
 
196
  # Generate images for selected models
197
  if num_models_to_compare >= 2:
198
  images_a = generate_single_image(
199
  prompt,
 
200
  negative_prompt,
201
  num_inference_steps_a,
202
  guidance_scale_a,
@@ -204,8 +249,6 @@ def generate_arena_images(
204
  width_a,
205
  seed,
206
  num_images_per_prompt,
207
- model_choice_a,
208
- generator,
209
  prior_num_inference_steps_a,
210
  prior_guidance_scale_a,
211
  decoder_num_inference_steps_a,
@@ -213,6 +256,7 @@ def generate_arena_images(
213
  )
214
  images_b = generate_single_image(
215
  prompt,
 
216
  negative_prompt,
217
  num_inference_steps_b,
218
  guidance_scale_b,
@@ -220,8 +264,6 @@ def generate_arena_images(
220
  width_b,
221
  seed,
222
  num_images_per_prompt,
223
- model_choice_b,
224
- generator,
225
  prior_num_inference_steps_b,
226
  prior_guidance_scale_b,
227
  decoder_num_inference_steps_b,
@@ -233,6 +275,7 @@ def generate_arena_images(
233
  if num_models_to_compare >= 3:
234
  images_c = generate_single_image(
235
  prompt,
 
236
  negative_prompt,
237
  num_inference_steps_c,
238
  guidance_scale_c,
@@ -240,8 +283,6 @@ def generate_arena_images(
240
  width_c,
241
  seed,
242
  num_images_per_prompt,
243
- model_choice_c,
244
- generator,
245
  prior_num_inference_steps_c,
246
  prior_guidance_scale_c,
247
  decoder_num_inference_steps_c,
@@ -253,6 +294,7 @@ def generate_arena_images(
253
  if num_models_to_compare >= 4:
254
  images_d = generate_single_image(
255
  prompt,
 
256
  negative_prompt,
257
  num_inference_steps_d,
258
  guidance_scale_d,
@@ -260,8 +302,6 @@ def generate_arena_images(
260
  width_d,
261
  seed,
262
  num_images_per_prompt,
263
- model_choice_d,
264
- generator,
265
  prior_num_inference_steps_d,
266
  prior_guidance_scale_d,
267
  decoder_num_inference_steps_d,
@@ -274,9 +314,9 @@ def generate_arena_images(
274
 
275
 
276
  # Define the image generation function for the Individual tab
277
- @spaces.GPU(duration=90)
278
  def generate_individual_image(
279
  prompt,
 
280
  negative_prompt,
281
  num_inference_steps,
282
  guidance_scale,
@@ -284,20 +324,16 @@ def generate_individual_image(
284
  width,
285
  seed,
286
  num_images_per_prompt,
287
- model_choice,
288
  prior_num_inference_steps,
289
  prior_guidance_scale,
290
  decoder_num_inference_steps,
291
  decoder_guidance_scale,
292
  progress=gr.Progress(track_tqdm=True),
293
  ):
294
- if seed == 0:
295
- seed = random.randint(1, MAX_SEED)
296
-
297
- generator = torch.Generator().manual_seed(seed)
298
 
299
  output = generate_single_image(
300
  prompt,
 
301
  negative_prompt,
302
  num_inference_steps,
303
  guidance_scale,
@@ -305,8 +341,6 @@ def generate_individual_image(
305
  width,
306
  seed,
307
  num_images_per_prompt,
308
- model_choice,
309
- generator,
310
  prior_num_inference_steps,
311
  prior_guidance_scale,
312
  decoder_num_inference_steps,
@@ -630,18 +664,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
630
  width_a = gr.Slider(
631
  label="Width (Model A)",
632
  info="Width of the Image",
633
- minimum=256,
634
- maximum=1344,
635
- step=32,
636
  value=1024,
 
637
  )
638
  height_a = gr.Slider(
639
  label="Height (Model A)",
640
  info="Height of the Image",
641
- minimum=256,
642
- maximum=1344,
643
- step=32,
644
  value=1024,
 
645
  )
646
  with gr.Column():
647
  num_inference_steps_b = gr.Slider(
@@ -650,7 +684,7 @@ with gr.Blocks(theme=theme, css=css) as demo:
650
  minimum=1,
651
  maximum=50,
652
  value=25,
653
- step=1,
654
  visible=True,
655
  )
656
  guidance_scale_b = gr.Slider(
@@ -701,18 +735,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
701
  width_b = gr.Slider(
702
  label="Width (Model B)",
703
  info="Width of the Image",
704
- minimum=256,
705
- maximum=1344,
706
- step=32,
707
  value=1024,
 
708
  )
709
  height_b = gr.Slider(
710
  label="Height (Model B)",
711
  info="Height of the Image",
712
- minimum=256,
713
- maximum=1344,
714
- step=32,
715
  value=1024,
 
716
  )
717
  with gr.Column(visible=False) as model_c_options:
718
  num_inference_steps_c = gr.Slider(
@@ -772,18 +806,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
772
  width_c = gr.Slider(
773
  label="Width (Model C)",
774
  info="Width of the Image",
775
- minimum=256,
776
- maximum=1344,
777
- step=32,
778
  value=1024,
 
779
  )
780
  height_c = gr.Slider(
781
  label="Height (Model C)",
782
  info="Height of the Image",
783
- minimum=256,
784
- maximum=1344,
785
- step=32,
786
  value=1024,
 
787
  )
788
  with gr.Column(visible=False) as model_d_options:
789
  num_inference_steps_d = gr.Slider(
@@ -843,18 +877,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
843
  width_d = gr.Slider(
844
  label="Width (Model D)",
845
  info="Width of the Image",
846
- minimum=256,
847
- maximum=1344,
848
- step=32,
849
  value=1024,
 
850
  )
851
  height_d = gr.Slider(
852
  label="Height (Model D)",
853
  info="Height of the Image",
854
- minimum=256,
855
- maximum=1344,
856
- step=32,
857
  value=1024,
 
858
  )
859
  with gr.Row():
860
  seed = gr.Slider(
@@ -883,6 +917,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
883
  prior_guidance_scale_a: gr.update(visible=True),
884
  decoder_num_inference_steps_a: gr.update(visible=True),
885
  decoder_guidance_scale_a: gr.update(visible=True),
 
 
886
  }
887
  elif model_choice_a == "sdxl flash":
888
  return {
@@ -892,6 +928,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
892
  prior_guidance_scale_a: gr.update(visible=False),
893
  decoder_num_inference_steps_a: gr.update(visible=False),
894
  decoder_guidance_scale_a: gr.update(visible=False),
 
 
895
  }
896
  elif model_choice_a == "sd1.5":
897
  return {
@@ -900,26 +938,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
900
  prior_guidance_scale_a: gr.update(visible=True),
901
  decoder_num_inference_steps_a: gr.update(visible=True),
902
  decoder_guidance_scale_a: gr.update(visible=True),
903
- }
904
- elif model_choice_a == "sdxl flash":
905
- return {
906
- num_inference_steps_a: gr.update(visible=True, maximum=15, value=8),
907
- guidance_scale_a: gr.update(visible=True, maximum=6.0, value=3.5),
908
- prior_num_inference_steps_a: gr.update(visible=False),
909
- prior_guidance_scale_a: gr.update(visible=False),
910
- decoder_num_inference_steps_a: gr.update(visible=False),
911
- decoder_guidance_scale_a: gr.update(visible=False),
912
- }
913
- elif model_choice_a == "sd1.5":
914
- return {
915
- num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
916
- guidance_scale_a: gr.update(visible=True, maximum=10.0, value=7.5),
917
- prior_num_inference_steps_a: gr.update(visible=False),
918
- prior_guidance_scale_a: gr.update(visible=False),
919
- decoder_num_inference_steps_a: gr.update(visible=False),
920
- decoder_guidance_scale_a: gr.update(visible=False),
921
- width_a: gr.update(value=512, maximum=768),
922
- height_a: gr.update(value=512, maximum=768),
923
  }
924
  elif model_choice_a == "sd2.1":
925
  return {
@@ -929,8 +949,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
929
  prior_guidance_scale_a: gr.update(visible=False),
930
  decoder_num_inference_steps_a: gr.update(visible=False),
931
  decoder_guidance_scale_a: gr.update(visible=False),
932
- width_a: gr.update(value=768, maximum=1024),
933
- height_a: gr.update(value=768, maximum=1024),
934
  }
935
  else:
936
  return {
@@ -940,8 +960,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
940
  prior_guidance_scale_a: gr.update(visible=False),
941
  decoder_num_inference_steps_a: gr.update(visible=False),
942
  decoder_guidance_scale_a: gr.update(visible=False),
943
- width_a: gr.update(maximum=1344),
944
- height_a: gr.update(maximum=1344),
945
  }
946
 
947
  def toggle_visibility_arena_b(model_choice_b):
@@ -953,6 +973,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
953
  prior_guidance_scale_b: gr.update(visible=True),
954
  decoder_num_inference_steps_b: gr.update(visible=True),
955
  decoder_guidance_scale_b: gr.update(visible=True),
 
 
956
  }
957
  elif model_choice_b == "sdxl flash":
958
  return {
@@ -962,6 +984,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
962
  prior_guidance_scale_b: gr.update(visible=False),
963
  decoder_num_inference_steps_b: gr.update(visible=False),
964
  decoder_guidance_scale_b: gr.update(visible=False),
 
 
965
  }
966
  elif model_choice_b == "sd1.5":
967
  return {
@@ -971,8 +995,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
971
  prior_guidance_scale_b: gr.update(visible=False),
972
  decoder_num_inference_steps_b: gr.update(visible=False),
973
  decoder_guidance_scale_b: gr.update(visible=False),
974
- width_b: gr.update(value=512, maximum=768),
975
- height_b: gr.update(value=512, maximum=768),
976
  }
977
  elif model_choice_b == "sd2.1":
978
  return {
@@ -982,8 +1006,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
982
  prior_guidance_scale_b: gr.update(visible=False),
983
  decoder_num_inference_steps_b: gr.update(visible=False),
984
  decoder_guidance_scale_b: gr.update(visible=False),
985
- width_b: gr.update(value=768, maximum=1024),
986
- height_b: gr.update(value=768, maximum=1024),
987
  }
988
  else:
989
  return {
@@ -993,8 +1017,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
993
  prior_guidance_scale_b: gr.update(visible=False),
994
  decoder_num_inference_steps_b: gr.update(visible=False),
995
  decoder_guidance_scale_b: gr.update(visible=False),
996
- width_b: gr.update(maximum=1344),
997
- height_b: gr.update(maximum=1344),
998
  }
999
 
1000
  def toggle_visibility_arena_c(model_choice_c):
@@ -1006,8 +1030,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1006
  prior_guidance_scale_c: gr.update(visible=True),
1007
  decoder_num_inference_steps_c: gr.update(visible=True),
1008
  decoder_guidance_scale_c: gr.update(visible=True),
1009
- width_c: gr.update(value=1024, maximum=1344),
1010
- height_c: gr.update(value=1024, maximum=1344),
1011
  }
1012
  elif model_choice_c == "sdxl flash":
1013
  return {
@@ -1017,8 +1041,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1017
  prior_guidance_scale_c: gr.update(visible=False),
1018
  decoder_num_inference_steps_c: gr.update(visible=False),
1019
  decoder_guidance_scale_c: gr.update(visible=False),
1020
- width_c: gr.update(value=1024, maximum=1344),
1021
- height_c: gr.update(value=1024, maximum=1344),
1022
  }
1023
  elif model_choice_c == "sd1.5":
1024
  return {
@@ -1028,8 +1052,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1028
  prior_guidance_scale_c: gr.update(visible=False),
1029
  decoder_num_inference_steps_c: gr.update(visible=False),
1030
  decoder_guidance_scale_c: gr.update(visible=False),
1031
- width_c: gr.update(value=512, maximum=768),
1032
- height_c: gr.update(value=512, maximum=768),
1033
  }
1034
  elif model_choice_c == "sd2.1":
1035
  return {
@@ -1039,8 +1063,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1039
  prior_guidance_scale_c: gr.update(visible=False),
1040
  decoder_num_inference_steps_c: gr.update(visible=False),
1041
  decoder_guidance_scale_c: gr.update(visible=False),
1042
- width_c: gr.update(value=768, maximum=1024),
1043
- height_c: gr.update(value=768, maximum=1024),
1044
  }
1045
  else:
1046
  return {
@@ -1050,8 +1074,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1050
  prior_guidance_scale_c: gr.update(visible=False),
1051
  decoder_num_inference_steps_c: gr.update(visible=False),
1052
  decoder_guidance_scale_c: gr.update(visible=False),
1053
- width_c: gr.update(value=1024, maximum=1344),
1054
- height_c: gr.update(value=1024, maximum=1344),
1055
  }
1056
 
1057
  def toggle_visibility_arena_d(model_choice_d):
@@ -1063,8 +1087,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1063
  prior_guidance_scale_d: gr.update(visible=True),
1064
  decoder_num_inference_steps_d: gr.update(visible=True),
1065
  decoder_guidance_scale_d: gr.update(visible=True),
1066
- width_d: gr.update(value=1024, maximum=1344),
1067
- height_d: gr.update(value=1024, maximum=1344),
1068
  }
1069
  elif model_choice_d == "sdxl flash":
1070
  return {
@@ -1074,8 +1098,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1074
  prior_guidance_scale_d: gr.update(visible=False),
1075
  decoder_num_inference_steps_d: gr.update(visible=False),
1076
  decoder_guidance_scale_d: gr.update(visible=False),
1077
- width_d: gr.update(value=1024, maximum=1344),
1078
- height_d: gr.update(value=1024, maximum=1344),
1079
  }
1080
  elif model_choice_d == "sd1.5":
1081
  return {
@@ -1085,8 +1109,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1085
  prior_guidance_scale_d: gr.update(visible=False),
1086
  decoder_num_inference_steps_d: gr.update(visible=False),
1087
  decoder_guidance_scale_d: gr.update(visible=False),
1088
- width_d: gr.update(value=512, maximum=768),
1089
- height_d: gr.update(value=512, maximum=768),
1090
  }
1091
  elif model_choice_d == "sd2.1":
1092
  return {
@@ -1096,8 +1120,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1096
  prior_guidance_scale_d: gr.update(visible=False),
1097
  decoder_num_inference_steps_d: gr.update(visible=False),
1098
  decoder_guidance_scale_d: gr.update(visible=False),
1099
- width_d: gr.update(value=768, maximum=1024),
1100
- height_d: gr.update(value=768, maximum=1024),
1101
  }
1102
  else:
1103
  return {
@@ -1107,8 +1131,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1107
  prior_guidance_scale_d: gr.update(visible=False),
1108
  decoder_num_inference_steps_d: gr.update(visible=False),
1109
  decoder_guidance_scale_d: gr.update(visible=False),
1110
- width_d: gr.update(value=1024, maximum=1344),
1111
- height_d: gr.update(value=1024, maximum=1344),
1112
  }
1113
 
1114
  model_choice_a.change(
@@ -1402,18 +1426,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
1402
  width = gr.Slider(
1403
  label="Width",
1404
  info="Width of the Image",
1405
- minimum=256,
1406
- maximum=1344,
1407
- step=32,
1408
  value=1024,
 
1409
  )
1410
  height = gr.Slider(
1411
  label="Height",
1412
  info="Height of the Image",
1413
- minimum=256,
1414
- maximum=1344,
1415
- step=32,
1416
  value=1024,
 
1417
  )
1418
  with gr.Row():
1419
  seed = gr.Slider(
@@ -1442,8 +1466,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1442
  prior_guidance_scale: gr.update(visible=True),
1443
  decoder_num_inference_steps: gr.update(visible=True),
1444
  decoder_guidance_scale: gr.update(visible=True),
1445
- width: gr.update(value=1024, maximum=1344),
1446
- height: gr.update(value=1024, maximum=1344),
1447
  }
1448
  elif model_choice == "sdxl flash":
1449
  return {
@@ -1453,8 +1477,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1453
  prior_guidance_scale: gr.update(visible=False),
1454
  decoder_num_inference_steps: gr.update(visible=False),
1455
  decoder_guidance_scale: gr.update(visible=False),
1456
- width: gr.update(value=1024, maximum=1344),
1457
- height: gr.update(value=1024, maximum=1344),
1458
  }
1459
  elif model_choice == "sd1.5":
1460
  return {
@@ -1464,8 +1488,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1464
  prior_guidance_scale: gr.update(visible=False),
1465
  decoder_num_inference_steps: gr.update(visible=False),
1466
  decoder_guidance_scale: gr.update(visible=False),
1467
- width: gr.update(value=512, maximum=768),
1468
- height: gr.update(value=512, maximum=768),
1469
  }
1470
  elif model_choice == "sd2.1":
1471
  return {
@@ -1475,8 +1499,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1475
  prior_guidance_scale: gr.update(visible=False),
1476
  decoder_num_inference_steps: gr.update(visible=False),
1477
  decoder_guidance_scale: gr.update(visible=False),
1478
- width: gr.update(value=768, maximum=1024),
1479
- height: gr.update(value=768, maximum=1024),
1480
  }
1481
  else:
1482
  return {
@@ -1486,8 +1510,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1486
  prior_guidance_scale: gr.update(visible=False),
1487
  decoder_num_inference_steps: gr.update(visible=False),
1488
  decoder_guidance_scale: gr.update(visible=False),
1489
- width: gr.update(value=1024, maximum=1344),
1490
- height: gr.update(value=1024, maximum=1344),
1491
  }
1492
 
1493
  model_choice.change(
@@ -1509,6 +1533,7 @@ with gr.Blocks(theme=theme, css=css) as demo:
1509
  examples=examples_individual,
1510
  inputs=[
1511
  prompt,
 
1512
  negative_prompt,
1513
  num_inference_steps,
1514
  guidance_scale,
@@ -1516,7 +1541,6 @@ with gr.Blocks(theme=theme, css=css) as demo:
1516
  width,
1517
  seed,
1518
  num_images_per_prompt,
1519
- model_choice,
1520
  prior_num_inference_steps,
1521
  prior_guidance_scale,
1522
  decoder_num_inference_steps,
@@ -1534,6 +1558,7 @@ with gr.Blocks(theme=theme, css=css) as demo:
1534
  fn=generate_individual_image,
1535
  inputs=[
1536
  prompt,
 
1537
  negative_prompt,
1538
  num_inference_steps,
1539
  guidance_scale,
@@ -1541,7 +1566,6 @@ with gr.Blocks(theme=theme, css=css) as demo:
1541
  width,
1542
  seed,
1543
  num_images_per_prompt,
1544
- model_choice,
1545
  prior_num_inference_steps,
1546
  prior_guidance_scale,
1547
  decoder_num_inference_steps,
 
3
  import gradio as gr
4
  import os
5
  import random
6
+ import numpy
7
  from PIL import Image
8
+ import gc # free up memory
9
  import spaces
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN") # login with hf read token to access sd gated models
 
18
  print("Using CPU")
19
 
20
 
21
+ MAX_SEED = numpy.iinfo(numpy.int32).max
22
 
23
+ # Global dictionary to store pipelines
24
+ PIPELINES = {}
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def load_pipeline(model_choice):
28
+ """Loads the specified pipeline and stores it in the PIPELINES dictionary."""
29
+ if model_choice not in PIPELINES:
30
+ if model_choice == "sd3 medium":
31
+ PIPELINES[model_choice] = StableDiffusion3Pipeline.from_pretrained(
32
+ "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
33
+ )
34
+ elif model_choice == "sd2.1":
35
+ PIPELINES[model_choice] = StableDiffusionPipeline.from_pretrained(
36
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
37
+ )
38
+ elif model_choice == "sdxl":
39
+ PIPELINES[model_choice] = StableDiffusionXLPipeline.from_pretrained(
40
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
41
+ )
42
+ elif model_choice == "sdxl flash":
43
+ PIPELINES[model_choice] = StableDiffusionXLPipeline.from_pretrained(
44
+ "sd-community/sdxl-flash", torch_dtype=torch.float16
45
+ )
46
+ # Store the original scheduler for resetting
47
+ PIPELINES[model_choice].original_scheduler = PIPELINES[model_choice].scheduler
48
+ elif model_choice == "stable cascade":
49
+ # Store both prior and decoder pipelines under 'stable cascade'
50
+ PIPELINES[model_choice] = {
51
+ 'prior': StableCascadePriorPipeline.from_pretrained(
52
+ "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
53
+ ),
54
+ 'decoder': StableCascadeDecoderPipeline.from_pretrained(
55
+ "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16
56
+ )
57
+ }
58
+ elif model_choice == "sd1.5":
59
+ PIPELINES[model_choice] = StableDiffusionPipeline.from_pretrained(
60
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
61
+ )
62
+ else:
63
+ raise ValueError(f"Invalid model choice: {model_choice}")
64
 
65
+ return PIPELINES[model_choice]
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
67
 
68
+ def unload_pipeline(model_choice):
69
+ """Unloads the specified pipeline from the PIPELINES dictionary and frees GPU memory."""
70
+ if model_choice in PIPELINES:
71
+ del PIPELINES[model_choice]
72
+
73
+ torch.cuda.empty_cache()
74
+ gc.collect()
75
 
 
76
  @spaces.GPU(duration=80)
77
+ def run_inference(
78
  prompt,
79
+ pipe,
80
  negative_prompt,
81
  num_inference_steps,
82
  guidance_scale,
 
84
  width,
85
  seed,
86
  num_images_per_prompt,
 
 
87
  prior_num_inference_steps=None,
88
  prior_guidance_scale=None,
89
  decoder_num_inference_steps=None,
90
  decoder_guidance_scale=None,
91
  ):
92
+ """Runs inference with the specified pipeline and parameters."""
93
+
94
+ # Enable CPU offloading only if a GPU is available, for saving up RAM
95
+ if torch.cuda.is_available():
96
+ if isinstance(pipe, dict): # Special handling for stable cascade
97
+ pipe['prior'].enable_model_cpu_offload()
98
+ pipe['decoder'].enable_model_cpu_offload()
99
+ else:
100
+ pipe.enable_model_cpu_offload()
 
 
 
 
 
 
101
 
102
+ # Reset the sampler if the model is NOT SDXL Flash
103
+ if model_choice != "sdxl flash" and "sdxl flash" in PIPELINES:
104
+ PIPELINES["sdxl flash"].scheduler = PIPELINES["sdxl flash"].original_scheduler
105
+
106
+ # Apply SDXL Flash sampler ONLY if model_choice is 'sdxl flash'
107
+ if model_choice == "sdxl flash":
108
+ pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
109
+ pipe.scheduler.config, timestep_spacing="trailing"
 
 
 
110
  )
111
 
112
+ if seed == 0:
113
+ seed = random.randint(1, MAX_SEED)
114
+
115
+ generator = torch.Generator().manual_seed(seed)
 
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ if isinstance(pipe, dict): # Stable Cascade
119
+ with torch.inference_mode():
120
+ prior_output = pipe['prior'](
121
+ prompt=prompt,
122
+ negative_prompt=negative_prompt,
123
+ num_inference_steps=prior_num_inference_steps,
124
+ guidance_scale=prior_guidance_scale,
125
+ height=height,
126
+ width=width,
127
+ generator=generator,
128
+ num_images_per_prompt=num_images_per_prompt,
129
+ )
130
+ with torch.inference_mode():
131
+ output = pipe['decoder'](
132
+ image_embeddings=prior_output.image_embeddings.to(torch.float16),
133
+ prompt=prompt,
134
+ negative_prompt=negative_prompt,
135
+ num_inference_steps=decoder_num_inference_steps,
136
+ guidance_scale=decoder_guidance_scale,
137
+ ).images
138
+ else: # Other pipelines
139
+ with torch.inference_mode():
140
+ output = pipe(
141
+ prompt=prompt,
142
+ negative_prompt=negative_prompt,
143
+ num_inference_steps=num_inference_steps,
144
+ guidance_scale=guidance_scale,
145
+ height=height,
146
+ width=width,
147
+ generator=generator,
148
+ num_images_per_prompt=num_images_per_prompt,
149
+ ).images
150
 
151
  return output
152
 
153
+ # Helper function to generate images for a single model
154
+ def generate_single_image(
155
+ prompt,
156
+ model_choice,
157
+ negative_prompt,
158
+ num_inference_steps,
159
+ guidance_scale,
160
+ height,
161
+ width,
162
+ seed,
163
+ num_images_per_prompt,
164
+ prior_num_inference_steps=None,
165
+ prior_guidance_scale=None,
166
+ decoder_num_inference_steps=None,
167
+ decoder_guidance_scale=None,
168
+ ):
169
+ # Load the pipeline
170
+ pipe = load_pipeline(model_choice)
171
+
172
+ # Run inference
173
+ output = run_inference(
174
+ prompt,
175
+ pipe,
176
+ negative_prompt,
177
+ num_inference_steps,
178
+ guidance_scale,
179
+ height,
180
+ width,
181
+ seed,
182
+ num_images_per_prompt,
183
+ prior_num_inference_steps,
184
+ prior_guidance_scale,
185
+ decoder_num_inference_steps,
186
+ decoder_guidance_scale,
187
+ )
188
+
189
+ # Unload the pipeline
190
+ unload_pipeline(model_choice)
191
+
192
+ return output
193
 
194
  # Define the image generation function for the Arena tab
 
195
  def generate_arena_images(
196
  prompt,
197
  negative_prompt,
 
236
  decoder_guidance_scale_d,
237
  progress=gr.Progress(track_tqdm=True),
238
  ):
 
 
 
 
239
 
240
  # Generate images for selected models
241
  if num_models_to_compare >= 2:
242
  images_a = generate_single_image(
243
  prompt,
244
+ model_choice_a,
245
  negative_prompt,
246
  num_inference_steps_a,
247
  guidance_scale_a,
 
249
  width_a,
250
  seed,
251
  num_images_per_prompt,
 
 
252
  prior_num_inference_steps_a,
253
  prior_guidance_scale_a,
254
  decoder_num_inference_steps_a,
 
256
  )
257
  images_b = generate_single_image(
258
  prompt,
259
+ model_choice_b,
260
  negative_prompt,
261
  num_inference_steps_b,
262
  guidance_scale_b,
 
264
  width_b,
265
  seed,
266
  num_images_per_prompt,
 
 
267
  prior_num_inference_steps_b,
268
  prior_guidance_scale_b,
269
  decoder_num_inference_steps_b,
 
275
  if num_models_to_compare >= 3:
276
  images_c = generate_single_image(
277
  prompt,
278
+ model_choice_c,
279
  negative_prompt,
280
  num_inference_steps_c,
281
  guidance_scale_c,
 
283
  width_c,
284
  seed,
285
  num_images_per_prompt,
 
 
286
  prior_num_inference_steps_c,
287
  prior_guidance_scale_c,
288
  decoder_num_inference_steps_c,
 
294
  if num_models_to_compare >= 4:
295
  images_d = generate_single_image(
296
  prompt,
297
+ model_choice_d,
298
  negative_prompt,
299
  num_inference_steps_d,
300
  guidance_scale_d,
 
302
  width_d,
303
  seed,
304
  num_images_per_prompt,
 
 
305
  prior_num_inference_steps_d,
306
  prior_guidance_scale_d,
307
  decoder_num_inference_steps_d,
 
314
 
315
 
316
  # Define the image generation function for the Individual tab
 
317
  def generate_individual_image(
318
  prompt,
319
+ model_choice,
320
  negative_prompt,
321
  num_inference_steps,
322
  guidance_scale,
 
324
  width,
325
  seed,
326
  num_images_per_prompt,
 
327
  prior_num_inference_steps,
328
  prior_guidance_scale,
329
  decoder_num_inference_steps,
330
  decoder_guidance_scale,
331
  progress=gr.Progress(track_tqdm=True),
332
  ):
 
 
 
 
333
 
334
  output = generate_single_image(
335
  prompt,
336
+ model_choice,
337
  negative_prompt,
338
  num_inference_steps,
339
  guidance_scale,
 
341
  width,
342
  seed,
343
  num_images_per_prompt,
 
 
344
  prior_num_inference_steps,
345
  prior_guidance_scale,
346
  decoder_num_inference_steps,
 
664
  width_a = gr.Slider(
665
  label="Width (Model A)",
666
  info="Width of the Image",
667
+ minimum=512,
668
+ maximum=1280,
 
669
  value=1024,
670
+ step=32,
671
  )
672
  height_a = gr.Slider(
673
  label="Height (Model A)",
674
  info="Height of the Image",
675
+ minimum=512,
676
+ maximum=1280,
 
677
  value=1024,
678
+ step=32,
679
  )
680
  with gr.Column():
681
  num_inference_steps_b = gr.Slider(
 
684
  minimum=1,
685
  maximum=50,
686
  value=25,
687
+ step=32,
688
  visible=True,
689
  )
690
  guidance_scale_b = gr.Slider(
 
735
  width_b = gr.Slider(
736
  label="Width (Model B)",
737
  info="Width of the Image",
738
+ minimum=512,
739
+ maximum=1280,
 
740
  value=1024,
741
+ step=32,
742
  )
743
  height_b = gr.Slider(
744
  label="Height (Model B)",
745
  info="Height of the Image",
746
+ minimum=512,
747
+ maximum=1280,
 
748
  value=1024,
749
+ step=32,
750
  )
751
  with gr.Column(visible=False) as model_c_options:
752
  num_inference_steps_c = gr.Slider(
 
806
  width_c = gr.Slider(
807
  label="Width (Model C)",
808
  info="Width of the Image",
809
+ minimum=512,
810
+ maximum=1280,
 
811
  value=1024,
812
+ step=32,
813
  )
814
  height_c = gr.Slider(
815
  label="Height (Model C)",
816
  info="Height of the Image",
817
+ minimum=512,
818
+ maximum=1280,
 
819
  value=1024,
820
+ step=32,
821
  )
822
  with gr.Column(visible=False) as model_d_options:
823
  num_inference_steps_d = gr.Slider(
 
877
  width_d = gr.Slider(
878
  label="Width (Model D)",
879
  info="Width of the Image",
880
+ minimum=512,
881
+ maximum=1280,
 
882
  value=1024,
883
+ step=32,
884
  )
885
  height_d = gr.Slider(
886
  label="Height (Model D)",
887
  info="Height of the Image",
888
+ minimum=512,
889
+ maximum=1280,
 
890
  value=1024,
891
+ step=32,
892
  )
893
  with gr.Row():
894
  seed = gr.Slider(
 
917
  prior_guidance_scale_a: gr.update(visible=True),
918
  decoder_num_inference_steps_a: gr.update(visible=True),
919
  decoder_guidance_scale_a: gr.update(visible=True),
920
+ width_a: gr.update(step=512, value=1024, maximum=1536),
921
+ height_a: gr.update(step=512, value=1024, maximum=1536),
922
  }
923
  elif model_choice_a == "sdxl flash":
924
  return {
 
928
  prior_guidance_scale_a: gr.update(visible=False),
929
  decoder_num_inference_steps_a: gr.update(visible=False),
930
  decoder_guidance_scale_a: gr.update(visible=False),
931
+ width_a: gr.update(step=32, value=1024, maximum=1536),
932
+ height_a: gr.update(step=32, value=1024, maximum=1536),
933
  }
934
  elif model_choice_a == "sd1.5":
935
  return {
 
938
  prior_guidance_scale_a: gr.update(visible=True),
939
  decoder_num_inference_steps_a: gr.update(visible=True),
940
  decoder_guidance_scale_a: gr.update(visible=True),
941
+ width_a: gr.update(step=32, value=512, maximum=768),
942
+ height_a: gr.update(step=32, value=512, maximum=768),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
  }
944
  elif model_choice_a == "sd2.1":
945
  return {
 
949
  prior_guidance_scale_a: gr.update(visible=False),
950
  decoder_num_inference_steps_a: gr.update(visible=False),
951
  decoder_guidance_scale_a: gr.update(visible=False),
952
+ width_a: gr.update(step=32, value=768, maximum=1024),
953
+ height_a: gr.update(step=32, value=768, maximum=1024),
954
  }
955
  else:
956
  return {
 
960
  prior_guidance_scale_a: gr.update(visible=False),
961
  decoder_num_inference_steps_a: gr.update(visible=False),
962
  decoder_guidance_scale_a: gr.update(visible=False),
963
+ width_a: gr.update(step=32, value=1024, maximum=1536),
964
+ height_a: gr.update(step=32, value=1024, maximum=1536),
965
  }
966
 
967
  def toggle_visibility_arena_b(model_choice_b):
 
973
  prior_guidance_scale_b: gr.update(visible=True),
974
  decoder_num_inference_steps_b: gr.update(visible=True),
975
  decoder_guidance_scale_b: gr.update(visible=True),
976
+ width_b: gr.update(step=256, value=1024, maximum=1536),
977
+ height_b: gr.update(step=256, value=1024, maximum=1536),
978
  }
979
  elif model_choice_b == "sdxl flash":
980
  return {
 
984
  prior_guidance_scale_b: gr.update(visible=False),
985
  decoder_num_inference_steps_b: gr.update(visible=False),
986
  decoder_guidance_scale_b: gr.update(visible=False),
987
+ width_a: gr.update(step=32, value=1024, maximum=1536),
988
+ height_a: gr.update(step=32, value=1024, maximum=1536),
989
  }
990
  elif model_choice_b == "sd1.5":
991
  return {
 
995
  prior_guidance_scale_b: gr.update(visible=False),
996
  decoder_num_inference_steps_b: gr.update(visible=False),
997
  decoder_guidance_scale_b: gr.update(visible=False),
998
+ width_b: gr.update(step=32, value=512, maximum=768),
999
+ height_b: gr.update(step=32, value=512, maximum=768),
1000
  }
1001
  elif model_choice_b == "sd2.1":
1002
  return {
 
1006
  prior_guidance_scale_b: gr.update(visible=False),
1007
  decoder_num_inference_steps_b: gr.update(visible=False),
1008
  decoder_guidance_scale_b: gr.update(visible=False),
1009
+ width_b: gr.update(step=32, value=768, maximum=1024),
1010
+ height_b: gr.update(step=32, value=768, maximum=1024),
1011
  }
1012
  else:
1013
  return {
 
1017
  prior_guidance_scale_b: gr.update(visible=False),
1018
  decoder_num_inference_steps_b: gr.update(visible=False),
1019
  decoder_guidance_scale_b: gr.update(visible=False),
1020
+ width_b: gr.update(step=32, value=1024, maximum=1536),
1021
+ height_b: gr.update(step=32, value=1024, maximum=1536),
1022
  }
1023
 
1024
  def toggle_visibility_arena_c(model_choice_c):
 
1030
  prior_guidance_scale_c: gr.update(visible=True),
1031
  decoder_num_inference_steps_c: gr.update(visible=True),
1032
  decoder_guidance_scale_c: gr.update(visible=True),
1033
+ width_c: gr.update(step=256, value=1024, maximum=1536),
1034
+ height_c: gr.update(step=256, value=1024, maximum=1536),
1035
  }
1036
  elif model_choice_c == "sdxl flash":
1037
  return {
 
1041
  prior_guidance_scale_c: gr.update(visible=False),
1042
  decoder_num_inference_steps_c: gr.update(visible=False),
1043
  decoder_guidance_scale_c: gr.update(visible=False),
1044
+ width_c: gr.update(step=32, value=1024, maximum=1536),
1045
+ height_c: gr.update(step=32, value=1024, maximum=1536),
1046
  }
1047
  elif model_choice_c == "sd1.5":
1048
  return {
 
1052
  prior_guidance_scale_c: gr.update(visible=False),
1053
  decoder_num_inference_steps_c: gr.update(visible=False),
1054
  decoder_guidance_scale_c: gr.update(visible=False),
1055
+ width_c: gr.update(step=32, value=512, maximum=768),
1056
+ height_c: gr.update(step=32, value=512, maximum=768),
1057
  }
1058
  elif model_choice_c == "sd2.1":
1059
  return {
 
1063
  prior_guidance_scale_c: gr.update(visible=False),
1064
  decoder_num_inference_steps_c: gr.update(visible=False),
1065
  decoder_guidance_scale_c: gr.update(visible=False),
1066
+ width_c: gr.update(step=32, value=768, maximum=1024),
1067
+ height_c: gr.update(step=32, value=768, maximum=1024),
1068
  }
1069
  else:
1070
  return {
 
1074
  prior_guidance_scale_c: gr.update(visible=False),
1075
  decoder_num_inference_steps_c: gr.update(visible=False),
1076
  decoder_guidance_scale_c: gr.update(visible=False),
1077
+ width_c: gr.update(step=32, value=1024, maximum=1536),
1078
+ height_c: gr.update(step=32, value=1024, maximum=1536),
1079
  }
1080
 
1081
  def toggle_visibility_arena_d(model_choice_d):
 
1087
  prior_guidance_scale_d: gr.update(visible=True),
1088
  decoder_num_inference_steps_d: gr.update(visible=True),
1089
  decoder_guidance_scale_d: gr.update(visible=True),
1090
+ width_d: gr.update(step=256, value=1024, maximum=1536),
1091
+ height_d: gr.update(step=256, value=1024, maximum=1536),
1092
  }
1093
  elif model_choice_d == "sdxl flash":
1094
  return {
 
1098
  prior_guidance_scale_d: gr.update(visible=False),
1099
  decoder_num_inference_steps_d: gr.update(visible=False),
1100
  decoder_guidance_scale_d: gr.update(visible=False),
1101
+ width_d: gr.update(step=32, value=1024, maximum=1536),
1102
+ height_d: gr.update(step=32, value=1024, maximum=1536),
1103
  }
1104
  elif model_choice_d == "sd1.5":
1105
  return {
 
1109
  prior_guidance_scale_d: gr.update(visible=False),
1110
  decoder_num_inference_steps_d: gr.update(visible=False),
1111
  decoder_guidance_scale_d: gr.update(visible=False),
1112
+ width_d: gr.update(step=32, value=512, maximum=768),
1113
+ height_d: gr.update(step=32, value=512, maximum=768),
1114
  }
1115
  elif model_choice_d == "sd2.1":
1116
  return {
 
1120
  prior_guidance_scale_d: gr.update(visible=False),
1121
  decoder_num_inference_steps_d: gr.update(visible=False),
1122
  decoder_guidance_scale_d: gr.update(visible=False),
1123
+ width_d: gr.update(step=32, value=768, maximum=1024),
1124
+ height_d: gr.update(step=32, value=768, maximum=1024),
1125
  }
1126
  else:
1127
  return {
 
1131
  prior_guidance_scale_d: gr.update(visible=False),
1132
  decoder_num_inference_steps_d: gr.update(visible=False),
1133
  decoder_guidance_scale_d: gr.update(visible=False),
1134
+ width_d: gr.update(step=32, value=1024, maximum=1536),
1135
+ height_d: gr.update(step=32, value=1024, maximum=1536),
1136
  }
1137
 
1138
  model_choice_a.change(
 
1426
  width = gr.Slider(
1427
  label="Width",
1428
  info="Width of the Image",
1429
+ minimum=512,
1430
+ maximum=1280,
 
1431
  value=1024,
1432
+ step=32,
1433
  )
1434
  height = gr.Slider(
1435
  label="Height",
1436
  info="Height of the Image",
1437
+ minimum=512,
1438
+ maximum=1280,
 
1439
  value=1024,
1440
+ step=32,
1441
  )
1442
  with gr.Row():
1443
  seed = gr.Slider(
 
1466
  prior_guidance_scale: gr.update(visible=True),
1467
  decoder_num_inference_steps: gr.update(visible=True),
1468
  decoder_guidance_scale: gr.update(visible=True),
1469
+ width: gr.update(step=256, value=1024, maximum=1536),
1470
+ height: gr.update(step=256, value=1024, maximum=1536),
1471
  }
1472
  elif model_choice == "sdxl flash":
1473
  return {
 
1477
  prior_guidance_scale: gr.update(visible=False),
1478
  decoder_num_inference_steps: gr.update(visible=False),
1479
  decoder_guidance_scale: gr.update(visible=False),
1480
+ width: gr.update(step=32, value=1024, maximum=1536),
1481
+ height: gr.update(step=32, value=1024, maximum=1536),
1482
  }
1483
  elif model_choice == "sd1.5":
1484
  return {
 
1488
  prior_guidance_scale: gr.update(visible=False),
1489
  decoder_num_inference_steps: gr.update(visible=False),
1490
  decoder_guidance_scale: gr.update(visible=False),
1491
+ width: gr.update(step=32, value=512, maximum=768),
1492
+ height: gr.update(step=32, value=512, maximum=768),
1493
  }
1494
  elif model_choice == "sd2.1":
1495
  return {
 
1499
  prior_guidance_scale: gr.update(visible=False),
1500
  decoder_num_inference_steps: gr.update(visible=False),
1501
  decoder_guidance_scale: gr.update(visible=False),
1502
+ width: gr.update(step=32, value=768, maximum=1024),
1503
+ height: gr.update(step=32, value=768, maximum=1024),
1504
  }
1505
  else:
1506
  return {
 
1510
  prior_guidance_scale: gr.update(visible=False),
1511
  decoder_num_inference_steps: gr.update(visible=False),
1512
  decoder_guidance_scale: gr.update(visible=False),
1513
+ width: gr.update(step=32, value=1024, maximum=1536),
1514
+ height: gr.update(step=32, value=1024, maximum=1536),
1515
  }
1516
 
1517
  model_choice.change(
 
1533
  examples=examples_individual,
1534
  inputs=[
1535
  prompt,
1536
+ model_choice,
1537
  negative_prompt,
1538
  num_inference_steps,
1539
  guidance_scale,
 
1541
  width,
1542
  seed,
1543
  num_images_per_prompt,
 
1544
  prior_num_inference_steps,
1545
  prior_guidance_scale,
1546
  decoder_num_inference_steps,
 
1558
  fn=generate_individual_image,
1559
  inputs=[
1560
  prompt,
1561
+ model_choice,
1562
  negative_prompt,
1563
  num_inference_steps,
1564
  guidance_scale,
 
1566
  width,
1567
  seed,
1568
  num_images_per_prompt,
 
1569
  prior_num_inference_steps,
1570
  prior_guidance_scale,
1571
  decoder_num_inference_steps,