hysts HF staff commited on
Commit
72b0fe6
1 Parent(s): 1869bcd
Files changed (1) hide show
  1. app.py +46 -12
app.py CHANGED
@@ -3,8 +3,10 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
 
7
  import gradio as gr
 
8
  import PIL.Image
9
  import torch
10
  from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
@@ -29,6 +31,16 @@ if torch.cuda.is_available():
29
  sd_pipe.to(device)
30
 
31
 
 
 
 
 
 
 
 
 
 
 
32
  def get_token_table(prompt: str) -> list[tuple[int, str]]:
33
  tokens = [ax_pipe.tokenizer.decode(t) for t in ax_pipe.tokenizer(prompt)["input_ids"]]
34
  tokens = tokens[1:-1]
@@ -40,7 +52,7 @@ def run(
40
  indices_to_alter_str: str,
41
  seed: int = 0,
42
  apply_attend_and_excite: bool = True,
43
- num_steps: int = 50,
44
  guidance_scale: float = 7.5,
45
  scale_factor: int = 20,
46
  thresholds: dict[int, float] = {
@@ -48,9 +60,12 @@ def run(
48
  20: 0.8,
49
  },
50
  max_iter_to_alter: int = 25,
 
51
  ) -> PIL.Image.Image:
52
- generator = torch.Generator(device=device).manual_seed(seed)
 
53
 
 
54
  if apply_attend_and_excite:
55
  try:
56
  token_indices = list(map(int, indices_to_alter_str.split(",")))
@@ -61,7 +76,7 @@ def run(
61
  token_indices=token_indices,
62
  guidance_scale=guidance_scale,
63
  generator=generator,
64
- num_inference_steps=num_steps,
65
  max_iter_to_alter=max_iter_to_alter,
66
  thresholds=thresholds,
67
  scale_factor=scale_factor,
@@ -71,7 +86,7 @@ def run(
71
  prompt=prompt,
72
  guidance_scale=guidance_scale,
73
  generator=generator,
74
- num_inference_steps=num_steps,
75
  )
76
  return out.images[0]
77
 
@@ -115,23 +130,24 @@ with gr.Blocks(css="style.css") as demo:
115
  max_lines=1,
116
  placeholder="4,16",
117
  )
 
118
  seed = gr.Slider(
119
  label="Seed",
120
  minimum=0,
121
- maximum=100000,
122
  step=1,
123
  value=0,
124
  )
125
- apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True)
126
- num_steps = gr.Slider(
127
- label="Number of steps",
128
- minimum=0,
129
- maximum=100,
130
  step=1,
131
  value=50,
132
  )
133
  guidance_scale = gr.Slider(
134
- label="CFG scale",
135
  minimum=0,
136
  maximum=50,
137
  step=0.1,
@@ -246,10 +262,16 @@ with gr.Blocks(css="style.css") as demo:
246
  token_indices_str,
247
  seed,
248
  apply_attend_and_excite,
249
- num_steps,
250
  guidance_scale,
251
  ]
252
  prompt.submit(
 
 
 
 
 
 
253
  fn=get_token_table,
254
  inputs=prompt,
255
  outputs=token_indices_table,
@@ -262,6 +284,12 @@ with gr.Blocks(css="style.css") as demo:
262
  api_name=False,
263
  )
264
  token_indices_str.submit(
 
 
 
 
 
 
265
  fn=get_token_table,
266
  inputs=prompt,
267
  outputs=token_indices_table,
@@ -274,6 +302,12 @@ with gr.Blocks(css="style.css") as demo:
274
  api_name=False,
275
  )
276
  run_button.click(
 
 
 
 
 
 
277
  fn=get_token_table,
278
  inputs=prompt,
279
  outputs=token_indices_table,
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import random
7
 
8
  import gradio as gr
9
+ import numpy as np
10
  import PIL.Image
11
  import torch
12
  from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
 
31
  sd_pipe.to(device)
32
 
33
 
34
+ MAX_INFERENCE_STEPS = 100
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+
37
+
38
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
39
+ if randomize_seed:
40
+ seed = random.randint(0, MAX_SEED)
41
+ return seed
42
+
43
+
44
  def get_token_table(prompt: str) -> list[tuple[int, str]]:
45
  tokens = [ax_pipe.tokenizer.decode(t) for t in ax_pipe.tokenizer(prompt)["input_ids"]]
46
  tokens = tokens[1:-1]
 
52
  indices_to_alter_str: str,
53
  seed: int = 0,
54
  apply_attend_and_excite: bool = True,
55
+ num_inference_steps: int = 50,
56
  guidance_scale: float = 7.5,
57
  scale_factor: int = 20,
58
  thresholds: dict[int, float] = {
 
60
  20: 0.8,
61
  },
62
  max_iter_to_alter: int = 25,
63
+ progress=gr.Progress(track_tqdm=True),
64
  ) -> PIL.Image.Image:
65
+ if num_inference_steps > MAX_INFERENCE_STEPS:
66
+ raise gr.Error(f"Number of steps cannot exceed {MAX_INFERENCE_STEPS}.")
67
 
68
+ generator = torch.Generator(device=device).manual_seed(seed)
69
  if apply_attend_and_excite:
70
  try:
71
  token_indices = list(map(int, indices_to_alter_str.split(",")))
 
76
  token_indices=token_indices,
77
  guidance_scale=guidance_scale,
78
  generator=generator,
79
+ num_inference_steps=num_inference_steps,
80
  max_iter_to_alter=max_iter_to_alter,
81
  thresholds=thresholds,
82
  scale_factor=scale_factor,
 
86
  prompt=prompt,
87
  guidance_scale=guidance_scale,
88
  generator=generator,
89
+ num_inference_steps=num_inference_steps,
90
  )
91
  return out.images[0]
92
 
 
130
  max_lines=1,
131
  placeholder="4,16",
132
  )
133
+ apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True)
134
  seed = gr.Slider(
135
  label="Seed",
136
  minimum=0,
137
+ maximum=MAX_SEED,
138
  step=1,
139
  value=0,
140
  )
141
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
142
+ num_inference_steps = gr.Slider(
143
+ label="Number of inference steps",
144
+ minimum=1,
145
+ maximum=MAX_INFERENCE_STEPS,
146
  step=1,
147
  value=50,
148
  )
149
  guidance_scale = gr.Slider(
150
+ label="Guidance scale",
151
  minimum=0,
152
  maximum=50,
153
  step=0.1,
 
262
  token_indices_str,
263
  seed,
264
  apply_attend_and_excite,
265
+ num_inference_steps,
266
  guidance_scale,
267
  ]
268
  prompt.submit(
269
+ fn=randomize_seed_fn,
270
+ inputs=[seed, randomize_seed],
271
+ outputs=seed,
272
+ queue=False,
273
+ api_name=False,
274
+ ).then(
275
  fn=get_token_table,
276
  inputs=prompt,
277
  outputs=token_indices_table,
 
284
  api_name=False,
285
  )
286
  token_indices_str.submit(
287
+ fn=randomize_seed_fn,
288
+ inputs=[seed, randomize_seed],
289
+ outputs=seed,
290
+ queue=False,
291
+ api_name=False,
292
+ ).then(
293
  fn=get_token_table,
294
  inputs=prompt,
295
  outputs=token_indices_table,
 
302
  api_name=False,
303
  )
304
  run_button.click(
305
+ fn=randomize_seed_fn,
306
+ inputs=[seed, randomize_seed],
307
+ outputs=seed,
308
+ queue=False,
309
+ api_name=False,
310
+ ).then(
311
  fn=get_token_table,
312
  inputs=prompt,
313
  outputs=token_indices_table,