Nick088 commited on
Commit
924eee3
·
verified ·
1 Parent(s): 5718b02

fixed arena tab compare < 4 not working, fixed height width models not changing always

Browse files
Files changed (1) hide show
  1. app.py +29 -14
app.py CHANGED
@@ -121,7 +121,7 @@ def generate_single_image(
121
  guidance_scale=decoder_guidance_scale,
122
  ).images
123
 
124
- # the rest of the models have similar pipeline
125
  else:
126
  output = pipe(
127
  prompt=prompt,
@@ -138,7 +138,7 @@ def generate_single_image(
138
 
139
 
140
  # Define the image generation function for the Arena tab
141
- @spaces.GPU(duration=80)
142
  def generate_arena_images(
143
  prompt,
144
  negative_prompt,
@@ -189,7 +189,6 @@ def generate_arena_images(
189
  generator = torch.Generator().manual_seed(seed)
190
 
191
  # Generate images for selected models
192
- images = []
193
  if num_models_to_compare >= 2:
194
  images_a = generate_single_image(
195
  prompt,
@@ -207,7 +206,6 @@ def generate_arena_images(
207
  decoder_num_inference_steps_a,
208
  decoder_guidance_scale_a,
209
  )
210
- images.append(images_a)
211
  images_b = generate_single_image(
212
  prompt,
213
  negative_prompt,
@@ -224,7 +222,9 @@ def generate_arena_images(
224
  decoder_num_inference_steps_b,
225
  decoder_guidance_scale_b,
226
  )
227
- images.append(images_b)
 
 
228
  if num_models_to_compare >= 3:
229
  images_c = generate_single_image(
230
  prompt,
@@ -242,7 +242,9 @@ def generate_arena_images(
242
  decoder_num_inference_steps_c,
243
  decoder_guidance_scale_c,
244
  )
245
- images.append(images_c)
 
 
246
  if num_models_to_compare >= 4:
247
  images_d = generate_single_image(
248
  prompt,
@@ -260,9 +262,10 @@ def generate_arena_images(
260
  decoder_num_inference_steps_d,
261
  decoder_guidance_scale_d,
262
  )
263
- images.append(images_d)
 
264
 
265
- return images
266
 
267
 
268
  # Define the image generation function for the Individual tab
@@ -998,6 +1001,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
998
  prior_guidance_scale_c: gr.update(visible=True),
999
  decoder_num_inference_steps_c: gr.update(visible=True),
1000
  decoder_guidance_scale_c: gr.update(visible=True),
 
 
1001
  }
1002
  elif model_choice_c == "sdxl flash":
1003
  return {
@@ -1007,6 +1012,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1007
  prior_guidance_scale_c: gr.update(visible=False),
1008
  decoder_num_inference_steps_c: gr.update(visible=False),
1009
  decoder_guidance_scale_c: gr.update(visible=False),
 
 
1010
  }
1011
  elif model_choice_c == "sd1.5":
1012
  return {
@@ -1038,8 +1045,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1038
  prior_guidance_scale_c: gr.update(visible=False),
1039
  decoder_num_inference_steps_c: gr.update(visible=False),
1040
  decoder_guidance_scale_c: gr.update(visible=False),
1041
- width_c: gr.update(maximum=1344),
1042
- height_c: gr.update(maximum=1344),
1043
  }
1044
 
1045
  def toggle_visibility_arena_d(model_choice_d):
@@ -1051,6 +1058,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1051
  prior_guidance_scale_d: gr.update(visible=True),
1052
  decoder_num_inference_steps_d: gr.update(visible=True),
1053
  decoder_guidance_scale_d: gr.update(visible=True),
 
 
1054
  }
1055
  elif model_choice_d == "sdxl flash":
1056
  return {
@@ -1060,6 +1069,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1060
  prior_guidance_scale_d: gr.update(visible=False),
1061
  decoder_num_inference_steps_d: gr.update(visible=False),
1062
  decoder_guidance_scale_d: gr.update(visible=False),
 
 
1063
  }
1064
  elif model_choice_d == "sd1.5":
1065
  return {
@@ -1091,8 +1102,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1091
  prior_guidance_scale_d: gr.update(visible=False),
1092
  decoder_num_inference_steps_d: gr.update(visible=False),
1093
  decoder_guidance_scale_d: gr.update(visible=False),
1094
- width_d: gr.update(maximum=1344),
1095
- height_d: gr.update(maximum=1344),
1096
  }
1097
 
1098
  model_choice_a.change(
@@ -1426,6 +1437,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1426
  prior_guidance_scale: gr.update(visible=True),
1427
  decoder_num_inference_steps: gr.update(visible=True),
1428
  decoder_guidance_scale: gr.update(visible=True),
 
 
1429
  }
1430
  elif model_choice == "sdxl flash":
1431
  return {
@@ -1435,6 +1448,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1435
  prior_guidance_scale: gr.update(visible=False),
1436
  decoder_num_inference_steps: gr.update(visible=False),
1437
  decoder_guidance_scale: gr.update(visible=False),
 
 
1438
  }
1439
  elif model_choice == "sd1.5":
1440
  return {
@@ -1466,8 +1481,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
1466
  prior_guidance_scale: gr.update(visible=False),
1467
  decoder_num_inference_steps: gr.update(visible=False),
1468
  decoder_guidance_scale: gr.update(visible=False),
1469
- width: gr.update(maximum=1344),
1470
- height: gr.update(maximum=1344),
1471
  }
1472
 
1473
  model_choice.change(
 
121
  guidance_scale=decoder_guidance_scale,
122
  ).images
123
 
124
+ # the rest of the models have similar pipeline
125
  else:
126
  output = pipe(
127
  prompt=prompt,
 
138
 
139
 
140
  # Define the image generation function for the Arena tab
141
+ @spaces.GPU(duration=200)
142
  def generate_arena_images(
143
  prompt,
144
  negative_prompt,
 
189
  generator = torch.Generator().manual_seed(seed)
190
 
191
  # Generate images for selected models
 
192
  if num_models_to_compare >= 2:
193
  images_a = generate_single_image(
194
  prompt,
 
206
  decoder_num_inference_steps_a,
207
  decoder_guidance_scale_a,
208
  )
 
209
  images_b = generate_single_image(
210
  prompt,
211
  negative_prompt,
 
222
  decoder_num_inference_steps_b,
223
  decoder_guidance_scale_b,
224
  )
225
+
226
+ output_arena_images = images_a, images_b, None, None
227
+
228
  if num_models_to_compare >= 3:
229
  images_c = generate_single_image(
230
  prompt,
 
242
  decoder_num_inference_steps_c,
243
  decoder_guidance_scale_c,
244
  )
245
+
246
+ output_arena_images = images_a, images_b, images_c, None
247
+
248
  if num_models_to_compare >= 4:
249
  images_d = generate_single_image(
250
  prompt,
 
262
  decoder_num_inference_steps_d,
263
  decoder_guidance_scale_d,
264
  )
265
+
266
+ output_arena_images = images_a, images_b, images_c, images_d
267
 
268
+ return output_arena_images
269
 
270
 
271
  # Define the image generation function for the Individual tab
 
1001
  prior_guidance_scale_c: gr.update(visible=True),
1002
  decoder_num_inference_steps_c: gr.update(visible=True),
1003
  decoder_guidance_scale_c: gr.update(visible=True),
1004
+ width_c: gr.update(value=1024, maximum=1344),
1005
+ height_c: gr.update(value=1024, maximum=1344),
1006
  }
1007
  elif model_choice_c == "sdxl flash":
1008
  return {
 
1012
  prior_guidance_scale_c: gr.update(visible=False),
1013
  decoder_num_inference_steps_c: gr.update(visible=False),
1014
  decoder_guidance_scale_c: gr.update(visible=False),
1015
+ width_c: gr.update(value=1024, maximum=1344),
1016
+ height_c: gr.update(value=1024, maximum=1344),
1017
  }
1018
  elif model_choice_c == "sd1.5":
1019
  return {
 
1045
  prior_guidance_scale_c: gr.update(visible=False),
1046
  decoder_num_inference_steps_c: gr.update(visible=False),
1047
  decoder_guidance_scale_c: gr.update(visible=False),
1048
+ width_c: gr.update(value=1024, maximum=1344),
1049
+ height_c: gr.update(value=1024, maximum=1344),
1050
  }
1051
 
1052
  def toggle_visibility_arena_d(model_choice_d):
 
1058
  prior_guidance_scale_d: gr.update(visible=True),
1059
  decoder_num_inference_steps_d: gr.update(visible=True),
1060
  decoder_guidance_scale_d: gr.update(visible=True),
1061
+ width_d: gr.update(value=1024, maximum=1344),
1062
+ height_d: gr.update(value=1024, maximum=1344),
1063
  }
1064
  elif model_choice_d == "sdxl flash":
1065
  return {
 
1069
  prior_guidance_scale_d: gr.update(visible=False),
1070
  decoder_num_inference_steps_d: gr.update(visible=False),
1071
  decoder_guidance_scale_d: gr.update(visible=False),
1072
+ width_d: gr.update(value=1024, maximum=1344),
1073
+ height_d: gr.update(value=1024, maximum=1344),
1074
  }
1075
  elif model_choice_d == "sd1.5":
1076
  return {
 
1102
  prior_guidance_scale_d: gr.update(visible=False),
1103
  decoder_num_inference_steps_d: gr.update(visible=False),
1104
  decoder_guidance_scale_d: gr.update(visible=False),
1105
+ width_d: gr.update(value=1024, maximum=1344),
1106
+ height_d: gr.update(value=1024, maximum=1344),
1107
  }
1108
 
1109
  model_choice_a.change(
 
1437
  prior_guidance_scale: gr.update(visible=True),
1438
  decoder_num_inference_steps: gr.update(visible=True),
1439
  decoder_guidance_scale: gr.update(visible=True),
1440
+ width: gr.update(value=1024, maximum=1344),
1441
+ height: gr.update(value=1024, maximum=1344),
1442
  }
1443
  elif model_choice == "sdxl flash":
1444
  return {
 
1448
  prior_guidance_scale: gr.update(visible=False),
1449
  decoder_num_inference_steps: gr.update(visible=False),
1450
  decoder_guidance_scale: gr.update(visible=False),
1451
+ width: gr.update(value=1024, maximum=1344),
1452
+ height: gr.update(value=1024, maximum=1344),
1453
  }
1454
  elif model_choice == "sd1.5":
1455
  return {
 
1481
  prior_guidance_scale: gr.update(visible=False),
1482
  decoder_num_inference_steps: gr.update(visible=False),
1483
  decoder_guidance_scale: gr.update(visible=False),
1484
+ width: gr.update(value=1024, maximum=1344),
1485
+ height: gr.update(value=1024, maximum=1344),
1486
  }
1487
 
1488
  model_choice.change(