HuiZhang commited on
Commit
be186ed
Β·
verified Β·
1 Parent(s): e6bd7ff

Upload 8 files

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: CreatiLayout
3
- emoji: πŸƒ
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
 
1
  ---
2
  title: CreatiLayout
3
+ emoji: πŸŒ–
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel
5
+ from src.pipeline.pipeline_CreatiLayout import CreatiLayoutSD3Pipeline
6
+ from utils.bbox_visualization import bbox_visualization,scale_boxes
7
+ from PIL import Image
8
+ import os
9
+ import pandas as pd
10
+ from huggingface_hub import login
11
+
12
+ hf_token = os.getenv("HF_TOKEN")
13
+
14
+ if hf_token is None:
15
+ raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret.")
16
+
17
+ login(token=hf_token)
18
+
19
+ model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
20
+ ckpt_path = "Benson1237/CreatiLayout"
21
+
22
+ transformer_additional_kwargs = dict(attention_type="layout",strict=True)
23
+
24
+ transformer = SiamLayoutSD3Transformer2DModel.from_pretrained(
25
+ ckpt_path, subfolder="transformer", torch_dtype=torch.float16,**transformer_additional_kwargs)
26
+
27
+ pipe = CreatiLayoutSD3Pipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
28
+ pipe = pipe.to("cuda")
29
+
30
+ print("pipeline is loaded.")
31
+
32
+ @spaces.GPU
33
+ def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame, boxes:pd.DataFrame,seed: int=42, randomize_seed: bool=False, guidance_scale: float=7.5, num_inference_steps: int=50):
34
+
35
+ if randomize_seed:
36
+ seed = torch.randint(0, 100, (1,)).item()
37
+
38
+ height = 1024
39
+ width = 1024
40
+
41
+ box_detail_phrases_list_tmp = box_detail_phrases_list.values.tolist()
42
+ box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp]
43
+ boxes = boxes.astype(float).values.tolist()
44
+
45
+ white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)')
46
+ show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp}
47
+ bbox_visualization_img = bbox_visualization(white_image,show_input)
48
+
49
+ result_img = pipe(
50
+ prompt=global_caption,
51
+ generator=torch.Generator(device="cuda").manual_seed(seed),
52
+ guidance_scale=guidance_scale,
53
+ num_inference_steps=num_inference_steps,
54
+ bbox_phrases=box_detail_phrases_list_tmp,
55
+ bbox_raw=boxes,
56
+ height=height,
57
+ width=width
58
+ ).images[0]
59
+
60
+ return bbox_visualization_img, result_img
61
+
62
+ def get_samples():
63
+ sample_list = [
64
+ {
65
+ "global_caption": "A picturesque scene features Iron Man standing confidently on a rugged rock by the sea, holding a drawing board with his hands. The board displays the words 'Creative Layout' in a playful, hand-drawn font. The serene sea shimmers under the setting sun. The sky is painted with a gradient of warm colors, from deep oranges to soft purples.",
66
+ "region_caption_list": [
67
+ "Iron Man standing confidently on a rugged rock.",
68
+ "A rugged rock by the sea.",
69
+ "A drawing board with the words \"Creative Layout\" in a playful, hand-drawn font.",
70
+ "The serene sea shimmers under the setting sun.",
71
+ "The sky is a shade of deep orange to soft purple."
72
+ ],
73
+ "region_bboxes_list": [
74
+ [0.40, 0.35, 0.55, 0.80],
75
+ [0.35, 0.75, 0.60, 0.95],
76
+ [0.40, 0.45, 0.55, 0.65],
77
+ [0.00, 0.30, 1.00, 0.90],
78
+ [0.00, 0.00, 1.00, 0.30]
79
+ ]
80
+ },
81
+ {
82
+ "global_caption": "This is a photo showcasing two wooden benches in a park. The bench on the left is painted in a vibrant blue, while the one on the right is painted in a green. Both are placed on a path paved with stones, surrounded by lush trees and shrubs. The sunlight filters through the leaves, casting dappled shadows on the ground, creating a tranquil and comfortable atmosphere.",
83
+ "region_caption_list": [
84
+ "A weathered, blue wooden bench with green elements in a natural setting.",
85
+ "Old, weathered wooden benches with green and blue paint.",
86
+ "A dirt path in a park with green grass on the sides and two colorful wooden benches.",
87
+ "Thick, verdant foliage of mature trees in a dense forest."
88
+ ],
89
+ "region_bboxes_list": [
90
+ [0.30, 0.44, 0.62, 0.78],
91
+ [0.54, 0.41, 0.75, 0.65],
92
+ [0.00, 0.39, 1.00, 1.00],
93
+ [0.00, 0.00, 1.00, 0.43]
94
+ ]
95
+ },
96
+ {
97
+ "global_caption": "This is a wedding photo taken in a photography studio, showing a newlywed couple sitting on a brown leather sofa in a modern indoor setting. The groom is dressed in a pink suit, paired with a pink tie and white shirt, while the bride is wearing a white wedding dress with a long veil. They are sitting on a brown leather sofa, with a wooden table in front of them, on which a bouquet of flowers is placed. The background is a bar with a staircase and a wall decorated with lights, creating a warm and romantic atmosphere.",
98
+ "region_caption_list": [
99
+ "A floral arrangement consisting of roses, carnations, and eucalyptus leaves on a wooden surface.",
100
+ "A white wedding dress with off-the-shoulder ruffles and a long, sheer veil.",
101
+ "A polished wooden table with visible grain and knots.",
102
+ "A close-up of a dark brown leather sofa with tufted upholstery and button details.",
103
+ "A man in a pink suit with a white shirt and red tie, sitting on a leather armchair.",
104
+ "A person in a suit seated on a leather armchair near a wooden staircase with books and bottles.",
105
+ "Bride in white gown with veil, groom in maroon suit and pink tie, seated on leather armchairs."
106
+ ],
107
+ "region_bboxes_list": [
108
+ [0.09, 0.65, 0.31, 0.93],
109
+ [0.62, 0.25, 0.89, 0.90],
110
+ [0.01, 0.70, 0.78, 0.99],
111
+ [0.76, 0.65, 1.00, 0.99],
112
+ [0.27, 0.32, 0.72, 0.75],
113
+ [0.00, 0.01, 0.52, 0.72],
114
+ [0.27, 0.09, 0.94, 0.89]
115
+ ]
116
+ }
117
+
118
+ ]
119
+ return [[sample["global_caption"], [[caption] for caption in sample["region_caption_list"]], sample["region_bboxes_list"]] for sample in sample_list]
120
+
121
+
122
+
123
+ with gr.Blocks() as demo:
124
+ gr.Markdown("# CreatiLayout / Layout-to-Image generation")
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ global_caption = gr.Textbox(lines=2, label="Global Caption")
129
+ box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions")
130
+ boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)")
131
+ with gr.Accordion("Advanced Settings", open=False):
132
+ seed = gr.Slider(0, 100, step=1, label="Seed", value=42)
133
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
134
+ guidance_scale = gr.Slider(1, 30, step=0.5, label="Guidance Scale", value=7.5)
135
+ num_inference_steps = gr.Slider(1, 50, step=1, label="Number of inference steps", value=50)
136
+ with gr.Column():
137
+ bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization")
138
+
139
+ with gr.Column():
140
+ output_image = gr.Image(type="pil", label="Generated Image")
141
+
142
+
143
+
144
+ gr.Button("Generate").click(
145
+ fn=process_image_and_text,
146
+ inputs=[global_caption, box_detail_phrases_list, boxes, seed, randomize_seed, guidance_scale, num_inference_steps],
147
+ outputs=[bbox_visualization_img, output_image]
148
+ )
149
+
150
+
151
+ gr.Examples(
152
+ examples=get_samples(),
153
+ inputs=[global_caption, box_detail_phrases_list, boxes],
154
+ outputs=[bbox_visualization_img, output_image],
155
+ fn=process_image_and_text,
156
+ cache_examples=True
157
+ )
158
+
159
+
160
+
161
+
162
+
163
+ if __name__ == "__main__":
164
+ demo.launch()
165
+
166
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch
5
+ transformers
6
+ xformers
7
+ huggingface_hub
8
+ sentencepiece
9
+ protobuf
10
+ opencv-python
11
+ bitsandbytes
12
+ prodigyopt
13
+ beautifulsoup4
src/models/attention_SiamLayout.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+
6
+ from diffusers.utils import deprecate, logging
7
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
8
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
9
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
10
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ def zero_module(module):
16
+ """
17
+ Zero out the parameters of a module and return it.
18
+ """
19
+ for p in module.parameters():
20
+ p.detach().zero_()
21
+ return module
22
+
23
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
24
+ # "feed_forward_chunk_size" can be used to save memory
25
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
26
+ raise ValueError(
27
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
28
+ )
29
+
30
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
31
+ ff_output = torch.cat(
32
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
33
+ dim=chunk_dim,
34
+ )
35
+ return ff_output
36
+
37
+ @maybe_allow_in_graph
38
+ class SiamLayoutJointTransformerBlock(nn.Module):
39
+
40
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False,attention_type="default",bbox_pre_only=True,bbox_with_temb = False):
41
+ super().__init__()
42
+
43
+ # text
44
+ self.context_pre_only = context_pre_only
45
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
46
+
47
+ # bbox
48
+ self.bbox_pre_only = bbox_pre_only
49
+
50
+ if bbox_pre_only:
51
+ if bbox_with_temb:
52
+ bbox_norm_type = "ada_norm_continous"
53
+ else:
54
+ bbox_norm_type = "LayerNorm"
55
+ else:
56
+ bbox_norm_type = "ada_norm_zero"
57
+
58
+ self.bbox_norm_type = bbox_norm_type
59
+
60
+ # img
61
+ self.norm1 = AdaLayerNormZero(dim)
62
+
63
+ # text
64
+ if context_norm_type == "ada_norm_continous":
65
+ self.norm1_context = AdaLayerNormContinuous(
66
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
67
+ )
68
+ elif context_norm_type == "ada_norm_zero":
69
+ self.norm1_context = AdaLayerNormZero(dim)
70
+ else:
71
+ raise ValueError(
72
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
73
+ )
74
+ if hasattr(F, "scaled_dot_product_attention"):
75
+ processor = JointAttnProcessor2_0()
76
+ else:
77
+ raise ValueError(
78
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
79
+ )
80
+ self.attn = Attention(
81
+ query_dim=dim,
82
+ cross_attention_dim=None,
83
+ added_kv_proj_dim=dim,
84
+ dim_head=attention_head_dim,
85
+ heads=num_attention_heads,
86
+ out_dim=dim,
87
+ context_pre_only=context_pre_only,
88
+ bias=True,
89
+ processor=processor,
90
+ )
91
+
92
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
93
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
94
+
95
+ if not context_pre_only:
96
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
97
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
98
+ else:
99
+ self.norm2_context = None
100
+ self.ff_context = None
101
+
102
+ # let chunk size default to None
103
+ self._chunk_size = None
104
+ self._chunk_dim = 0
105
+
106
+ self.attention_type = attention_type
107
+ if self.attention_type == "layout":
108
+ self.bbox_fuser_block = Attention(
109
+ query_dim=dim,
110
+ cross_attention_dim=None,
111
+ added_kv_proj_dim=dim,
112
+ dim_head=attention_head_dim,
113
+ heads=num_attention_heads,
114
+ out_dim=dim,
115
+ context_pre_only=bbox_pre_only,
116
+ bias=True,
117
+ processor=processor,
118
+ )
119
+
120
+ self.bbox_forward = zero_module(nn.Linear(dim, dim))
121
+
122
+ self.bbox_pre_only = bbox_pre_only
123
+
124
+
125
+ if self.bbox_norm_type == "ada_norm_continous":
126
+ self.norm1_bbox = AdaLayerNormContinuous(
127
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
128
+ )
129
+ elif self.bbox_norm_type == "LayerNorm":
130
+ self.norm1_bbox = nn.LayerNorm(dim)
131
+ elif self.bbox_norm_type == "ada_norm_zero":
132
+ self.norm1_bbox = AdaLayerNormZero(dim)
133
+
134
+
135
+ if not self.bbox_pre_only:
136
+ self.norm2_bbox = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
137
+ self.ff_bbox = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
138
+ else:
139
+ self.norm2_bbox = None
140
+ self.ff_bbox = None
141
+
142
+
143
+
144
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
145
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
146
+ # Sets chunk feed-forward
147
+ self._chunk_size = chunk_size
148
+ self._chunk_dim = dim
149
+
150
+ def forward(
151
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,bbox_hidden_states=None,bbox_scale=1.0
152
+ ):
153
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
154
+
155
+ if self.context_pre_only:
156
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
157
+ else:
158
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
159
+ encoder_hidden_states, emb=temb
160
+ )
161
+
162
+ # img-txt MM-Attention.
163
+ attn_output, context_attn_output = self.attn(
164
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
165
+ )
166
+
167
+ attn_output = gate_msa.unsqueeze(1) * attn_output #gate_msa
168
+
169
+ # Layout
170
+ if self.attention_type == "layout" and bbox_scale!=0.0:
171
+
172
+ if self.bbox_pre_only:
173
+ norm_bbox_hidden_states = self.norm1_bbox(bbox_hidden_states, temb)
174
+ else:
175
+ norm_bbox_hidden_states, bbox_gate_msa, bbox_shift_mlp, bbox_scale_mlp, bbox_gate_mlp = self.norm1_bbox(
176
+ bbox_hidden_states, emb=temb
177
+ )
178
+ # img-bbox MM-Attention.
179
+ img_attn_output, bbox_attn_output = self.bbox_fuser_block(
180
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_bbox_hidden_states
181
+ )
182
+
183
+ attn_output = attn_output + bbox_scale*self.bbox_forward(img_attn_output)
184
+
185
+ if self.bbox_pre_only:
186
+ bbox_hidden_states = None
187
+ else:
188
+ bbox_attn_output = bbox_gate_msa.unsqueeze(1) * bbox_attn_output
189
+ bbox_hidden_states = bbox_hidden_states + bbox_attn_output
190
+
191
+ norm_bbox_hidden_states = self.norm2_bbox(bbox_hidden_states)
192
+ norm_bbox_hidden_states = norm_bbox_hidden_states * (1 + bbox_scale_mlp[:, None]) + bbox_shift_mlp[:, None]
193
+ if self._chunk_size is not None:
194
+ # "feed_forward_chunk_size" can be used to save memory
195
+ bbox_ff_output = _chunked_feed_forward(
196
+ self.ff_bbox, norm_bbox_hidden_states, self._chunk_dim, self._chunk_size
197
+ )
198
+ else:
199
+ bbox_ff_output = self.ff_bbox(norm_bbox_hidden_states)
200
+ bbox_hidden_states = bbox_hidden_states + bbox_gate_mlp.unsqueeze(1) * bbox_ff_output
201
+
202
+
203
+ # Process attention outputs for the `hidden_states`.
204
+ hidden_states = hidden_states + attn_output
205
+
206
+ norm_hidden_states = self.norm2(hidden_states)
207
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
208
+ if self._chunk_size is not None:
209
+ # "feed_forward_chunk_size" can be used to save memory
210
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
211
+ else:
212
+ ff_output = self.ff(norm_hidden_states)
213
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
214
+
215
+ hidden_states = hidden_states + ff_output
216
+
217
+ # Process attention outputs for the `encoder_hidden_states`.
218
+ if self.context_pre_only:
219
+ encoder_hidden_states = None
220
+ else:
221
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
223
+
224
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
225
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
226
+ if self._chunk_size is not None:
227
+ # "feed_forward_chunk_size" can be used to save memory
228
+ context_ff_output = _chunked_feed_forward(
229
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
230
+ )
231
+ else:
232
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
233
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
234
+
235
+ return encoder_hidden_states, hidden_states,bbox_hidden_states
236
+
237
+
238
+ class FeedForward(nn.Module):
239
+ r"""
240
+ A feed-forward layer.
241
+
242
+ Parameters:
243
+ dim (`int`): The number of channels in the input.
244
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
245
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
246
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
247
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
248
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
249
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ dim_out: Optional[int] = None,
256
+ mult: int = 4,
257
+ dropout: float = 0.0,
258
+ activation_fn: str = "geglu",
259
+ final_dropout: bool = False,
260
+ inner_dim=None,
261
+ bias: bool = True,
262
+ ):
263
+ super().__init__()
264
+ if inner_dim is None:
265
+ inner_dim = int(dim * mult)
266
+ dim_out = dim_out if dim_out is not None else dim
267
+
268
+ if activation_fn == "gelu":
269
+ act_fn = GELU(dim, inner_dim, bias=bias)
270
+ if activation_fn == "gelu-approximate":
271
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
272
+ elif activation_fn == "geglu":
273
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
274
+ elif activation_fn == "geglu-approximate":
275
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
276
+ elif activation_fn == "swiglu":
277
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
278
+
279
+ self.net = nn.ModuleList([])
280
+ # project in
281
+ self.net.append(act_fn)
282
+ # project dropout
283
+ self.net.append(nn.Dropout(dropout))
284
+ # project out
285
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
286
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
287
+ if final_dropout:
288
+ self.net.append(nn.Dropout(dropout))
289
+
290
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
291
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
292
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
293
+ deprecate("scale", "1.0.0", deprecation_message)
294
+ for module in self.net:
295
+ hidden_states = module(hidden_states)
296
+ return hidden_states
src/models/transformer_sd3_SiamLayout.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from src.models.attention_SiamLayout import SiamLayoutJointTransformerBlock
24
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormContinuous
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+ from diffusers.models.activations import FP32SiLU, get_activation
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ def get_fourier_embeds_from_boundingbox(embed_dim, box):
34
+ """
35
+ Args:
36
+ embed_dim: int
37
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
38
+ Returns:
39
+ [B x N x embed_dim] tensor of positional embeddings
40
+ """
41
+
42
+ batch_size, num_boxes = box.shape[:2]
43
+
44
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
45
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
46
+ emb = emb * box.unsqueeze(-1)
47
+
48
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
49
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
50
+
51
+ return emb
52
+
53
+ class PixArtAlphaTextProjection(nn.Module):
54
+ """
55
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
56
+
57
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
58
+ """
59
+
60
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
61
+ super().__init__()
62
+ if out_features is None:
63
+ out_features = hidden_size
64
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
65
+ if act_fn == "gelu_tanh":
66
+ self.act_1 = nn.GELU(approximate="tanh")
67
+ elif act_fn == "silu":
68
+ self.act_1 = nn.SiLU()
69
+ elif act_fn == "silu_fp32":
70
+ self.act_1 = FP32SiLU()
71
+ else:
72
+ raise ValueError(f"Unknown activation function: {act_fn}")
73
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
74
+
75
+ def forward(self, caption):
76
+ hidden_states = self.linear_1(caption)
77
+ hidden_states = self.act_1(hidden_states)
78
+ hidden_states = self.linear_2(hidden_states)
79
+ return hidden_states
80
+
81
+ class TextBoundingboxProjection(nn.Module):
82
+ def __init__(self, pooled_projection_dim,positive_len, out_dim, fourier_freqs=8):
83
+ super().__init__()
84
+ self.positive_len = positive_len
85
+ self.out_dim = out_dim
86
+
87
+ self.fourier_embedder_dim = fourier_freqs
88
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy #64
89
+
90
+ if isinstance(out_dim, tuple):
91
+ out_dim = out_dim[0]
92
+
93
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, positive_len, act_fn="silu")
94
+ self.linears = PixArtAlphaTextProjection(in_features=self.positive_len + self.position_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
95
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
96
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
97
+
98
+ def forward(
99
+ self,
100
+ boxes,
101
+ masks,
102
+ positive_embeddings,
103
+ phrases_masks=None,
104
+ image_masks=None,
105
+ phrases_embeddings=None,
106
+ image_embeddings=None,
107
+ ):
108
+
109
+ masks = masks.unsqueeze(-1)
110
+
111
+ # embedding position (it may includes padding as placeholder)
112
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes)
113
+
114
+ # learnable null embedding
115
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
116
+
117
+ # replace padding with learnable null embedding
118
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
119
+
120
+ # learnable null embedding
121
+ positive_null = self.null_positive_feature.view(1, 1, -1)
122
+
123
+ positive_embeddings = self.text_embedder(positive_embeddings)
124
+
125
+ # replace padding with learnable null embedding
126
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
127
+
128
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
129
+
130
+
131
+ return objs
132
+
133
+
134
+
135
+ class SiamLayoutSD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
136
+ """
137
+ The Transformer model introduced in Stable Diffusion 3.
138
+
139
+ Reference: https://arxiv.org/abs/2403.03206
140
+
141
+ Parameters:
142
+ sample_size (`int`): The width of the latent images. This is fixed during training since
143
+ it is used to learn a number of position embeddings.
144
+ patch_size (`int`): Patch size to turn the input data into small patches.
145
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
146
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
147
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
148
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
149
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
150
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
151
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
152
+ out_channels (`int`, defaults to 16): Number of output channels.
153
+
154
+ """
155
+
156
+ _supports_gradient_checkpointing = True
157
+
158
+ @register_to_config
159
+ def __init__(
160
+ self,
161
+ sample_size: int = 128,
162
+ patch_size: int = 2,
163
+ in_channels: int = 16,
164
+ num_layers: int = 18,
165
+ attention_head_dim: int = 64,
166
+ num_attention_heads: int = 18,
167
+ joint_attention_dim: int = 4096,
168
+ caption_projection_dim: int = 1152,
169
+ pooled_projection_dim: int = 2048,
170
+ out_channels: int = 16,
171
+ pos_embed_max_size: int = 96,
172
+ attention_type = "layout",
173
+ max_boxes_per_image =10
174
+ ):
175
+ super().__init__()
176
+ default_out_channels = in_channels
177
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
178
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
179
+
180
+ self.pos_embed = PatchEmbed(
181
+ height=self.config.sample_size,
182
+ width=self.config.sample_size,
183
+ patch_size=self.config.patch_size,
184
+ in_channels=self.config.in_channels,
185
+ embed_dim=self.inner_dim,
186
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
187
+ )
188
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
189
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
190
+ )
191
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
192
+
193
+ # `attention_head_dim` is doubled to account for the mixing.
194
+ # It needs to crafted when we get the actual checkpoints.
195
+ self.transformer_blocks = nn.ModuleList(
196
+ [
197
+ SiamLayoutJointTransformerBlock(
198
+ dim=self.inner_dim,
199
+ num_attention_heads=self.config.num_attention_heads,
200
+ attention_head_dim=self.config.attention_head_dim,
201
+ context_pre_only=i == num_layers - 1,
202
+ attention_type=attention_type,
203
+ bbox_pre_only= i == num_layers - 1,
204
+ bbox_with_temb= True,
205
+ )
206
+ for i in range(self.config.num_layers)
207
+ ]
208
+ )
209
+
210
+
211
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
212
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
213
+
214
+ self.gradient_checkpointing = False
215
+
216
+ self.attention_type = attention_type
217
+ self.max_boxes_per_image = max_boxes_per_image
218
+ if self.attention_type == "layout":
219
+ self.position_net = TextBoundingboxProjection(
220
+ pooled_projection_dim=self.config.pooled_projection_dim,positive_len=self.inner_dim, out_dim=self.inner_dim
221
+ )
222
+
223
+
224
+
225
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
226
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
227
+ """
228
+ Sets the attention processor to use [feed forward
229
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
230
+
231
+ Parameters:
232
+ chunk_size (`int`, *optional*):
233
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
234
+ over each tensor of dim=`dim`.
235
+ dim (`int`, *optional*, defaults to `0`):
236
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
237
+ or dim=1 (sequence length).
238
+ """
239
+ if dim not in [0, 1]:
240
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
241
+
242
+ # By default chunk size is 1
243
+ chunk_size = chunk_size or 1
244
+
245
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
246
+ if hasattr(module, "set_chunk_feed_forward"):
247
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
248
+
249
+ for child in module.children():
250
+ fn_recursive_feed_forward(child, chunk_size, dim)
251
+
252
+ for module in self.children():
253
+ fn_recursive_feed_forward(module, chunk_size, dim)
254
+
255
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
256
+ def disable_forward_chunking(self):
257
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
258
+ if hasattr(module, "set_chunk_feed_forward"):
259
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
260
+
261
+ for child in module.children():
262
+ fn_recursive_feed_forward(child, chunk_size, dim)
263
+
264
+ for module in self.children():
265
+ fn_recursive_feed_forward(module, None, 0)
266
+
267
+ @property
268
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
269
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
270
+ r"""
271
+ Returns:
272
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
273
+ indexed by its weight name.
274
+ """
275
+ # set recursively
276
+ processors = {}
277
+
278
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
279
+ if hasattr(module, "get_processor"):
280
+ processors[f"{name}.processor"] = module.get_processor()
281
+
282
+ for sub_name, child in module.named_children():
283
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
284
+
285
+ return processors
286
+
287
+ for name, module in self.named_children():
288
+ fn_recursive_add_processors(name, module, processors)
289
+
290
+ return processors
291
+
292
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
293
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
294
+ r"""
295
+ Sets the attention processor to use to compute attention.
296
+
297
+ Parameters:
298
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
299
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
300
+ for **all** `Attention` layers.
301
+
302
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
303
+ processor. This is strongly recommended when setting trainable attention processors.
304
+
305
+ """
306
+ count = len(self.attn_processors.keys())
307
+
308
+ if isinstance(processor, dict) and len(processor) != count:
309
+ raise ValueError(
310
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
311
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
312
+ )
313
+
314
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
315
+ if hasattr(module, "set_processor"):
316
+ if not isinstance(processor, dict):
317
+ module.set_processor(processor)
318
+ else:
319
+ module.set_processor(processor.pop(f"{name}.processor"))
320
+
321
+ for sub_name, child in module.named_children():
322
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
323
+
324
+ for name, module in self.named_children():
325
+ fn_recursive_attn_processor(name, module, processor)
326
+
327
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
328
+ def fuse_qkv_projections(self):
329
+ """
330
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
331
+ are fused. For cross-attention modules, key and value projection matrices are fused.
332
+
333
+ <Tip warning={true}>
334
+
335
+ This API is πŸ§ͺ experimental.
336
+
337
+ </Tip>
338
+ """
339
+ self.original_attn_processors = None
340
+
341
+ for _, attn_processor in self.attn_processors.items():
342
+ if "Added" in str(attn_processor.__class__.__name__):
343
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
344
+
345
+ self.original_attn_processors = self.attn_processors
346
+
347
+ for module in self.modules():
348
+ if isinstance(module, Attention):
349
+ module.fuse_projections(fuse=True)
350
+
351
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
352
+
353
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
354
+ def unfuse_qkv_projections(self):
355
+ """Disables the fused QKV projection if enabled.
356
+
357
+ <Tip warning={true}>
358
+
359
+ This API is πŸ§ͺ experimental.
360
+
361
+ </Tip>
362
+
363
+ """
364
+ if self.original_attn_processors is not None:
365
+ self.set_attn_processor(self.original_attn_processors)
366
+
367
+ def _set_gradient_checkpointing(self, module, value=False):
368
+ if hasattr(module, "gradient_checkpointing"):
369
+ module.gradient_checkpointing = value
370
+
371
+ def forward(
372
+ self,
373
+ hidden_states: torch.FloatTensor,
374
+ encoder_hidden_states: torch.FloatTensor = None,
375
+ pooled_projections: torch.FloatTensor = None,
376
+ timestep: torch.LongTensor = None,
377
+ block_controlnet_hidden_states: List = None,
378
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
379
+ return_dict: bool = True,
380
+ layout_kwargs = None,
381
+ bbox_scale=1.0,
382
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
383
+ """
384
+ The [`SD3Transformer2DModel`] forward method.
385
+
386
+ Args:
387
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
388
+ Input `hidden_states`.
389
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
390
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
391
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
392
+ from the embeddings of input conditions.
393
+ timestep ( `torch.LongTensor`):
394
+ Used to indicate denoising step.
395
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
396
+ A list of tensors that if specified are added to the residuals of transformer blocks.
397
+ joint_attention_kwargs (`dict`, *optional*):
398
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
399
+ `self.processor` in
400
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
401
+ return_dict (`bool`, *optional*, defaults to `True`):
402
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
403
+ tuple.
404
+
405
+ Returns:
406
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
407
+ `tuple` where the first element is the sample tensor.
408
+ """
409
+ if joint_attention_kwargs is not None:
410
+ joint_attention_kwargs = joint_attention_kwargs.copy()
411
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
412
+ else:
413
+ lora_scale = 1.0
414
+
415
+ if USE_PEFT_BACKEND:
416
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
417
+ scale_lora_layers(self, lora_scale)
418
+ else:
419
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
420
+ logger.warning(
421
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
422
+ )
423
+
424
+ height, width = hidden_states.shape[-2:]
425
+
426
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
427
+ temb = self.time_text_embed(timestep, pooled_projections)
428
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
429
+
430
+
431
+ if self.attention_type=="layout" and layout_kwargs is not None and layout_kwargs.get("layout", None) is not None:
432
+
433
+ layout_args = layout_kwargs["layout"]
434
+ bbox_raw = layout_args["boxes"]
435
+ bbox_text_embeddings = layout_args["positive_embeddings"].to(dtype=hidden_states.dtype,device=hidden_states.device)
436
+ bbox_masks = layout_args["masks"]
437
+ bbox_hidden_states = self.position_net(boxes=bbox_raw,masks=bbox_masks,positive_embeddings=bbox_text_embeddings)
438
+
439
+ else:
440
+ N = hidden_states.shape[0]
441
+ bbox_hidden_states = torch.zeros(N, self.max_boxes_per_image,self.inner_dim, dtype=hidden_states.dtype, device=hidden_states.device)
442
+ bbox_masks = torch.zeros(N, self.max_boxes_per_image, dtype=hidden_states.dtype, device=hidden_states.device)
443
+
444
+ for index_block, block in enumerate(self.transformer_blocks):
445
+ if self.training and self.gradient_checkpointing:
446
+
447
+ def create_custom_forward(module, return_dict=None):
448
+ def custom_forward(*inputs):
449
+ if return_dict is not None:
450
+ return module(*inputs, return_dict=return_dict)
451
+ else:
452
+ return module(*inputs)
453
+
454
+ return custom_forward
455
+
456
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
457
+ encoder_hidden_states, hidden_states,bbox_hidden_states = torch.utils.checkpoint.checkpoint(
458
+ create_custom_forward(block),
459
+ hidden_states,
460
+ encoder_hidden_states,
461
+ temb,
462
+ bbox_hidden_states,
463
+ **ckpt_kwargs,
464
+ )
465
+
466
+ else:
467
+ encoder_hidden_states, hidden_states,bbox_hidden_states = block(
468
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,bbox_hidden_states= bbox_hidden_states,bbox_scale=bbox_scale
469
+ )
470
+
471
+ # controlnet residual
472
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
473
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
474
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
475
+
476
+ hidden_states = self.norm_out(hidden_states, temb)
477
+ hidden_states = self.proj_out(hidden_states)
478
+
479
+ # unpatchify
480
+ patch_size = self.config.patch_size
481
+ height = height // patch_size
482
+ width = width // patch_size
483
+
484
+ hidden_states = hidden_states.reshape(
485
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
486
+ )
487
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
488
+ output = hidden_states.reshape(
489
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
490
+ )
491
+
492
+ if USE_PEFT_BACKEND:
493
+ # remove `lora_scale` from each PEFT layer
494
+ unscale_lora_layers(self, lora_scale)
495
+
496
+ if not return_dict:
497
+ return (output,)
498
+
499
+ return Transformer2DModelOutput(sample=output)
src/pipeline/pipeline_CreatiLayout.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ T5EncoderModel,
23
+ T5TokenizerFast,
24
+ )
25
+
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
28
+ from diffusers.models.autoencoders import AutoencoderKL
29
+
30
+ from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
+ from diffusers.utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import randn_tensor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
43
+
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import StableDiffusion3Pipeline
59
+
60
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
61
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
62
+ ... )
63
+ >>> pipe.to("cuda")
64
+ >>> prompt = "A cat holding a sign that says hello world"
65
+ >>> image = pipe(prompt).images[0]
66
+ >>> image.save("sd3.png")
67
+ ```
68
+ """
69
+
70
+
71
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
72
+ def retrieve_timesteps(
73
+ scheduler,
74
+ num_inference_steps: Optional[int] = None,
75
+ device: Optional[Union[str, torch.device]] = None,
76
+ timesteps: Optional[List[int]] = None,
77
+ sigmas: Optional[List[float]] = None,
78
+ **kwargs,
79
+ ):
80
+ """
81
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
82
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
83
+
84
+ Args:
85
+ scheduler (`SchedulerMixin`):
86
+ The scheduler to get timesteps from.
87
+ num_inference_steps (`int`):
88
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
89
+ must be `None`.
90
+ device (`str` or `torch.device`, *optional*):
91
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
92
+ timesteps (`List[int]`, *optional*):
93
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
94
+ `num_inference_steps` and `sigmas` must be `None`.
95
+ sigmas (`List[float]`, *optional*):
96
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
97
+ `num_inference_steps` and `timesteps` must be `None`.
98
+
99
+ Returns:
100
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101
+ second element is the number of inference steps.
102
+ """
103
+ if timesteps is not None and sigmas is not None:
104
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
105
+ if timesteps is not None:
106
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107
+ if not accepts_timesteps:
108
+ raise ValueError(
109
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110
+ f" timestep schedules. Please check whether you are using the correct scheduler."
111
+ )
112
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113
+ timesteps = scheduler.timesteps
114
+ num_inference_steps = len(timesteps)
115
+ elif sigmas is not None:
116
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
117
+ if not accept_sigmas:
118
+ raise ValueError(
119
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
120
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
121
+ )
122
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
123
+ timesteps = scheduler.timesteps
124
+ num_inference_steps = len(timesteps)
125
+ else:
126
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
127
+ timesteps = scheduler.timesteps
128
+ return timesteps, num_inference_steps
129
+
130
+
131
+ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
132
+ r"""
133
+ Args:
134
+ transformer ([`SD3Transformer2DModel`]):
135
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
136
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
137
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
138
+ vae ([`AutoencoderKL`]):
139
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
140
+ text_encoder ([`CLIPTextModelWithProjection`]):
141
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
142
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
143
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
144
+ as its dimension.
145
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
146
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
147
+ specifically the
148
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
149
+ variant.
150
+ text_encoder_3 ([`T5EncoderModel`]):
151
+ Frozen text-encoder. Stable Diffusion 3 uses
152
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
153
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
154
+ tokenizer (`CLIPTokenizer`):
155
+ Tokenizer of class
156
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
157
+ tokenizer_2 (`CLIPTokenizer`):
158
+ Second Tokenizer of class
159
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
160
+ tokenizer_3 (`T5TokenizerFast`):
161
+ Tokenizer of class
162
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
163
+ """
164
+
165
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
166
+ _optional_components = []
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ transformer: SiamLayoutSD3Transformer2DModel,
172
+ scheduler: FlowMatchEulerDiscreteScheduler,
173
+ vae: AutoencoderKL,
174
+ text_encoder: CLIPTextModelWithProjection,
175
+ tokenizer: CLIPTokenizer,
176
+ text_encoder_2: CLIPTextModelWithProjection,
177
+ tokenizer_2: CLIPTokenizer,
178
+ text_encoder_3: T5EncoderModel,
179
+ tokenizer_3: T5TokenizerFast,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.register_modules(
184
+ vae=vae,
185
+ text_encoder=text_encoder,
186
+ text_encoder_2=text_encoder_2,
187
+ text_encoder_3=text_encoder_3,
188
+ tokenizer=tokenizer,
189
+ tokenizer_2=tokenizer_2,
190
+ tokenizer_3=tokenizer_3,
191
+ transformer=transformer,
192
+ scheduler=scheduler,
193
+ )
194
+ self.vae_scale_factor = (
195
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
196
+ )
197
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
198
+ self.tokenizer_max_length = (
199
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
200
+ )
201
+ self.default_sample_size = (
202
+ self.transformer.config.sample_size
203
+ if hasattr(self, "transformer") and self.transformer is not None
204
+ else 128
205
+ )
206
+
207
+ def _get_t5_prompt_embeds(
208
+ self,
209
+ prompt: Union[str, List[str]] = None,
210
+ num_images_per_prompt: int = 1,
211
+ max_sequence_length: int = 256, # 256
212
+ device: Optional[torch.device] = None,
213
+ dtype: Optional[torch.dtype] = None,
214
+ ):
215
+ device = device or self._execution_device
216
+ dtype = dtype or self.text_encoder.dtype
217
+
218
+ prompt = [prompt] if isinstance(prompt, str) else prompt
219
+ batch_size = len(prompt)
220
+
221
+ if self.text_encoder_3 is None:
222
+ return torch.zeros(
223
+ (
224
+ batch_size * num_images_per_prompt,
225
+ self.tokenizer_max_length,
226
+ self.transformer.config.joint_attention_dim,
227
+ ),
228
+ device=device,
229
+ dtype=dtype,
230
+ )
231
+
232
+ text_inputs = self.tokenizer_3(
233
+ prompt,
234
+ padding="max_length",
235
+ max_length=max_sequence_length,
236
+ truncation=True,
237
+ add_special_tokens=True,
238
+ return_tensors="pt",
239
+ )
240
+ text_input_ids = text_inputs.input_ids
241
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
242
+
243
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
244
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
245
+ logger.warning(
246
+ "The following part of your input was truncated because `max_sequence_length` is set to "
247
+ f" {max_sequence_length} tokens: {removed_text}"
248
+ )
249
+
250
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
251
+
252
+ dtype = self.text_encoder_3.dtype
253
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
254
+
255
+ _, seq_len, _ = prompt_embeds.shape
256
+
257
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
258
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
259
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
260
+
261
+ return prompt_embeds #[B,256,4096]
262
+
263
+ def _get_clip_prompt_embeds(
264
+ self,
265
+ prompt: Union[str, List[str]],
266
+ num_images_per_prompt: int = 1,
267
+ device: Optional[torch.device] = None,
268
+ clip_skip: Optional[int] = None,
269
+ clip_model_index: int = 0,
270
+ ):
271
+ device = device or self._execution_device
272
+
273
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
274
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
275
+
276
+ tokenizer = clip_tokenizers[clip_model_index]
277
+ text_encoder = clip_text_encoders[clip_model_index]
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ batch_size = len(prompt)
281
+
282
+ text_inputs = tokenizer(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=self.tokenizer_max_length,
286
+ truncation=True,
287
+ return_tensors="pt",
288
+ )
289
+
290
+ text_input_ids = text_inputs.input_ids
291
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
294
+ logger.warning(
295
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
296
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
297
+ )
298
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
299
+ pooled_prompt_embeds = prompt_embeds[0]
300
+
301
+ if clip_skip is None:
302
+ prompt_embeds = prompt_embeds.hidden_states[-2] #ε€’ζ•°η¬¬δΊŒε±‚
303
+ else:
304
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
305
+
306
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
307
+
308
+ _, seq_len, _ = prompt_embeds.shape
309
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
310
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
311
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
312
+
313
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
314
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
315
+
316
+ return prompt_embeds, pooled_prompt_embeds #clip-L [B,77,768], [B,768] #clip-G [B,77,1280], [B,1280]
317
+
318
+ def encode_prompt(
319
+ self,
320
+ prompt: Union[str, List[str]],
321
+ prompt_2: Union[str, List[str]],
322
+ prompt_3: Union[str, List[str]],
323
+ device: Optional[torch.device] = None,
324
+ num_images_per_prompt: int = 1,
325
+ do_classifier_free_guidance: bool = True,
326
+ negative_prompt: Optional[Union[str, List[str]]] = None,
327
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
328
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
329
+ prompt_embeds: Optional[torch.FloatTensor] = None,
330
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
331
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
332
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ clip_skip: Optional[int] = None,
334
+ max_sequence_length: int = 256,
335
+ lora_scale: Optional[float] = None,
336
+ ):
337
+ r"""
338
+
339
+ Args:
340
+ prompt (`str` or `List[str]`, *optional*):
341
+ prompt to be encoded
342
+ prompt_2 (`str` or `List[str]`, *optional*):
343
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
344
+ used in all text-encoders
345
+ prompt_3 (`str` or `List[str]`, *optional*):
346
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
347
+ used in all text-encoders
348
+ device: (`torch.device`):
349
+ torch device
350
+ num_images_per_prompt (`int`):
351
+ number of images that should be generated per prompt
352
+ do_classifier_free_guidance (`bool`):
353
+ whether to use classifier free guidance or not
354
+ negative_prompt (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
356
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
357
+ less than `1`).
358
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
359
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
360
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
361
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
362
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
363
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
364
+ prompt_embeds (`torch.FloatTensor`, *optional*):
365
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
366
+ provided, text embeddings will be generated from `prompt` input argument.
367
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
368
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
369
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
370
+ argument.
371
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
372
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
373
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
374
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
375
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
376
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
377
+ input argument.
378
+ clip_skip (`int`, *optional*):
379
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
380
+ the output of the pre-final layer will be used for computing the prompt embeddings.
381
+ lora_scale (`float`, *optional*):
382
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
383
+ """
384
+ device = device or self._execution_device
385
+ # set lora scale so that monkey patched LoRA
386
+ # function of text encoder can correctly access it
387
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
388
+ self._lora_scale = lora_scale
389
+
390
+ # dynamically adjust the LoRA scale
391
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
392
+ scale_lora_layers(self.text_encoder, lora_scale)
393
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
394
+ scale_lora_layers(self.text_encoder_2, lora_scale)
395
+
396
+ prompt = [prompt] if isinstance(prompt, str) else prompt
397
+ if prompt is not None:
398
+ batch_size = len(prompt)
399
+ else:
400
+ batch_size = prompt_embeds.shape[0]
401
+
402
+ if prompt_embeds is None:
403
+ prompt_2 = prompt_2 or prompt
404
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
405
+
406
+ prompt_3 = prompt_3 or prompt
407
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
408
+
409
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
410
+ prompt=prompt,
411
+ device=device,
412
+ num_images_per_prompt=num_images_per_prompt,
413
+ clip_skip=clip_skip,
414
+ clip_model_index=0,
415
+ )
416
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
417
+ prompt=prompt_2,
418
+ device=device,
419
+ num_images_per_prompt=num_images_per_prompt,
420
+ clip_skip=clip_skip,
421
+ clip_model_index=1,
422
+ )
423
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) # torch.Size([B, 77, 768])+ torch.Size([B, 77, 1280])-> torch.Size([B, 77, 2048])
424
+
425
+ t5_prompt_embed = self._get_t5_prompt_embeds(
426
+ prompt=prompt_3,
427
+ num_images_per_prompt=num_images_per_prompt,
428
+ max_sequence_length=max_sequence_length,
429
+ device=device,
430
+ ) # [B,256,4096]
431
+
432
+ clip_prompt_embeds = torch.nn.functional.pad(
433
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
434
+ ) # [B,77,4096]
435
+
436
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) # torch.Size([B, 333(256+77), 4096])
437
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)# [B,2048]
438
+
439
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
440
+ negative_prompt = negative_prompt or ""
441
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
442
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
443
+
444
+ # normalize str to list
445
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
446
+ negative_prompt_2 = (
447
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
448
+ )
449
+ negative_prompt_3 = (
450
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
451
+ )
452
+
453
+ if prompt is not None and type(prompt) is not type(negative_prompt):
454
+ raise TypeError(
455
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
456
+ f" {type(prompt)}."
457
+ )
458
+ elif batch_size != len(negative_prompt):
459
+ raise ValueError(
460
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
461
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
462
+ " the batch size of `prompt`."
463
+ )
464
+
465
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
466
+ negative_prompt,
467
+ device=device,
468
+ num_images_per_prompt=num_images_per_prompt,
469
+ clip_skip=None,
470
+ clip_model_index=0,
471
+ )
472
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
473
+ negative_prompt_2,
474
+ device=device,
475
+ num_images_per_prompt=num_images_per_prompt,
476
+ clip_skip=None,
477
+ clip_model_index=1,
478
+ )
479
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
480
+
481
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
482
+ prompt=negative_prompt_3,
483
+ num_images_per_prompt=num_images_per_prompt,
484
+ max_sequence_length=max_sequence_length,
485
+ device=device,
486
+ )
487
+
488
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
489
+ negative_clip_prompt_embeds,
490
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
491
+ )
492
+
493
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
494
+ negative_pooled_prompt_embeds = torch.cat(
495
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
496
+ )
497
+
498
+ if self.text_encoder is not None:
499
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
500
+ # Retrieve the original scale by scaling back the LoRA layers
501
+ unscale_lora_layers(self.text_encoder, lora_scale)
502
+
503
+ if self.text_encoder_2 is not None:
504
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
505
+ # Retrieve the original scale by scaling back the LoRA layers
506
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
507
+
508
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
509
+
510
+ def check_inputs(
511
+ self,
512
+ prompt,
513
+ prompt_2,
514
+ prompt_3,
515
+ height,
516
+ width,
517
+ negative_prompt=None,
518
+ negative_prompt_2=None,
519
+ negative_prompt_3=None,
520
+ prompt_embeds=None,
521
+ negative_prompt_embeds=None,
522
+ pooled_prompt_embeds=None,
523
+ negative_pooled_prompt_embeds=None,
524
+ callback_on_step_end_tensor_inputs=None,
525
+ max_sequence_length=None,
526
+ ):
527
+ if height % 8 != 0 or width % 8 != 0:
528
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
529
+
530
+ if callback_on_step_end_tensor_inputs is not None and not all(
531
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
532
+ ):
533
+ raise ValueError(
534
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
535
+ )
536
+
537
+ if prompt is not None and prompt_embeds is not None:
538
+ raise ValueError(
539
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
540
+ " only forward one of the two."
541
+ )
542
+ elif prompt_2 is not None and prompt_embeds is not None:
543
+ raise ValueError(
544
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
545
+ " only forward one of the two."
546
+ )
547
+ elif prompt_3 is not None and prompt_embeds is not None:
548
+ raise ValueError(
549
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
550
+ " only forward one of the two."
551
+ )
552
+ elif prompt is None and prompt_embeds is None:
553
+ raise ValueError(
554
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
555
+ )
556
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
557
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
558
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
559
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
560
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
561
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
562
+
563
+ if negative_prompt is not None and negative_prompt_embeds is not None:
564
+ raise ValueError(
565
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
566
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
567
+ )
568
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
569
+ raise ValueError(
570
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
571
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
572
+ )
573
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
574
+ raise ValueError(
575
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
576
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
577
+ )
578
+
579
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
580
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
581
+ raise ValueError(
582
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
583
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
584
+ f" {negative_prompt_embeds.shape}."
585
+ )
586
+
587
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
588
+ raise ValueError(
589
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
590
+ )
591
+
592
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
593
+ raise ValueError(
594
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
595
+ )
596
+
597
+ if max_sequence_length is not None and max_sequence_length > 512:
598
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
599
+
600
+ def prepare_latents(
601
+ self,
602
+ batch_size,
603
+ num_channels_latents,
604
+ height,
605
+ width,
606
+ dtype,
607
+ device,
608
+ generator,
609
+ latents=None,
610
+ ):
611
+ if latents is not None:
612
+ return latents.to(device=device, dtype=dtype)
613
+
614
+ shape = (
615
+ batch_size,
616
+ num_channels_latents,
617
+ int(height) // self.vae_scale_factor,
618
+ int(width) // self.vae_scale_factor,
619
+ )
620
+
621
+ if isinstance(generator, list) and len(generator) != batch_size:
622
+ raise ValueError(
623
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
624
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
625
+ )
626
+
627
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
628
+
629
+ return latents
630
+
631
+ @property
632
+ def guidance_scale(self):
633
+ return self._guidance_scale
634
+
635
+ @property
636
+ def clip_skip(self):
637
+ return self._clip_skip
638
+
639
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
640
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
641
+ # corresponds to doing no classifier free guidance.
642
+ @property
643
+ def do_classifier_free_guidance(self):
644
+ return self._guidance_scale > 1
645
+
646
+ @property
647
+ def joint_attention_kwargs(self):
648
+ return self._joint_attention_kwargs
649
+
650
+ @property
651
+ def num_timesteps(self):
652
+ return self._num_timesteps
653
+
654
+ @property
655
+ def interrupt(self):
656
+ return self._interrupt
657
+
658
+ @torch.no_grad()
659
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
660
+ def __call__(
661
+ self,
662
+ prompt: Union[str, List[str]] = None,
663
+ prompt_2: Optional[Union[str, List[str]]] = None,
664
+ prompt_3: Optional[Union[str, List[str]]] = None,
665
+ height: Optional[int] = None,
666
+ width: Optional[int] = None,
667
+ num_inference_steps: int = 28,
668
+ timesteps: List[int] = None,
669
+ guidance_scale: float = 7.0,
670
+ negative_prompt: Optional[Union[str, List[str]]] = None,
671
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
672
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
673
+ num_images_per_prompt: Optional[int] = 1,
674
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
675
+ latents: Optional[torch.FloatTensor] = None,
676
+ prompt_embeds: Optional[torch.FloatTensor] = None,
677
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
678
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
679
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
680
+ output_type: Optional[str] = "pil",
681
+ return_dict: bool = True,
682
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
683
+ clip_skip: Optional[int] = None,
684
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
685
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
686
+ max_sequence_length: int = 256,
687
+ bbox_phrases=None,
688
+ bbox_raw=None
689
+ ):
690
+ r"""
691
+ Function invoked when calling the pipeline for generation.
692
+
693
+ Args:
694
+ prompt (`str` or `List[str]`, *optional*):
695
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
696
+ instead.
697
+ prompt_2 (`str` or `List[str]`, *optional*):
698
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
699
+ will be used instead
700
+ prompt_3 (`str` or `List[str]`, *optional*):
701
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
702
+ will be used instead
703
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
704
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
705
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
706
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
707
+ num_inference_steps (`int`, *optional*, defaults to 50):
708
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
709
+ expense of slower inference.
710
+ timesteps (`List[int]`, *optional*):
711
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
712
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
713
+ passed will be used. Must be in descending order.
714
+ guidance_scale (`float`, *optional*, defaults to 7.0):
715
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
716
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
717
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
718
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
719
+ usually at the expense of lower image quality.
720
+ negative_prompt (`str` or `List[str]`, *optional*):
721
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
722
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
723
+ less than `1`).
724
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
725
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
726
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
727
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
728
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
729
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
730
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
731
+ The number of images to generate per prompt.
732
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
733
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
734
+ to make generation deterministic.
735
+ latents (`torch.FloatTensor`, *optional*):
736
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
737
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
738
+ tensor will ge generated by sampling using the supplied random `generator`.
739
+ prompt_embeds (`torch.FloatTensor`, *optional*):
740
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
741
+ provided, text embeddings will be generated from `prompt` input argument.
742
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
743
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
744
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
745
+ argument.
746
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
747
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
748
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
749
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
750
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
751
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
752
+ input argument.
753
+ output_type (`str`, *optional*, defaults to `"pil"`):
754
+ The output format of the generate image. Choose between
755
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
756
+ return_dict (`bool`, *optional*, defaults to `True`):
757
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
758
+ of a plain tuple.
759
+ joint_attention_kwargs (`dict`, *optional*):
760
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
761
+ `self.processor` in
762
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
763
+ callback_on_step_end (`Callable`, *optional*):
764
+ A function that calls at the end of each denoising steps during the inference. The function is called
765
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
766
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
767
+ `callback_on_step_end_tensor_inputs`.
768
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
769
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
770
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
771
+ `._callback_tensor_inputs` attribute of your pipeline class.
772
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
773
+
774
+ Examples:
775
+
776
+ Returns:
777
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
778
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
779
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
780
+ """
781
+
782
+ height = height or self.default_sample_size * self.vae_scale_factor
783
+ width = width or self.default_sample_size * self.vae_scale_factor
784
+
785
+ # 1. Check inputs. Raise error if not correct
786
+ self.check_inputs(
787
+ prompt,
788
+ prompt_2,
789
+ prompt_3,
790
+ height,
791
+ width,
792
+ negative_prompt=negative_prompt,
793
+ negative_prompt_2=negative_prompt_2,
794
+ negative_prompt_3=negative_prompt_3,
795
+ prompt_embeds=prompt_embeds,
796
+ negative_prompt_embeds=negative_prompt_embeds,
797
+ pooled_prompt_embeds=pooled_prompt_embeds,
798
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
799
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
800
+ max_sequence_length=max_sequence_length,
801
+ )
802
+
803
+ self._guidance_scale = guidance_scale
804
+ self._clip_skip = clip_skip
805
+ self._joint_attention_kwargs = joint_attention_kwargs
806
+ self._interrupt = False
807
+
808
+ # 2. Define call parameters
809
+ if prompt is not None and isinstance(prompt, str):
810
+ batch_size = 1
811
+ elif prompt is not None and isinstance(prompt, list):
812
+ batch_size = len(prompt)
813
+ else:
814
+ batch_size = prompt_embeds.shape[0]
815
+
816
+ device = self._execution_device
817
+
818
+ lora_scale = (
819
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
820
+ )
821
+ (
822
+ prompt_embeds,
823
+ negative_prompt_embeds,
824
+ pooled_prompt_embeds,
825
+ negative_pooled_prompt_embeds,
826
+ ) = self.encode_prompt(
827
+ prompt=prompt,
828
+ prompt_2=prompt_2,
829
+ prompt_3=prompt_3,
830
+ negative_prompt=negative_prompt,
831
+ negative_prompt_2=negative_prompt_2,
832
+ negative_prompt_3=negative_prompt_3,
833
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
834
+ prompt_embeds=prompt_embeds,
835
+ negative_prompt_embeds=negative_prompt_embeds,
836
+ pooled_prompt_embeds=pooled_prompt_embeds,
837
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
838
+ device=device,
839
+ clip_skip=self.clip_skip,
840
+ num_images_per_prompt=num_images_per_prompt,
841
+ max_sequence_length=max_sequence_length,
842
+ lora_scale=lora_scale,
843
+ )
844
+
845
+ if self.do_classifier_free_guidance:
846
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
847
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
848
+
849
+ # 4. Prepare timesteps
850
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
851
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
852
+ self._num_timesteps = len(timesteps)
853
+
854
+ # 5. Prepare latent variables
855
+ num_channels_latents = self.transformer.config.in_channels
856
+ latents = self.prepare_latents(
857
+ batch_size * num_images_per_prompt,
858
+ num_channels_latents,
859
+ height,
860
+ width,
861
+ prompt_embeds.dtype,
862
+ device,
863
+ generator,
864
+ latents,
865
+ )
866
+
867
+ # 5.5 layout
868
+ max_objs = 10
869
+ if len(bbox_raw) > max_objs:
870
+
871
+ print(f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.")
872
+
873
+ bbox_phrases = bbox_phrases[:max_objs]
874
+ bbox_raw = bbox_raw[:max_objs]
875
+ # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
876
+ # Get tokens for phrases from pre-trained CLIPTokenizer
877
+ # from IPython.core.debugger import set_trace
878
+ # set_trace()
879
+ tokenizer_inputs = self.tokenizer(
880
+ bbox_phrases,
881
+ padding="max_length",
882
+ max_length=self.tokenizer_max_length,
883
+ truncation=True,
884
+ return_tensors="pt",
885
+ ).input_ids.to(device)
886
+ # For the token, we use the same pre-trained text encoder
887
+ # to obtain its text feature
888
+
889
+ text_embeddings_1 = self.text_encoder(tokenizer_inputs.to(device), output_hidden_states=True)[0]
890
+
891
+
892
+ tokenizer_inputs_2 = self.tokenizer_2(
893
+ bbox_phrases,
894
+ padding="max_length",
895
+ max_length=self.tokenizer_max_length,
896
+ truncation=True,
897
+ return_tensors="pt",
898
+ ).input_ids.to(device)
899
+ # For the token, we use the same pre-trained text encoder
900
+ # to obtain its text feature
901
+
902
+ text_embeddings_2 = self.text_encoder_2(tokenizer_inputs_2.to(device), output_hidden_states=True)[0]
903
+
904
+ clip_text_embeddings = torch.cat([text_embeddings_1, text_embeddings_2], dim=-1)
905
+
906
+
907
+
908
+ n_objs = len(bbox_raw)
909
+ boxes = torch.zeros(max_objs, 4, device=device, dtype=latents.dtype)
910
+ boxes[:n_objs] = torch.tensor(bbox_raw, device=device, dtype=latents.dtype)
911
+ text_embeddings = torch.zeros(
912
+ max_objs, 2048, device=device, dtype=latents.dtype
913
+ )
914
+ text_embeddings[:n_objs] = clip_text_embeddings
915
+ masks = torch.zeros(max_objs, device=device, dtype=latents.dtype)
916
+ masks[:n_objs] = 1
917
+ repeat_batch = batch_size * num_images_per_prompt
918
+ boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
919
+ boxex = boxes.to(device=device, dtype=latents.dtype)
920
+ text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
921
+ text_embeddings = text_embeddings.to(device=device, dtype=latents.dtype)
922
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
923
+ masks = masks.to(device=device, dtype=latents.dtype)
924
+ if self.do_classifier_free_guidance:
925
+ repeat_batch = repeat_batch * 2
926
+ boxes = torch.cat([boxes] * 2)
927
+ text_embeddings = torch.cat([text_embeddings] * 2)
928
+ masks = torch.cat([masks] * 2)
929
+ masks[: repeat_batch // 2] = 0
930
+
931
+ layout_kwargs = {
932
+ "layout": {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks}
933
+ }
934
+
935
+ bbox_scale =1.0
936
+ num_grounding_steps = int(0.3 * len(timesteps))
937
+
938
+ # 6. Denoising loop
939
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
940
+ for i, t in enumerate(timesteps):
941
+ if self.interrupt:
942
+ continue
943
+
944
+ # layout scale
945
+ if i == num_grounding_steps:
946
+ bbox_scale=0.0
947
+
948
+ # expand the latents if we are doing classifier free guidance
949
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
950
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
951
+ timestep = t.expand(latent_model_input.shape[0])
952
+
953
+ noise_pred = self.transformer(
954
+ hidden_states=latent_model_input,
955
+ timestep=timestep,
956
+ encoder_hidden_states=prompt_embeds,
957
+ pooled_projections=pooled_prompt_embeds,
958
+ joint_attention_kwargs=self.joint_attention_kwargs,
959
+ return_dict=False,
960
+ layout_kwargs= layout_kwargs,
961
+ bbox_scale=bbox_scale
962
+ )[0]
963
+
964
+ # perform guidance
965
+ if self.do_classifier_free_guidance:
966
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
967
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
968
+
969
+ # compute the previous noisy sample x_t -> x_t-1
970
+ latents_dtype = latents.dtype
971
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
972
+
973
+ if latents.dtype != latents_dtype:
974
+ if torch.backends.mps.is_available():
975
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
976
+ latents = latents.to(latents_dtype)
977
+
978
+ if callback_on_step_end is not None:
979
+ callback_kwargs = {}
980
+ for k in callback_on_step_end_tensor_inputs:
981
+ callback_kwargs[k] = locals()[k]
982
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
983
+
984
+ latents = callback_outputs.pop("latents", latents)
985
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
986
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
987
+ negative_pooled_prompt_embeds = callback_outputs.pop(
988
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
989
+ )
990
+
991
+ # call the callback, if provided
992
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
993
+ progress_bar.update()
994
+
995
+ if XLA_AVAILABLE:
996
+ xm.mark_step()
997
+
998
+ if output_type == "latent":
999
+ image = latents
1000
+
1001
+ else:
1002
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1003
+
1004
+ image = self.vae.decode(latents, return_dict=False)[0]
1005
+ image = self.image_processor.postprocess(image, output_type=output_type)
1006
+
1007
+ # Offload all models
1008
+ self.maybe_free_model_hooks()
1009
+
1010
+ if not return_dict:
1011
+ return (image,)
1012
+
1013
+ return StableDiffusion3PipelineOutput(images=image)
utils/arial.ttf ADDED
Binary file (276 kB). View file
 
utils/bbox_visualization.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import random
5
+ def scale_boxes(boxes, width, height):
6
+ scaled_boxes = []
7
+ for box in boxes:
8
+ x_min, y_min, x_max, y_max = box
9
+ scaled_box = [x_min * width, y_min * height, x_max * width, y_max * height]
10
+ scaled_boxes.append(scaled_box)
11
+ return scaled_boxes
12
+
13
+ def draw_mask(mask, draw, random_color=True):
14
+ if random_color:
15
+ color = (
16
+ random.randint(0, 255),
17
+ random.randint(0, 255),
18
+ random.randint(0, 255),
19
+ 153,
20
+ )
21
+ else:
22
+ color = (30, 144, 255, 153)
23
+
24
+ nonzero_coords = np.transpose(np.nonzero(mask))
25
+
26
+ for coord in nonzero_coords:
27
+ draw.point(coord[::-1], fill=color)
28
+
29
+ def bbox_visualization(image_pil: Image,
30
+ result: Dict,
31
+ draw_width: float = 6.0,
32
+ return_mask=True) -> Image:
33
+ """Plot bounding boxes and labels on an image.
34
+
35
+ Args:
36
+ image_pil (PIL.Image): The input image as a PIL Image object.
37
+ result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
38
+ the bounding boxes and labels. The keys are:
39
+ - boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
40
+ - scores (List[float]): A list of scores for each bounding box. shape (N)
41
+ - labels (List[str]): A list of labels for each object
42
+ - masks (List[PIL.Image]): A list of masks in the format of PIL.Image
43
+ draw_score (bool): Draw score on the image. Defaults to False.
44
+
45
+ Returns:
46
+ PIL.Image: The input image with plotted bounding boxes, labels, and masks.
47
+ """
48
+ # Get the bounding boxes and labels from the target dictionary
49
+ boxes = result["boxes"]
50
+ categorys = result["labels"]
51
+ masks = result.get("masks", [])
52
+
53
+
54
+ color_list= [(177, 214, 144),(255, 162, 76),
55
+ (13, 146, 244),(249, 84, 84),(54, 186, 152),
56
+ (74, 36, 157),(0, 159, 189),
57
+ (80, 118, 135),(188, 90, 148),(119, 205, 255)]
58
+
59
+
60
+ np.random.seed(42)
61
+
62
+ # Find all unique categories and build a cate2color dictionary
63
+ cate2color = {}
64
+ unique_categorys = sorted(set(categorys))
65
+ for idx,cate in enumerate(unique_categorys):
66
+ cate2color[cate] = color_list[idx%len(color_list)]
67
+
68
+ # Load a font with the specified size
69
+ font_size=30
70
+ font = ImageFont.truetype("utils/arial.ttf", font_size)
71
+
72
+ # Create a PIL ImageDraw object to draw on the input image
73
+ if isinstance(image_pil, np.ndarray):
74
+ image_pil = Image.fromarray(image_pil)
75
+ draw = ImageDraw.Draw(image_pil)
76
+
77
+ # Create a new binary mask image with the same size as the input image
78
+ mask = Image.new("L", image_pil.size, 0)
79
+ # Create a PIL ImageDraw object to draw on the mask image
80
+ mask_draw = ImageDraw.Draw(mask)
81
+
82
+ # Draw boxes, labels, and masks for each box and label in the target dictionary
83
+ for box, category in zip(boxes, categorys):
84
+ # Extract the box coordinates
85
+ x0, y0, x1, y1 = box
86
+
87
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
88
+ color = cate2color[category]
89
+
90
+ # Draw the box outline on the input image
91
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
92
+
93
+ # Draw the label and score on the input image
94
+ text = f"{category}"
95
+
96
+ if hasattr(font, "getbbox"):
97
+ bbox = draw.textbbox((x0, y0), text, font)
98
+ else:
99
+ w, h = draw.textsize(text, font)
100
+ bbox = (x0, y0, w + x0, y0 + h)
101
+ draw.rectangle(bbox, fill=color)
102
+ draw.text((x0, y0), text, fill="white",font=font)
103
+
104
+ # Draw the mask on the input image if masks are provided
105
+ if len(masks) > 0 and return_mask:
106
+ size = image_pil.size
107
+ mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
108
+ mask_draw = ImageDraw.Draw(mask_image)
109
+ for mask in masks:
110
+ mask = np.array(mask)[:, :, -1]
111
+ draw_mask(mask, mask_draw)
112
+
113
+ image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
114
+ return image_pil
115
+
116
+
117
+