Royir commited on
Commit
e47c7c5
·
1 Parent(s): 334d5a7

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +25 -13
  2. app.py +37 -0
  3. compute_loss.py +211 -0
  4. environment.yaml +15 -0
  5. requirements.txt +7 -0
  6. run.py +66 -0
  7. syngen_diffusion_pipeline.py +495 -0
README.md CHANGED
@@ -1,13 +1,25 @@
1
- ---
2
- title: SynGen
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.34.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-sa-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Linguistic Binding in Diffusion Models: Enhancing Attribute Correspondence through Attention Map Alignment
2
+
3
+ ## Setup
4
+ Clone this repository and create a conda environment:
5
+ ```
6
+ conda env create -f environment.yaml
7
+ conda activate syngen
8
+ ```
9
+
10
+ If you rather use an existing environment, just run:
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ Finally, run:
16
+ ```
17
+ python -m spacy download en_core_web_trf
18
+ ```
19
+
20
+ ## Inference
21
+ ```
22
+ python run.py --prompt "a horned lion and a spotted monkey" --seed 1269
23
+ ```
24
+
25
+ Note that this will download the stable diffusion model `CompVis/stable-diffusion-v1-4`. If you rather use an existing copy of the model, provide the absolute path using `--model_path`.
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from syngen_diffusion_pipeline import SynGenDiffusionPipeline
5
+
6
+ model_path = 'CompVis/stable-diffusion-v1-4'
7
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
8
+ pipe = SynGenDiffusionPipeline.from_pretrained(model_path).to(device)
9
+
10
+
11
+ def generate_fn(prompt, seed):
12
+ generator = torch.Generator(device.type).manual_seed(int(seed))
13
+ result = pipe(prompt=prompt, generator=generator, num_inference_steps=50)
14
+ return result['images'][0]
15
+
16
+ title = "SynGen"
17
+ description = """
18
+ This is the demo for [SynGen](https://github.com/RoyiRa/Syntax-Guided-Generation), an image synthesis approach which first syntactically analyses the prompt to identify entities and their modifiers, and then uses a novel loss function that encourages the cross-attention maps to agree with the linguistic binding reflected by the syntax. Preprint: \"Linguistic Binding in Diffusion Models: Enhancing Attribute Correspondence through Attention Map Alignment\" (arxiv link coming soon).
19
+ """
20
+
21
+ examples = [
22
+ ["a yellow flamingo and a pink sunflower", "16"],
23
+ ["a yellow flamingo and a pink sunflower", "60"],
24
+ ["a checkered bowl in a cluttered room", "69"],
25
+ ["a checkered bowl in a cluttered room", "77"],
26
+ ["a horned lion and a spotted monkey", "1269"],
27
+ ["a horned lion and a spotted monkey", "9146"]
28
+ ]
29
+
30
+ prompt_textbox = gr.Textbox(label="Prompt", placeholder="A yellow flamingo and a pink sunflower", lines=1)
31
+ seed_textbox = gr.Textbox(label="Seed", placeholder="42", lines=1)
32
+
33
+ output = gr.Image(label="generation")
34
+ demo = gr.Interface(fn=generate_fn, inputs=[prompt_textbox, seed_textbox], outputs=output, examples=examples,
35
+ title=title, description=description, allow_flagging=False)
36
+
37
+ demo.launch()
compute_loss.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.distributions as dist
2
+ from typing import List, Dict
3
+ import itertools
4
+
5
+ start_token = "<|startoftext|>"
6
+ end_token = "<|endoftext|>"
7
+
8
+
9
+ def _get_outside_indices(subtree_indices, attn_map_idx_to_wp):
10
+ flattened_subtree_indices = _flatten_indices(subtree_indices)
11
+ outside_indices = [
12
+ map_idx
13
+ for map_idx in attn_map_idx_to_wp.keys() if (map_idx not in flattened_subtree_indices)
14
+ ]
15
+ return outside_indices
16
+
17
+
18
+ def _flatten_indices(related_indices):
19
+ flattened_related_indices = []
20
+ for item in related_indices:
21
+ if isinstance(item, list):
22
+ flattened_related_indices.extend(item)
23
+ else:
24
+ flattened_related_indices.append(item)
25
+ return flattened_related_indices
26
+
27
+
28
+ def split_indices(related_indices: List[int]):
29
+ noun = [related_indices[-1]] # assumes noun is always last in the list
30
+ modifier = related_indices[:-1]
31
+ if isinstance(modifier, int):
32
+ modifier = [modifier]
33
+ return noun, modifier
34
+
35
+
36
+ def _symmetric_kl(attention_map1, attention_map2):
37
+ # Convert map into a single distribution: 16x16 -> 256
38
+ if len(attention_map1.shape) > 1:
39
+ attention_map1 = attention_map1.reshape(-1)
40
+ if len(attention_map2.shape) > 1:
41
+ attention_map2 = attention_map2.reshape(-1)
42
+
43
+ p = dist.Categorical(probs=attention_map1)
44
+ q = dist.Categorical(probs=attention_map2)
45
+
46
+ kl_divergence_pq = dist.kl_divergence(p, q)
47
+ kl_divergence_qp = dist.kl_divergence(q, p)
48
+
49
+ avg_kl_divergence = (kl_divergence_pq + kl_divergence_qp) / 2
50
+ return avg_kl_divergence
51
+
52
+
53
+ def calculate_positive_loss(attention_maps, modifier, noun):
54
+ src_indices = modifier
55
+ dest_indices = noun
56
+
57
+ if isinstance(src_indices, list) and isinstance(dest_indices, list):
58
+ wp_pos_loss = [
59
+ _symmetric_kl(attention_maps[s], attention_maps[d])
60
+ for (s, d) in itertools.product(src_indices, dest_indices)
61
+ ]
62
+ positive_loss = max(wp_pos_loss)
63
+ elif isinstance(dest_indices, list):
64
+ wp_pos_loss = [
65
+ _symmetric_kl(attention_maps[src_indices], attention_maps[d])
66
+ for d in dest_indices
67
+ ]
68
+ positive_loss = max(wp_pos_loss)
69
+ elif isinstance(src_indices, list):
70
+ wp_pos_loss = [
71
+ _symmetric_kl(attention_maps[s], attention_maps[dest_indices])
72
+ for s in src_indices
73
+ ]
74
+ positive_loss = max(wp_pos_loss)
75
+ else:
76
+ positive_loss = _symmetric_kl(
77
+ attention_maps[src_indices], attention_maps[dest_indices]
78
+ )
79
+
80
+ return positive_loss
81
+
82
+
83
+ def _calculate_outside_loss(attention_maps, src_indices, outside_loss):
84
+ negative_loss = []
85
+ computed_pairs = set()
86
+ pair_counter = 0
87
+
88
+ for outside_idx in outside_loss:
89
+ if isinstance(src_indices, list):
90
+ wp_neg_loss = []
91
+ for t in src_indices:
92
+ pair_key = (t, outside_idx)
93
+ if pair_key not in computed_pairs:
94
+ wp_neg_loss.append(
95
+ _symmetric_kl(
96
+ attention_maps[t], attention_maps[outside_idx]
97
+ )
98
+ )
99
+ computed_pairs.add(pair_key)
100
+ negative_loss.append(max(wp_neg_loss) if wp_neg_loss else 0)
101
+ pair_counter += 1
102
+
103
+ else:
104
+ pair_key = (src_indices, outside_idx)
105
+ if pair_key not in computed_pairs:
106
+ negative_loss.append(
107
+ _symmetric_kl(
108
+ attention_maps[src_indices], attention_maps[outside_idx]
109
+ )
110
+ )
111
+ computed_pairs.add(pair_key)
112
+ pair_counter += 1
113
+
114
+ return negative_loss, pair_counter
115
+
116
+
117
+ def align_wordpieces_indices(
118
+ wordpieces2indices, start_idx, target_word
119
+ ):
120
+ """
121
+ Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`)
122
+ """
123
+
124
+ wp_indices = [start_idx]
125
+ wp = wordpieces2indices[start_idx].replace("</w>", "")
126
+
127
+ # Run over the next wordpieces in the sequence (which is why we use +1)
128
+ for wp_idx in range(start_idx + 1, len(wordpieces2indices)):
129
+ if wp == target_word:
130
+ break
131
+
132
+ wp2 = wordpieces2indices[wp_idx].replace("</w>", "")
133
+ if target_word.startswith(wp + wp2) and wp2 != target_word:
134
+ wp += wordpieces2indices[wp_idx].replace("</w>", "")
135
+ wp_indices.append(wp_idx)
136
+ else:
137
+ wp_indices = (
138
+ []
139
+ ) # if there's no match, you want to clear the list and finish
140
+ break
141
+
142
+ return wp_indices
143
+
144
+
145
+ def extract_attribution_indices(prompt, parser):
146
+ doc = parser(prompt)
147
+ subtrees = []
148
+ modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
149
+
150
+ for w in doc:
151
+ if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
152
+ continue
153
+ subtree = []
154
+ stack = []
155
+ for child in w.children:
156
+ if child.dep_ in modifiers:
157
+ subtree.append(child)
158
+ stack.extend(child.children)
159
+
160
+ while stack:
161
+ node = stack.pop()
162
+ if node.dep_ in modifiers or node.dep_ == "conj":
163
+ subtree.append(node)
164
+ stack.extend(node.children)
165
+ if subtree:
166
+ subtree.append(w)
167
+ subtrees.append(subtree)
168
+ return subtrees
169
+
170
+
171
+ def calculate_negative_loss(
172
+ attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
173
+ ):
174
+ outside_indices = _get_outside_indices(subtree_indices, attn_map_idx_to_wp)
175
+ negative_modifier_loss, num_modifier_pairs = _calculate_outside_loss(
176
+ attention_maps, modifier, outside_indices
177
+ )
178
+ negative_noun_loss, num_noun_pairs = _calculate_outside_loss(
179
+ attention_maps, noun, outside_indices
180
+ )
181
+
182
+ negative_modifier_loss = -sum(negative_modifier_loss) / len(outside_indices)
183
+ negative_noun_loss = -sum(negative_noun_loss) / len(outside_indices)
184
+
185
+ negative_loss = (negative_modifier_loss + negative_noun_loss) / 2
186
+
187
+ return negative_loss
188
+
189
+ def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
190
+ """Utility function to list the indices of the tokens you wish to alte"""
191
+ ids = tokenizer(prompt).input_ids
192
+ indices = {
193
+ i: tok
194
+ for tok, i in zip(
195
+ tokenizer.convert_ids_to_tokens(ids), range(len(ids))
196
+ )
197
+ }
198
+ return indices
199
+
200
+ def get_attention_map_index_to_wordpiece(tokenizer, prompt):
201
+ attn_map_idx_to_wp = {}
202
+
203
+ wordpieces2indices = get_indices(tokenizer, prompt)
204
+
205
+ # Ignore `start_token` and `end_token`
206
+ for i in list(wordpieces2indices.keys())[1:-1]:
207
+ wordpiece = wordpieces2indices[i]
208
+ wordpiece = wordpiece.replace("</w>", "")
209
+ attn_map_idx_to_wp[i] = wordpiece
210
+
211
+ return attn_map_idx_to_wp
environment.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: syngen
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.11.2
7
+ - pip=23.0.1
8
+ - pip:
9
+ - diffusers==0.14.0
10
+ - numpy==1.23.3
11
+ - spacy==3.5.2
12
+ - tqdm==4.65.0
13
+ - transformers @ git+https://github.com/huggingface/transformers.git@dbc12269ed5546b2da9236b9f1078b95b6a4d3d5
14
+ - torch==2.0.0
15
+ - accelerate==0.18.0
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.14.0
2
+ numpy==1.23.3
3
+ spacy==3.5.2
4
+ tqdm==4.65.0
5
+ transformers @ git+https://github.com/huggingface/transformers.git@dbc12269ed5546b2da9236b9f1078b95b6a4d3d5
6
+ torch==2.0.0
7
+ accelerate==0.18.0
run.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from syngen_diffusion_pipeline import SynGenDiffusionPipeline
6
+
7
+
8
+ def main(prompt, seed, output_directory, model_path):
9
+ pipe = load_model(model_path)
10
+ image = generate(pipe, prompt, seed)
11
+ save_image(image, prompt, seed, output_directory)
12
+
13
+
14
+ def load_model(model_path):
15
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
16
+ pipe = SynGenDiffusionPipeline.from_pretrained(model_path).to(device)
17
+
18
+ return pipe
19
+
20
+
21
+ def generate(pipe, prompt, seed):
22
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
23
+ generator = torch.Generator(device.type).manual_seed(seed)
24
+ result = pipe(prompt=prompt, generator=generator)
25
+ return result['images'][0]
26
+
27
+
28
+ def save_image(image, prompt, seed, output_directory):
29
+ if not os.path.exists(output_directory):
30
+ os.makedirs(output_directory)
31
+
32
+ file_name = f"{output_directory}/{prompt}_{seed}.png"
33
+ image.save(file_name)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument(
40
+ "--prompt",
41
+ type=str,
42
+ default="a checkered bowl on a red and blue table"
43
+ )
44
+
45
+ parser.add_argument(
46
+ '--seed',
47
+ type=int,
48
+ default=1924
49
+ )
50
+
51
+ parser.add_argument(
52
+ '--output_directory',
53
+ type=str,
54
+ default='./output'
55
+ )
56
+
57
+ parser.add_argument(
58
+ '--model_path',
59
+ type=str,
60
+ default='CompVis/stable-diffusion-v1-4',
61
+ help='The path to the model (this will download the model if the path doesn\'t exist)'
62
+ )
63
+
64
+ args = parser.parse_args()
65
+
66
+ main(args.prompt, args.seed, args.output_directory, args.model_path)
syngen_diffusion_pipeline.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Any, Callable, Dict, Optional, Union, List
3
+
4
+ import spacy
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
7
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
8
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
9
+ EXAMPLE_DOC_STRING,
10
+ )
11
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_attend_and_excite import (
12
+ AttentionStore,
13
+ AttendExciteCrossAttnProcessor,
14
+ )
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ logging,
18
+ replace_example_docstring,
19
+ )
20
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
21
+
22
+ from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, \
23
+ align_wordpieces_indices, extract_attribution_indices
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class SynGenDiffusionPipeline(StableDiffusionPipeline):
29
+ def __init__(self,
30
+ vae: AutoencoderKL,
31
+ text_encoder: CLIPTextModel,
32
+ tokenizer: CLIPTokenizer,
33
+ unet: UNet2DConditionModel,
34
+ scheduler: KarrasDiffusionSchedulers,
35
+ safety_checker: StableDiffusionSafetyChecker,
36
+ feature_extractor: CLIPFeatureExtractor,
37
+ requires_safety_checker: bool = True,
38
+ ):
39
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor,
40
+ requires_safety_checker)
41
+
42
+ self.parser = spacy.load("en_core_web_trf")
43
+
44
+ def _aggregate_and_get_attention_maps_per_token(self):
45
+ attention_maps = self.attention_store.aggregate_attention(
46
+ from_where=("up", "down", "mid"),
47
+ )
48
+ attention_maps_list = _get_attention_maps_list(
49
+ attention_maps=attention_maps
50
+ )
51
+ return attention_maps_list
52
+
53
+ @staticmethod
54
+ def _update_latent(
55
+ latents: torch.Tensor, loss: torch.Tensor, step_size: float
56
+ ) -> torch.Tensor:
57
+ """Update the latent according to the computed loss."""
58
+ grad_cond = torch.autograd.grad(
59
+ loss.requires_grad_(True), [latents], retain_graph=True
60
+ )[0]
61
+ latents = latents - step_size * grad_cond
62
+ return latents
63
+
64
+ def register_attention_control(self):
65
+ attn_procs = {}
66
+ cross_att_count = 0
67
+ for name in self.unet.attn_processors.keys():
68
+ if name.startswith("mid_block"):
69
+ place_in_unet = "mid"
70
+ elif name.startswith("up_blocks"):
71
+ place_in_unet = "up"
72
+ elif name.startswith("down_blocks"):
73
+ place_in_unet = "down"
74
+ else:
75
+ continue
76
+
77
+ cross_att_count += 1
78
+ attn_procs[name] = AttendExciteCrossAttnProcessor(
79
+ attnstore=self.attention_store, place_in_unet=place_in_unet
80
+ )
81
+
82
+ self.unet.set_attn_processor(attn_procs)
83
+ self.attention_store.num_att_layers = cross_att_count
84
+
85
+ # Based on StableDiffusionPipeline.__call__ . New code is annotated with NEW.
86
+ @torch.no_grad()
87
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
88
+ def __call__(
89
+ self,
90
+ prompt: Union[str, List[str]] = None,
91
+ height: Optional[int] = None,
92
+ width: Optional[int] = None,
93
+ num_inference_steps: int = 50,
94
+ guidance_scale: float = 7.5,
95
+ negative_prompt: Optional[Union[str, List[str]]] = None,
96
+ num_images_per_prompt: Optional[int] = 1,
97
+ eta: float = 0.0,
98
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
99
+ latents: Optional[torch.FloatTensor] = None,
100
+ prompt_embeds: Optional[torch.FloatTensor] = None,
101
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
102
+ output_type: Optional[str] = "pil",
103
+ return_dict: bool = True,
104
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
105
+ callback_steps: int = 1,
106
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
107
+ syngen_step_size: int = 20,
108
+ ):
109
+ r"""
110
+ Function invoked when calling the pipeline for generation.
111
+
112
+ Args:
113
+ prompt (`str` or `List[str]`, *optional*):
114
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
115
+ instead.
116
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
117
+ The height in pixels of the generated image.
118
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
119
+ The width in pixels of the generated image.
120
+ num_inference_steps (`int`, *optional*, defaults to 50):
121
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
122
+ expense of slower inference.
123
+ guidance_scale (`float`, *optional*, defaults to 7.5):
124
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
125
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
126
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
127
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
128
+ usually at the expense of lower image quality.
129
+ negative_prompt (`str` or `List[str]`, *optional*):
130
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
131
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
132
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
133
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
134
+ The number of images to generate per prompt.
135
+ eta (`float`, *optional*, defaults to 0.0):
136
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
137
+ [`schedulers.DDIMScheduler`], will be ignored for others.
138
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
139
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
140
+ to make generation deterministic.
141
+ latents (`torch.FloatTensor`, *optional*):
142
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
143
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
144
+ tensor will ge generated by sampling using the supplied random `generator`.
145
+ prompt_embeds (`torch.FloatTensor`, *optional*):
146
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
147
+ provided, text embeddings will be generated from `prompt` input argument.
148
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
149
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
150
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
151
+ argument.
152
+ output_type (`str`, *optional*, defaults to `"pil"`):
153
+ The output format of the generate image. Choose between
154
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
155
+ return_dict (`bool`, *optional*, defaults to `True`):
156
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
157
+ plain tuple.
158
+ callback (`Callable`, *optional*):
159
+ A function that will be called every `callback_steps` steps during inference. The function will be
160
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
161
+ callback_steps (`int`, *optional*, defaults to 1):
162
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
163
+ called at every step.
164
+ cross_attention_kwargs (`dict`, *optional*):
165
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
166
+ `self.processor` in
167
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
168
+ syngen_step_size (`int`, *optional*, default to 20):
169
+ Controls the step size of each SynGen update.
170
+
171
+ Examples:
172
+
173
+ Returns:
174
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
175
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
176
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
177
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
178
+ (nsfw) content, according to the `safety_checker`.
179
+ """
180
+ # 0. Default height and width to unet
181
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
182
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
183
+
184
+ # 1. Check inputs. Raise error if not correct
185
+ self.check_inputs(
186
+ prompt,
187
+ height,
188
+ width,
189
+ callback_steps,
190
+ negative_prompt,
191
+ prompt_embeds,
192
+ negative_prompt_embeds,
193
+ )
194
+
195
+ # 2. Define call parameters
196
+ if prompt is not None and isinstance(prompt, str):
197
+ batch_size = 1
198
+ elif prompt is not None and isinstance(prompt, list):
199
+ batch_size = len(prompt)
200
+ else:
201
+ batch_size = prompt_embeds.shape[0]
202
+
203
+ device = self._execution_device
204
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
206
+ # corresponds to doing no classifier free guidance.
207
+ do_classifier_free_guidance = guidance_scale > 1.0
208
+
209
+ # 3. Encode input prompt
210
+ prompt_embeds = self._encode_prompt(
211
+ prompt,
212
+ device,
213
+ num_images_per_prompt,
214
+ do_classifier_free_guidance,
215
+ negative_prompt,
216
+ prompt_embeds=prompt_embeds,
217
+ negative_prompt_embeds=negative_prompt_embeds,
218
+ )
219
+
220
+ # 4. Prepare timesteps
221
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
222
+ timesteps = self.scheduler.timesteps
223
+
224
+ # 5. Prepare latent variables
225
+ num_channels_latents = self.unet.in_channels
226
+ latents = self.prepare_latents(
227
+ batch_size * num_images_per_prompt,
228
+ num_channels_latents,
229
+ height,
230
+ width,
231
+ prompt_embeds.dtype,
232
+ device,
233
+ generator,
234
+ latents,
235
+ )
236
+
237
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
238
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
239
+
240
+ # NEW - stores the attention calculated in the unet
241
+ self.attention_store = AttentionStore()
242
+ self.register_attention_control()
243
+
244
+ # NEW
245
+ text_embeddings = (
246
+ prompt_embeds[batch_size * num_images_per_prompt:] if do_classifier_free_guidance else prompt_embeds
247
+ )
248
+
249
+ # 7. Denoising loop
250
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
251
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
252
+ for i, t in enumerate(timesteps):
253
+ # NEW
254
+ latents = self._syngen_step(
255
+ latents,
256
+ text_embeddings,
257
+ t,
258
+ i,
259
+ syngen_step_size,
260
+ cross_attention_kwargs,
261
+ prompt,
262
+ max_iter_to_alter=25,
263
+ )
264
+
265
+ # expand the latents if we are doing classifier free guidance
266
+ latent_model_input = (
267
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
268
+ )
269
+ latent_model_input = self.scheduler.scale_model_input(
270
+ latent_model_input, t
271
+ )
272
+
273
+ # predict the noise residual
274
+ noise_pred = self.unet(
275
+ latent_model_input,
276
+ t,
277
+ encoder_hidden_states=prompt_embeds,
278
+ cross_attention_kwargs=cross_attention_kwargs,
279
+ ).sample
280
+
281
+ # perform guidance
282
+ if do_classifier_free_guidance:
283
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
284
+ noise_pred = noise_pred_uncond + guidance_scale * (
285
+ noise_pred_text - noise_pred_uncond
286
+ )
287
+
288
+ # compute the previous noisy sample x_t -> x_t-1
289
+ latents = self.scheduler.step(
290
+ noise_pred, t, latents, **extra_step_kwargs
291
+ ).prev_sample
292
+
293
+ # call the callback, if provided
294
+ if i == len(timesteps) - 1 or (
295
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
296
+ ):
297
+ progress_bar.update()
298
+ if callback is not None and i % callback_steps == 0:
299
+ callback(i, t, latents)
300
+
301
+ if output_type == "latent":
302
+ image = latents
303
+ has_nsfw_concept = None
304
+ elif output_type == "pil":
305
+ # 8. Post-processing
306
+ image = self.decode_latents(latents)
307
+
308
+ # 9. Run safety checker
309
+ image, has_nsfw_concept = self.run_safety_checker(
310
+ image, device, prompt_embeds.dtype
311
+ )
312
+
313
+ # 10. Convert to PIL
314
+ image = self.numpy_to_pil(image)
315
+ else:
316
+ # 8. Post-processing
317
+ image = self.decode_latents(latents)
318
+
319
+ # 9. Run safety checker
320
+ image, has_nsfw_concept = self.run_safety_checker(
321
+ image, device, prompt_embeds.dtype
322
+ )
323
+
324
+ # Offload last model to CPU
325
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
326
+ self.final_offload_hook.offload()
327
+
328
+ if not return_dict:
329
+ return (image, has_nsfw_concept)
330
+
331
+ return StableDiffusionPipelineOutput(
332
+ images=image, nsfw_content_detected=has_nsfw_concept
333
+ )
334
+
335
+ def _syngen_step(
336
+ self,
337
+ latents,
338
+ text_embeddings,
339
+ t,
340
+ i,
341
+ step_size,
342
+ cross_attention_kwargs,
343
+ prompt,
344
+ max_iter_to_alter=25,
345
+ ):
346
+ with torch.enable_grad():
347
+ latents = latents.clone().detach().requires_grad_(True)
348
+ updated_latents = []
349
+ for latent, text_embedding in zip(latents, text_embeddings):
350
+ # Forward pass of denoising with text conditioning
351
+ latent = latent.unsqueeze(0)
352
+ text_embedding = text_embedding.unsqueeze(0)
353
+
354
+ self.unet(
355
+ latent,
356
+ t,
357
+ encoder_hidden_states=text_embedding,
358
+ cross_attention_kwargs=cross_attention_kwargs,
359
+ ).sample
360
+ self.unet.zero_grad()
361
+
362
+ # Get attention maps
363
+ attention_maps = self._aggregate_and_get_attention_maps_per_token()
364
+
365
+ loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
366
+
367
+ # Perform gradient update
368
+ if i < max_iter_to_alter:
369
+ if loss != 0:
370
+ latent = self._update_latent(
371
+ latents=latent, loss=loss, step_size=step_size
372
+ )
373
+ logger.info(f"Iteration {i} | Loss: {loss:0.4f}")
374
+
375
+ updated_latents.append(latent)
376
+
377
+ latents = torch.cat(updated_latents, dim=0)
378
+
379
+ return latents
380
+
381
+ def _compute_loss(
382
+ self, attention_maps: List[torch.Tensor], prompt: Union[str, List[str]]
383
+ ) -> torch.Tensor:
384
+ attn_map_idx_to_wp = get_attention_map_index_to_wordpiece(self.tokenizer, prompt)
385
+ loss = self._attribution_loss(attention_maps, prompt, attn_map_idx_to_wp)
386
+
387
+ return loss
388
+
389
+
390
+ def _attribution_loss(
391
+ self,
392
+ attention_maps: List[torch.Tensor],
393
+ prompt: Union[str, List[str]],
394
+ attn_map_idx_to_wp,
395
+ ) -> torch.Tensor:
396
+ subtrees_indices = self._extract_attribution_indices(prompt)
397
+ loss = 0
398
+
399
+ for subtree_indices in subtrees_indices:
400
+ noun, modifier = split_indices(subtree_indices)
401
+ all_subtree_pairs = list(itertools.product(noun, modifier))
402
+ positive_loss, negative_loss = self._calculate_losses(
403
+ attention_maps,
404
+ all_subtree_pairs,
405
+ subtree_indices,
406
+ attn_map_idx_to_wp,
407
+ )
408
+ loss += positive_loss
409
+ loss += negative_loss
410
+
411
+ return loss
412
+
413
+ def _calculate_losses(
414
+ self,
415
+ attention_maps,
416
+ all_subtree_pairs,
417
+ subtree_indices,
418
+ attn_map_idx_to_wp,
419
+ ):
420
+ positive_loss = []
421
+ negative_loss = []
422
+ for pair in all_subtree_pairs:
423
+ noun, modifier = pair
424
+ positive_loss.append(
425
+ calculate_positive_loss(attention_maps, modifier, noun)
426
+ )
427
+ negative_loss.append(
428
+ calculate_negative_loss(
429
+ attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
430
+ )
431
+ )
432
+
433
+ positive_loss = sum(positive_loss)
434
+ negative_loss = sum(negative_loss)
435
+
436
+ return positive_loss, negative_loss
437
+
438
+ def _align_indices(self, prompt, spacy_pairs):
439
+ wordpieces2indices = get_indices(self.tokenizer, prompt)
440
+ paired_indices = []
441
+ collected_spacy_indices = (
442
+ set()
443
+ ) # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word)
444
+
445
+ for pair in spacy_pairs:
446
+ curr_collected_wp_indices = (
447
+ []
448
+ ) # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear"))
449
+ for member in pair:
450
+ for idx, wp in wordpieces2indices.items():
451
+ if wp in [start_token, end_token]:
452
+ continue
453
+
454
+ wp = wp.replace("</w>", "")
455
+ if member.text == wp:
456
+ if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices:
457
+ curr_collected_wp_indices.append(idx)
458
+ break
459
+ # take care of wordpieces that are split up
460
+ elif member.text.startswith(wp) and wp != member.text: # can maybe be while loop
461
+ wp_indices = align_wordpieces_indices(
462
+ wordpieces2indices, idx, member.text
463
+ )
464
+ # check if all wp_indices are not already in collected_spacy_indices
465
+ if wp_indices and (wp_indices not in curr_collected_wp_indices) and all([wp_idx not in collected_spacy_indices for wp_idx in wp_indices]):
466
+ curr_collected_wp_indices.append(wp_indices)
467
+ break
468
+
469
+ for collected_idx in curr_collected_wp_indices:
470
+ if isinstance(collected_idx, list):
471
+ for idx in collected_idx:
472
+ collected_spacy_indices.add(idx)
473
+ else:
474
+ collected_spacy_indices.add(collected_idx)
475
+
476
+ paired_indices.append(curr_collected_wp_indices)
477
+
478
+ return paired_indices
479
+
480
+ def _extract_attribution_indices(self, prompt):
481
+ pairs = extract_attribution_indices(prompt, self.parser)
482
+ paired_indices = self._align_indices(prompt, pairs)
483
+ return paired_indices
484
+
485
+
486
+
487
+ def _get_attention_maps_list(
488
+ attention_maps: torch.Tensor
489
+ ) -> List[torch.Tensor]:
490
+ attention_maps *= 100
491
+ attention_maps_list = [
492
+ attention_maps[:, :, i] for i in range(attention_maps.shape[2])
493
+ ]
494
+
495
+ return attention_maps_list