hysts HF staff commited on
Commit
1869bcd
·
1 Parent(s): f214f73
Files changed (2) hide show
  1. app.py +74 -16
  2. model.py +0 -61
app.py CHANGED
@@ -6,8 +6,8 @@ import os
6
 
7
  import gradio as gr
8
  import PIL.Image
9
-
10
- from model import Model
11
 
12
  DESCRIPTION = """\
13
  # Attend-and-Excite
@@ -17,7 +17,63 @@ Attend-and-Excite performs attention-based generative semantic guidance to mitig
17
  Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
18
  """
19
 
20
- model = Model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def process_example(
@@ -26,11 +82,13 @@ def process_example(
26
  seed: int,
27
  apply_attend_and_excite: bool,
28
  ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
29
- num_steps = 50
30
- guidance_scale = 7.5
31
-
32
- token_table = model.get_token_table(prompt)
33
- result = model.run(prompt, indices_to_alter_str, seed, apply_attend_and_excite, num_steps, guidance_scale)
 
 
34
  return token_table, result
35
 
36
 
@@ -176,11 +234,11 @@ with gr.Blocks(css="style.css") as demo:
176
  )
177
 
178
  show_token_indices_button.click(
179
- fn=model.get_token_table,
180
  inputs=prompt,
181
  outputs=token_indices_table,
182
  queue=False,
183
- api_name=False,
184
  )
185
 
186
  inputs = [
@@ -192,37 +250,37 @@ with gr.Blocks(css="style.css") as demo:
192
  guidance_scale,
193
  ]
194
  prompt.submit(
195
- fn=model.get_token_table,
196
  inputs=prompt,
197
  outputs=token_indices_table,
198
  queue=False,
199
  api_name=False,
200
  ).then(
201
- fn=model.run,
202
  inputs=inputs,
203
  outputs=result,
204
  api_name=False,
205
  )
206
  token_indices_str.submit(
207
- fn=model.get_token_table,
208
  inputs=prompt,
209
  outputs=token_indices_table,
210
  queue=False,
211
  api_name=False,
212
  ).then(
213
- fn=model.run,
214
  inputs=inputs,
215
  outputs=result,
216
  api_name=False,
217
  )
218
  run_button.click(
219
- fn=model.get_token_table,
220
  inputs=prompt,
221
  outputs=token_indices_table,
222
  queue=False,
223
  api_name=False,
224
  ).then(
225
- fn=model.run,
226
  inputs=inputs,
227
  outputs=result,
228
  api_name="run",
 
6
 
7
  import gradio as gr
8
  import PIL.Image
9
+ import torch
10
+ from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
11
 
12
  DESCRIPTION = """\
13
  # Attend-and-Excite
 
17
  Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
18
  """
19
 
20
+ if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
22
+
23
+ if torch.cuda.is_available():
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ model_id = "CompVis/stable-diffusion-v1-4"
26
+ ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id)
27
+ ax_pipe.to(device)
28
+ sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
29
+ sd_pipe.to(device)
30
+
31
+
32
+ def get_token_table(prompt: str) -> list[tuple[int, str]]:
33
+ tokens = [ax_pipe.tokenizer.decode(t) for t in ax_pipe.tokenizer(prompt)["input_ids"]]
34
+ tokens = tokens[1:-1]
35
+ return list(enumerate(tokens, start=1))
36
+
37
+
38
+ def run(
39
+ prompt: str,
40
+ indices_to_alter_str: str,
41
+ seed: int = 0,
42
+ apply_attend_and_excite: bool = True,
43
+ num_steps: int = 50,
44
+ guidance_scale: float = 7.5,
45
+ scale_factor: int = 20,
46
+ thresholds: dict[int, float] = {
47
+ 10: 0.5,
48
+ 20: 0.8,
49
+ },
50
+ max_iter_to_alter: int = 25,
51
+ ) -> PIL.Image.Image:
52
+ generator = torch.Generator(device=device).manual_seed(seed)
53
+
54
+ if apply_attend_and_excite:
55
+ try:
56
+ token_indices = list(map(int, indices_to_alter_str.split(",")))
57
+ except Exception:
58
+ raise ValueError("Invalid token indices.")
59
+ out = ax_pipe(
60
+ prompt=prompt,
61
+ token_indices=token_indices,
62
+ guidance_scale=guidance_scale,
63
+ generator=generator,
64
+ num_inference_steps=num_steps,
65
+ max_iter_to_alter=max_iter_to_alter,
66
+ thresholds=thresholds,
67
+ scale_factor=scale_factor,
68
+ )
69
+ else:
70
+ out = sd_pipe(
71
+ prompt=prompt,
72
+ guidance_scale=guidance_scale,
73
+ generator=generator,
74
+ num_inference_steps=num_steps,
75
+ )
76
+ return out.images[0]
77
 
78
 
79
  def process_example(
 
82
  seed: int,
83
  apply_attend_and_excite: bool,
84
  ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
85
+ token_table = get_token_table(prompt)
86
+ result = run(
87
+ prompt=prompt,
88
+ indices_to_alter_str=indices_to_alter_str,
89
+ seed=seed,
90
+ apply_attend_and_excite=apply_attend_and_excite,
91
+ )
92
  return token_table, result
93
 
94
 
 
234
  )
235
 
236
  show_token_indices_button.click(
237
+ fn=get_token_table,
238
  inputs=prompt,
239
  outputs=token_indices_table,
240
  queue=False,
241
+ api_name="get-token-table",
242
  )
243
 
244
  inputs = [
 
250
  guidance_scale,
251
  ]
252
  prompt.submit(
253
+ fn=get_token_table,
254
  inputs=prompt,
255
  outputs=token_indices_table,
256
  queue=False,
257
  api_name=False,
258
  ).then(
259
+ fn=run,
260
  inputs=inputs,
261
  outputs=result,
262
  api_name=False,
263
  )
264
  token_indices_str.submit(
265
+ fn=get_token_table,
266
  inputs=prompt,
267
  outputs=token_indices_table,
268
  queue=False,
269
  api_name=False,
270
  ).then(
271
+ fn=run,
272
  inputs=inputs,
273
  outputs=result,
274
  api_name=False,
275
  )
276
  run_button.click(
277
+ fn=get_token_table,
278
  inputs=prompt,
279
  outputs=token_indices_table,
280
  queue=False,
281
  api_name=False,
282
  ).then(
283
+ fn=run,
284
  inputs=inputs,
285
  outputs=result,
286
  api_name="run",
model.py DELETED
@@ -1,61 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import PIL.Image
4
- import torch
5
- from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
6
-
7
-
8
- class Model:
9
- def __init__(self):
10
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
- model_id = "CompVis/stable-diffusion-v1-4"
12
- self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id)
13
- self.ax_pipe.to(self.device)
14
- self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
15
- self.sd_pipe.to(self.device)
16
-
17
- def get_token_table(self, prompt: str):
18
- tokens = [self.ax_pipe.tokenizer.decode(t) for t in self.ax_pipe.tokenizer(prompt)["input_ids"]]
19
- tokens = tokens[1:-1]
20
- return list(enumerate(tokens, start=1))
21
-
22
- def run(
23
- self,
24
- prompt: str,
25
- indices_to_alter_str: str,
26
- seed: int = 0,
27
- apply_attend_and_excite: bool = True,
28
- num_steps: int = 50,
29
- guidance_scale: float = 7.5,
30
- scale_factor: int = 20,
31
- thresholds: dict[int, float] = {
32
- 10: 0.5,
33
- 20: 0.8,
34
- },
35
- max_iter_to_alter: int = 25,
36
- ) -> PIL.Image.Image:
37
- generator = torch.Generator(device=self.device).manual_seed(seed)
38
-
39
- if apply_attend_and_excite:
40
- try:
41
- token_indices = list(map(int, indices_to_alter_str.split(",")))
42
- except Exception:
43
- raise ValueError("Invalid token indices.")
44
- out = self.ax_pipe(
45
- prompt=prompt,
46
- token_indices=token_indices,
47
- guidance_scale=guidance_scale,
48
- generator=generator,
49
- num_inference_steps=num_steps,
50
- max_iter_to_alter=max_iter_to_alter,
51
- thresholds=thresholds,
52
- scale_factor=scale_factor,
53
- )
54
- else:
55
- out = self.sd_pipe(
56
- prompt=prompt,
57
- guidance_scale=guidance_scale,
58
- generator=generator,
59
- num_inference_steps=num_steps,
60
- )
61
- return out.images[0]