Nick088 commited on
Commit
ec79957
·
verified ·
1 Parent(s): 3190260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -83
app.py CHANGED
@@ -36,23 +36,10 @@ sdxl_pipe = DiffusionPipeline.from_pretrained(
36
  )
37
  sdxl_pipe.to(device)
38
 
39
- # superprompt-v1
40
- tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
41
- model = T5ForConditionalGeneration.from_pretrained(
42
- "roborovski/superprompt-v1", device_map="auto", torch_dtype="auto"
43
- )
44
- model.to(device)
45
-
46
- # toggle visibility the enhanced prompt output
47
- def update_visibility(enhance_prompt):
48
- return gr.update(visible=enhance_prompt)
49
-
50
-
51
  # Define the image generation function for the Arena tab
52
  @spaces.GPU(duration=80)
53
  def generate_arena_images(
54
  prompt,
55
- enhance_prompt,
56
  negative_prompt,
57
  num_inference_steps,
58
  height,
@@ -67,23 +54,6 @@ def generate_arena_images(
67
  if seed == 0:
68
  seed = random.randint(1, 2**32 - 1)
69
 
70
- if enhance_prompt:
71
- transformers.set_seed(seed)
72
-
73
- input_text = f"Expand the following prompt to add more detail: {prompt}"
74
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
75
-
76
- outputs = model.generate(
77
- input_ids,
78
- max_new_tokens=512,
79
- repetition_penalty=1.2,
80
- do_sample=True,
81
- temperature=0.7,
82
- top_p=1,
83
- top_k=50,
84
- )
85
- prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
-
87
  generator = torch.Generator().manual_seed(seed)
88
 
89
  # Generate images for both models
@@ -112,7 +82,7 @@ def generate_arena_images(
112
  generator,
113
  )
114
 
115
- return images_1, images_2, prompt
116
 
117
 
118
  # Helper function to generate images for a single model
@@ -155,7 +125,6 @@ def generate_single_image(
155
  @spaces.GPU(duration=80)
156
  def generate_individual_image(
157
  prompt,
158
- enhance_prompt,
159
  negative_prompt,
160
  num_inference_steps,
161
  height,
@@ -169,23 +138,6 @@ def generate_individual_image(
169
  if seed == 0:
170
  seed = random.randint(1, 2**32 - 1)
171
 
172
- if enhance_prompt:
173
- transformers.set_seed(seed)
174
-
175
- input_text = f"Expand the following prompt to add more detail: {prompt}"
176
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
177
-
178
- outputs = model.generate(
179
- input_ids,
180
- max_new_tokens=512,
181
- repetition_penalty=1.2,
182
- do_sample=True,
183
- temperature=0.7,
184
- top_p=1,
185
- top_k=50,
186
- )
187
- prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
188
-
189
  generator = torch.Generator().manual_seed(seed)
190
 
191
  output = generate_single_image(
@@ -201,14 +153,14 @@ def generate_individual_image(
201
  generator,
202
  )
203
 
204
- return output, prompt
205
 
206
 
207
  # Create the Gradio interface
208
  examples = [
209
- ["A white car racing fast to the moon.", True],
210
- ["A woman in a red dress singing on top of a building.", True],
211
- ["An astronaut on mars in a futuristic cyborg suit.", True],
212
  ]
213
 
214
  css = """
@@ -240,9 +192,6 @@ with gr.Blocks(css=css) as demo:
240
  info="Describe the image you want",
241
  placeholder="A cat...",
242
  )
243
- enhance_prompt = gr.Checkbox(
244
- label="Prompt Enhancement with SuperPrompt-v1", value=True
245
- )
246
  model_choice_1 = gr.Dropdown(
247
  label="Stable Diffusion Model 1",
248
  choices=["sd3 medium", "sd2.1", "sdxl"],
@@ -256,14 +205,6 @@ with gr.Blocks(css=css) as demo:
256
  run_button = gr.Button("Run")
257
  result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
258
  result_2 = gr.Gallery(label="Generated Images (Model 2)", elem_id="gallery_2")
259
- better_prompt = gr.Textbox(
260
- label="Enhanced Prompt",
261
- info="The output of your enhanced prompt used for the Image Generation",
262
- visible=True,
263
- )
264
- enhance_prompt.change(
265
- fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt
266
- )
267
  with gr.Accordion("Advanced options", open=False):
268
  with gr.Row():
269
  negative_prompt = gr.Textbox(
@@ -326,8 +267,8 @@ with gr.Blocks(css=css) as demo:
326
 
327
  gr.Examples(
328
  examples=examples,
329
- inputs=[prompt, enhance_prompt],
330
- outputs=[result_1, result_2, better_prompt],
331
  fn=generate_arena_images,
332
  )
333
 
@@ -339,7 +280,6 @@ with gr.Blocks(css=css) as demo:
339
  fn=generate_arena_images,
340
  inputs=[
341
  prompt,
342
- enhance_prompt,
343
  negative_prompt,
344
  num_inference_steps,
345
  width,
@@ -350,7 +290,7 @@ with gr.Blocks(css=css) as demo:
350
  model_choice_1,
351
  model_choice_2,
352
  ],
353
- outputs=[result_1, result_2, better_prompt],
354
  )
355
 
356
  with gr.TabItem("Individual"):
@@ -361,9 +301,6 @@ with gr.Blocks(css=css) as demo:
361
  info="Describe the image you want",
362
  placeholder="A cat...",
363
  )
364
- enhance_prompt = gr.Checkbox(
365
- label="Prompt Enhancement with SuperPrompt-v1", value=True
366
- )
367
  model_choice = gr.Dropdown(
368
  label="Stable Diffusion Model",
369
  choices=["sd3 medium", "sd2.1", "sdxl"],
@@ -371,14 +308,6 @@ with gr.Blocks(css=css) as demo:
371
  )
372
  run_button = gr.Button("Run")
373
  result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
374
- better_prompt = gr.Textbox(
375
- label="Enhanced Prompt",
376
- info="The output of your enhanced prompt used for the Image Generation",
377
- visible=True,
378
- )
379
- enhance_prompt.change(
380
- fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt
381
- )
382
  with gr.Accordion("Advanced options", open=False):
383
  with gr.Row():
384
  negative_prompt = gr.Textbox(
@@ -441,8 +370,8 @@ with gr.Blocks(css=css) as demo:
441
 
442
  gr.Examples(
443
  examples=examples,
444
- inputs=[prompt, enhance_prompt],
445
- outputs=[result, better_prompt],
446
  fn=generate_individual_image,
447
  )
448
 
@@ -454,7 +383,6 @@ with gr.Blocks(css=css) as demo:
454
  fn=generate_individual_image,
455
  inputs=[
456
  prompt,
457
- enhance_prompt,
458
  negative_prompt,
459
  num_inference_steps,
460
  width,
@@ -464,7 +392,7 @@ with gr.Blocks(css=css) as demo:
464
  num_images_per_prompt,
465
  model_choice,
466
  ],
467
- outputs=[result, better_prompt],
468
  )
469
 
470
  demo.queue().launch(share=False)
 
36
  )
37
  sdxl_pipe.to(device)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Define the image generation function for the Arena tab
40
  @spaces.GPU(duration=80)
41
  def generate_arena_images(
42
  prompt,
 
43
  negative_prompt,
44
  num_inference_steps,
45
  height,
 
54
  if seed == 0:
55
  seed = random.randint(1, 2**32 - 1)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  generator = torch.Generator().manual_seed(seed)
58
 
59
  # Generate images for both models
 
82
  generator,
83
  )
84
 
85
+ return images_1, images_2
86
 
87
 
88
  # Helper function to generate images for a single model
 
125
  @spaces.GPU(duration=80)
126
  def generate_individual_image(
127
  prompt,
 
128
  negative_prompt,
129
  num_inference_steps,
130
  height,
 
138
  if seed == 0:
139
  seed = random.randint(1, 2**32 - 1)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  generator = torch.Generator().manual_seed(seed)
142
 
143
  output = generate_single_image(
 
153
  generator,
154
  )
155
 
156
+ return output
157
 
158
 
159
  # Create the Gradio interface
160
  examples = [
161
+ ["A white car racing fast to the moon."],
162
+ ["A woman in a red dress singing on top of a building."],
163
+ ["An astronaut on mars in a futuristic cyborg suit."],
164
  ]
165
 
166
  css = """
 
192
  info="Describe the image you want",
193
  placeholder="A cat...",
194
  )
 
 
 
195
  model_choice_1 = gr.Dropdown(
196
  label="Stable Diffusion Model 1",
197
  choices=["sd3 medium", "sd2.1", "sdxl"],
 
205
  run_button = gr.Button("Run")
206
  result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
207
  result_2 = gr.Gallery(label="Generated Images (Model 2)", elem_id="gallery_2")
 
 
 
 
 
 
 
 
208
  with gr.Accordion("Advanced options", open=False):
209
  with gr.Row():
210
  negative_prompt = gr.Textbox(
 
267
 
268
  gr.Examples(
269
  examples=examples,
270
+ inputs=[prompt],
271
+ outputs=[result_1, result_2],
272
  fn=generate_arena_images,
273
  )
274
 
 
280
  fn=generate_arena_images,
281
  inputs=[
282
  prompt,
 
283
  negative_prompt,
284
  num_inference_steps,
285
  width,
 
290
  model_choice_1,
291
  model_choice_2,
292
  ],
293
+ outputs=[result_1, result_2],
294
  )
295
 
296
  with gr.TabItem("Individual"):
 
301
  info="Describe the image you want",
302
  placeholder="A cat...",
303
  )
 
 
 
304
  model_choice = gr.Dropdown(
305
  label="Stable Diffusion Model",
306
  choices=["sd3 medium", "sd2.1", "sdxl"],
 
308
  )
309
  run_button = gr.Button("Run")
310
  result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
 
 
 
 
 
 
 
 
311
  with gr.Accordion("Advanced options", open=False):
312
  with gr.Row():
313
  negative_prompt = gr.Textbox(
 
370
 
371
  gr.Examples(
372
  examples=examples,
373
+ inputs=[prompt],
374
+ outputs=[result],
375
  fn=generate_individual_image,
376
  )
377
 
 
383
  fn=generate_individual_image,
384
  inputs=[
385
  prompt,
 
386
  negative_prompt,
387
  num_inference_steps,
388
  width,
 
392
  num_images_per_prompt,
393
  model_choice,
394
  ],
395
+ outputs=[result],
396
  )
397
 
398
  demo.queue().launch(share=False)