John6666 commited on
Commit
0eea822
·
verified ·
1 Parent(s): d6df077

Upload 46 files

Browse files
Files changed (47) hide show
  1. .gitattributes +1 -0
  2. LICENSE +21 -0
  3. README.md +13 -12
  4. app.py +229 -0
  5. depth_anything_v2/dinov2.py +415 -0
  6. depth_anything_v2/dinov2_layers/__init__.py +11 -0
  7. depth_anything_v2/dinov2_layers/attention.py +83 -0
  8. depth_anything_v2/dinov2_layers/block.py +252 -0
  9. depth_anything_v2/dinov2_layers/drop_path.py +35 -0
  10. depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
  11. depth_anything_v2/dinov2_layers/mlp.py +41 -0
  12. depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
  13. depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
  14. depth_anything_v2/dpt.py +221 -0
  15. depth_anything_v2/util/blocks.py +148 -0
  16. depth_anything_v2/util/transform.py +158 -0
  17. flux-architecture.svg +169 -0
  18. flux/activations.py +165 -0
  19. flux/attention.py +843 -0
  20. flux/attention_processor.py +0 -0
  21. flux/controlnet_flux.py +617 -0
  22. flux/embeddings.py +1469 -0
  23. flux/flux_network.py +183 -0
  24. flux/lora/lora_base.py +752 -0
  25. flux/lora/lora_conversion_utils.py +328 -0
  26. flux/lora/lora_pipeline.py +0 -0
  27. flux/lora/peft.py +395 -0
  28. flux/normalization.py +393 -0
  29. flux/pipeline_flux.py +749 -0
  30. flux/pipeline_flux_chameleon.py +758 -0
  31. flux/pipeline_flux_controlnet.py +945 -0
  32. flux/pipeline_flux_controlnet_img2img.py +1002 -0
  33. flux/pipeline_flux_controlnet_inpainting.py +1199 -0
  34. flux/pipeline_flux_img2img.py +856 -0
  35. flux/pipeline_flux_inpaint.py +1021 -0
  36. flux/pipeline_output.py +21 -0
  37. flux/scheduling_flow_match_euler_discrete.py +325 -0
  38. flux/transformer_flux.py +572 -0
  39. main.py +154 -0
  40. model.py +644 -0
  41. modelmod.py +650 -0
  42. qwen2_vl/configuration_qwen2_vl.py +206 -0
  43. qwen2_vl/image_processing_qwen2_vl.py +458 -0
  44. qwen2_vl/modeling_qwen2_vl.py +1952 -0
  45. qwen2_vl/processing_qwen2_vl.py +183 -0
  46. requirements.txt +22 -0
  47. technical-report.pdf +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ technical-report.pdf filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 erwold
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Qwen2vl Flux Zero
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: Qwen2VL-Flux Zero (failure)
3
+ emoji: 🎨
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import requests
4
+ import random
5
+ import numpy as np
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from PIL import Image
10
+ from huggingface_hub import login
11
+ import os
12
+ import time
13
+ from gradio_imageslider import ImageSlider
14
+
15
+ import requests
16
+ from io import BytesIO
17
+ import PIL.Image
18
+ import requests
19
+ import shutil
20
+ import glob
21
+ from huggingface_hub import snapshot_download, hf_hub_download
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+ IMAGE_SIZE = 1024
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ HF_TOKEN = os.environ.get("HF_TOKEN")
27
+ if HF_TOKEN: login(token=HF_TOKEN)
28
+
29
+ cp_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints')
30
+ snapshot_download("Djrango/Qwen2vl-Flux", local_dir=cp_dir)
31
+ hf_hub_download(repo_id="TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline", local_dir=f"{cp_dir}/anyline")
32
+ shutil.move("checkpoints/anyline/Anyline/MTEED.pth", f"{cp_dir}/anyline")
33
+ snapshot_download("depth-anything/Depth-Anything-V2-Large", local_dir=f"{cp_dir}/depth-anything-v2")
34
+ snapshot_download("facebook/sam2-hiera-large", local_dir=f"{cp_dir}/segment-anything-2")
35
+ # https://github.com/facebookresearch/sam2/issues/26
36
+ os.makedirs("sam2_configs", exist_ok=True)
37
+ for p in glob.glob(f"{cp_dir}/segment-anything-2/*.yaml"):
38
+ shutil.copy(p, "sam2_configs")
39
+
40
+ from modelmod import FluxModel
41
+ model = FluxModel(device=DEVICE, is_turbo=False, required_features=['controlnet', 'depth', 'line'], is_quantization=True) # , 'sam'
42
+
43
+ QWEN2VLFLUX_MODES = ["variation", "img2img", "inpaint", "controlnet", "controlnet-inpaint"]
44
+ QWEN2VLFLUX_ASPECT_RATIO = ["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"]
45
+
46
+ class calculateDuration:
47
+ def __init__(self, activity_name=""):
48
+ self.activity_name = activity_name
49
+
50
+ def __enter__(self):
51
+ self.start_time = time.time()
52
+ self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
53
+ print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
54
+ return self
55
+
56
+ def __exit__(self, exc_type, exc_value, traceback):
57
+ self.end_time = time.time()
58
+ self.elapsed_time = self.end_time - self.start_time
59
+ self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
60
+
61
+ if self.activity_name:
62
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
63
+ else:
64
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
65
+
66
+ print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
67
+
68
+ def resize_image_dimensions(
69
+ original_resolution_wh: Tuple[int, int],
70
+ maximum_dimension: int = IMAGE_SIZE
71
+ ) -> Tuple[int, int]:
72
+ width, height = original_resolution_wh
73
+
74
+ # if width <= maximum_dimension and height <= maximum_dimension:
75
+ # width = width - (width % 32)
76
+ # height = height - (height % 32)
77
+ # return width, height
78
+
79
+ if width > height:
80
+ scaling_factor = maximum_dimension / width
81
+ else:
82
+ scaling_factor = maximum_dimension / height
83
+
84
+ new_width = int(width * scaling_factor)
85
+ new_height = int(height * scaling_factor)
86
+
87
+ new_width = new_width - (new_width % 32)
88
+ new_height = new_height - (new_height % 32)
89
+
90
+ return new_width, new_height
91
+
92
+ def fetch_from_url(url: str, name: str):
93
+ try:
94
+ print(f"start to fetch {name} from url", url)
95
+ response = requests.get(url)
96
+ response.raise_for_status()
97
+ image = PIL.Image.open(BytesIO(response.content))
98
+ print(f"fetch {name} success")
99
+ return image
100
+ except Exception as e:
101
+ print(e)
102
+ return None
103
+
104
+ @spaces.GPU(duration=100)
105
+ @torch.inference_mode()
106
+ def process(
107
+ mode: str,
108
+ input_image_editor: dict,
109
+ ref_image: Image.Image,
110
+ image_url: str,
111
+ mask_url: str,
112
+ ref_url: str,
113
+ input_text: str,
114
+ strength: float,
115
+ num_inference_steps: int,
116
+ guidance_scale: float,
117
+ aspect_ratio: str,
118
+ attn_mode: bool,
119
+ center_x: float,
120
+ center_y: float,
121
+ radius: float,
122
+ line_mode: bool,
123
+ line_strength: float,
124
+ depth_mode: bool,
125
+ depth_strength: float,
126
+ progress=gr.Progress(track_tqdm=True)
127
+ ):
128
+ #if not input_text:
129
+ # gr.Info("Please enter a text prompt.")
130
+ # return None
131
+
132
+ kwargs = {}
133
+
134
+ image = input_image_editor['background']
135
+ mask = input_image_editor['layers'][0]
136
+
137
+ if image_url: image = fetch_from_url(image_url, "image")
138
+ if mask_url: mask = fetch_from_url(mask_url, "mask")
139
+ if ref_url: ref_image = fetch_from_url(ref_url, "refernce image")
140
+
141
+ if not image:
142
+ gr.Info("Please upload an image.")
143
+ return None
144
+
145
+ if ref_image: kwargs["input_image_b"] = ref_image
146
+ if mode == "inpaint" or mode == "controlnet-inpaint":
147
+ if not mask:
148
+ gr.Info("Please draw a mask on the image.")
149
+ return None
150
+ kwargs["mask_image"] = mask
151
+
152
+ if attn_mode:
153
+ kwargs["center_x"] = center_x
154
+ kwargs["center_y"] = center_y
155
+ kwargs["radius"] = radius
156
+
157
+ with calculateDuration("run inference"):
158
+ result = model.generate(
159
+ input_image_a=image,
160
+ prompt=input_text,
161
+ guidance_scale=guidance_scale,
162
+ num_inference_steps=num_inference_steps,
163
+ aspect_ratio=aspect_ratio,
164
+ mode=mode,
165
+ denoise_strength=strength,
166
+ line_mode=line_mode,
167
+ line_strength=line_strength,
168
+ depth_mode=depth_mode,
169
+ depth_strength=depth_strength,
170
+ imageCount=1,
171
+ **kwargs
172
+ )[0]
173
+
174
+ #return result
175
+ return [image, result]
176
+
177
+ CSS = """
178
+ .title { text-align: center; }
179
+ """
180
+
181
+ with gr.Blocks(fill_width=True, css=CSS) as demo:
182
+ gr.Markdown("# Qwen2VL-Flux", elem_classes="title")
183
+ with gr.Row():
184
+ with gr.Column():
185
+ gen_mode = gr.Radio(label="Generation mode", choices=QWEN2VLFLUX_MODES, value="variation")
186
+ with gr.Row():
187
+ input_image_editor = gr.ImageEditor(label='Image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB',
188
+ layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
189
+ ref_image = gr.Image(label='Reference image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB')
190
+ with gr.Accordion("Image from URL", open=False):
191
+ image_url = gr.Textbox(label="Image url", show_label=True, max_lines=1, placeholder="Enter your image url (Optional)")
192
+ mask_url = gr.Textbox(label="Mask image url", show_label=True, max_lines=1, placeholder="Enter your mask image url (Optional)")
193
+ ref_url = gr.Textbox(label="Reference image url", show_label=True, max_lines=1, placeholder="Enter your reference image url (Optional)")
194
+
195
+ with gr.Accordion("Prompt Settings", open=True):
196
+ input_text = gr.Textbox(label="Prompt", show_label=True, max_lines=1, placeholder="Enter your prompt")
197
+ submit_button = gr.Button(value='Submit', variant='primary')
198
+
199
+ with gr.Accordion("Advanced Settings", open=True):
200
+ with gr.Row():
201
+ denoise_strength = gr.Slider(label="Denoise strength", minimum=0, maximum=1, step=0.01, value=0.75)
202
+ aspect_ratio = gr.Radio(label="Output image ratio", choices=QWEN2VLFLUX_ASPECT_RATIO, value="1:1")
203
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28)
204
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.5, value=3.5)
205
+ with gr.Accordion("Attention Control", open=True):
206
+ with gr.Row():
207
+ attn_mode = gr.Checkbox(label="Attention Control", value=False)
208
+ center_x = gr.Slider(label="X coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
209
+ center_y = gr.Slider(label="Y coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
210
+ radius = gr.Slider(label="Radius of attention circle", minimum=0, maximum=1, step=0.01, value=0.5)
211
+ with gr.Accordion("ControlNet Settings", open=True):
212
+ with gr.Row():
213
+ line_mode = gr.Checkbox(label="Line mode", value=True)
214
+ line_strength = gr.Slider(label="Line strength", minimum=0, maximum=1, step=0.01, value=0.4)
215
+ depth_mode = gr.Checkbox(label="Depth mode", value=True)
216
+ depth_strength = gr.Slider(label="Depth strength", minimum=0, maximum=1, step=0.01, value=0.2)
217
+
218
+ with gr.Column():
219
+ #output_image = gr.Image(label="Generated image", type="pil", format="png", show_download_button=True, show_share_button=False)
220
+ output_image = ImageSlider(label="Generated image", type="pil")
221
+
222
+ gr.on(triggers=[submit_button.click, input_text.submit], fn=process,
223
+ inputs=[gen_mode, input_image_editor, ref_image, image_url, mask_url, ref_url,
224
+ input_text, denoise_strength, num_inference_steps, guidance_scale, aspect_ratio,
225
+ attn_mode, center_x, center_y, radius, line_mode, line_strength, depth_mode, depth_strength],
226
+ outputs=[output_image], queue=True)
227
+
228
+ demo.queue().launch(debug=True, show_error=True)
229
+ #demo.queue().launch(debug=True, show_error=True, ssr_mode=False) # Gradio 5
depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
depth_anything_v2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
depth_anything_v2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
depth_anything_v2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
depth_anything_v2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
depth_anything_v2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
depth_anything_v2/dpt.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Compose
6
+
7
+ from .dinov2 import DINOv2
8
+ from .util.blocks import FeatureFusionBlock, _make_scratch
9
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
10
+
11
+
12
+ def _make_fusion_block(features, use_bn, size=None):
13
+ return FeatureFusionBlock(
14
+ features,
15
+ nn.ReLU(False),
16
+ deconv=False,
17
+ bn=use_bn,
18
+ expand=False,
19
+ align_corners=True,
20
+ size=size,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(self, in_feature, out_feature):
26
+ super().__init__()
27
+
28
+ self.conv_block = nn.Sequential(
29
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
30
+ nn.BatchNorm2d(out_feature),
31
+ nn.ReLU(True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.conv_block(x)
36
+
37
+
38
+ class DPTHead(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels,
42
+ features=256,
43
+ use_bn=False,
44
+ out_channels=[256, 512, 1024, 1024],
45
+ use_clstoken=False
46
+ ):
47
+ super(DPTHead, self).__init__()
48
+
49
+ self.use_clstoken = use_clstoken
50
+
51
+ self.projects = nn.ModuleList([
52
+ nn.Conv2d(
53
+ in_channels=in_channels,
54
+ out_channels=out_channel,
55
+ kernel_size=1,
56
+ stride=1,
57
+ padding=0,
58
+ ) for out_channel in out_channels
59
+ ])
60
+
61
+ self.resize_layers = nn.ModuleList([
62
+ nn.ConvTranspose2d(
63
+ in_channels=out_channels[0],
64
+ out_channels=out_channels[0],
65
+ kernel_size=4,
66
+ stride=4,
67
+ padding=0),
68
+ nn.ConvTranspose2d(
69
+ in_channels=out_channels[1],
70
+ out_channels=out_channels[1],
71
+ kernel_size=2,
72
+ stride=2,
73
+ padding=0),
74
+ nn.Identity(),
75
+ nn.Conv2d(
76
+ in_channels=out_channels[3],
77
+ out_channels=out_channels[3],
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=1)
81
+ ])
82
+
83
+ if use_clstoken:
84
+ self.readout_projects = nn.ModuleList()
85
+ for _ in range(len(self.projects)):
86
+ self.readout_projects.append(
87
+ nn.Sequential(
88
+ nn.Linear(2 * in_channels, in_channels),
89
+ nn.GELU()))
90
+
91
+ self.scratch = _make_scratch(
92
+ out_channels,
93
+ features,
94
+ groups=1,
95
+ expand=False,
96
+ )
97
+
98
+ self.scratch.stem_transpose = None
99
+
100
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
101
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
104
+
105
+ head_features_1 = features
106
+ head_features_2 = 32
107
+
108
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(True),
112
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
113
+ nn.ReLU(True),
114
+ nn.Identity(),
115
+ )
116
+
117
+ def forward(self, out_features, patch_h, patch_w):
118
+ out = []
119
+ for i, x in enumerate(out_features):
120
+ if self.use_clstoken:
121
+ x, cls_token = x[0], x[1]
122
+ readout = cls_token.unsqueeze(1).expand_as(x)
123
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
124
+ else:
125
+ x = x[0]
126
+
127
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
128
+
129
+ x = self.projects[i](x)
130
+ x = self.resize_layers[i](x)
131
+
132
+ out.append(x)
133
+
134
+ layer_1, layer_2, layer_3, layer_4 = out
135
+
136
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
137
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
138
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
139
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
140
+
141
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
142
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
143
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
144
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
145
+
146
+ out = self.scratch.output_conv1(path_1)
147
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
148
+ out = self.scratch.output_conv2(out)
149
+
150
+ return out
151
+
152
+
153
+ class DepthAnythingV2(nn.Module):
154
+ def __init__(
155
+ self,
156
+ encoder='vitl',
157
+ features=256,
158
+ out_channels=[256, 512, 1024, 1024],
159
+ use_bn=False,
160
+ use_clstoken=False
161
+ ):
162
+ super(DepthAnythingV2, self).__init__()
163
+
164
+ self.intermediate_layer_idx = {
165
+ 'vits': [2, 5, 8, 11],
166
+ 'vitb': [2, 5, 8, 11],
167
+ 'vitl': [4, 11, 17, 23],
168
+ 'vitg': [9, 19, 29, 39]
169
+ }
170
+
171
+ self.encoder = encoder
172
+ self.pretrained = DINOv2(model_name=encoder)
173
+
174
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
175
+
176
+ def forward(self, x):
177
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
178
+
179
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
180
+
181
+ depth = self.depth_head(features, patch_h, patch_w)
182
+ depth = F.relu(depth)
183
+
184
+ return depth.squeeze(1)
185
+
186
+ @torch.no_grad()
187
+ def infer_image(self, raw_image, input_size=518):
188
+ image, (h, w) = self.image2tensor(raw_image, input_size)
189
+
190
+ depth = self.forward(image)
191
+
192
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
193
+
194
+ return depth.cpu().numpy()
195
+
196
+ def image2tensor(self, raw_image, input_size=518):
197
+ transform = Compose([
198
+ Resize(
199
+ width=input_size,
200
+ height=input_size,
201
+ resize_target=False,
202
+ keep_aspect_ratio=True,
203
+ ensure_multiple_of=14,
204
+ resize_method='lower_bound',
205
+ image_interpolation_method=cv2.INTER_CUBIC,
206
+ ),
207
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
208
+ PrepareForNet(),
209
+ ])
210
+
211
+ h, w = raw_image.shape[:2]
212
+
213
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
214
+
215
+ image = transform({'image': image})['image']
216
+ image = torch.from_numpy(image).unsqueeze(0)
217
+
218
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
219
+ image = image.to(DEVICE)
220
+
221
+ return image, (h, w)
depth_anything_v2/util/blocks.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23
+ if len(in_shape) >= 4:
24
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+
26
+ return scratch
27
+
28
+
29
+ class ResidualConvUnit(nn.Module):
30
+ """Residual convolution module.
31
+ """
32
+
33
+ def __init__(self, features, activation, bn):
34
+ """Init.
35
+
36
+ Args:
37
+ features (int): number of features
38
+ """
39
+ super().__init__()
40
+
41
+ self.bn = bn
42
+
43
+ self.groups=1
44
+
45
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46
+
47
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48
+
49
+ if self.bn == True:
50
+ self.bn1 = nn.BatchNorm2d(features)
51
+ self.bn2 = nn.BatchNorm2d(features)
52
+
53
+ self.activation = activation
54
+
55
+ self.skip_add = nn.quantized.FloatFunctional()
56
+
57
+ def forward(self, x):
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x (tensor): input
62
+
63
+ Returns:
64
+ tensor: output
65
+ """
66
+
67
+ out = self.activation(x)
68
+ out = self.conv1(out)
69
+ if self.bn == True:
70
+ out = self.bn1(out)
71
+
72
+ out = self.activation(out)
73
+ out = self.conv2(out)
74
+ if self.bn == True:
75
+ out = self.bn2(out)
76
+
77
+ if self.groups > 1:
78
+ out = self.conv_merge(out)
79
+
80
+ return self.skip_add.add(out, x)
81
+
82
+
83
+ class FeatureFusionBlock(nn.Module):
84
+ """Feature fusion block.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ features,
90
+ activation,
91
+ deconv=False,
92
+ bn=False,
93
+ expand=False,
94
+ align_corners=True,
95
+ size=None
96
+ ):
97
+ """Init.
98
+
99
+ Args:
100
+ features (int): number of features
101
+ """
102
+ super(FeatureFusionBlock, self).__init__()
103
+
104
+ self.deconv = deconv
105
+ self.align_corners = align_corners
106
+
107
+ self.groups=1
108
+
109
+ self.expand = expand
110
+ out_features = features
111
+ if self.expand == True:
112
+ out_features = features // 2
113
+
114
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115
+
116
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ self.size=size
122
+
123
+ def forward(self, *xs, size=None):
124
+ """Forward pass.
125
+
126
+ Returns:
127
+ tensor: output
128
+ """
129
+ output = xs[0]
130
+
131
+ if len(xs) == 2:
132
+ res = self.resConfUnit1(xs[1])
133
+ output = self.skip_add.add(output, res)
134
+
135
+ output = self.resConfUnit2(output)
136
+
137
+ if (size is None) and (self.size is None):
138
+ modifier = {"scale_factor": 2}
139
+ elif size is None:
140
+ modifier = {"size": self.size}
141
+ else:
142
+ modifier = {"size": size}
143
+
144
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145
+
146
+ output = self.out_conv(output)
147
+
148
+ return output
depth_anything_v2/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
flux-architecture.svg ADDED
flux/activations.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate
21
+ from diffusers.utils.import_utils import is_torch_npu_available
22
+
23
+
24
+ if is_torch_npu_available():
25
+ import torch_npu
26
+
27
+ ACTIVATION_FUNCTIONS = {
28
+ "swish": nn.SiLU(),
29
+ "silu": nn.SiLU(),
30
+ "mish": nn.Mish(),
31
+ "gelu": nn.GELU(),
32
+ "relu": nn.ReLU(),
33
+ }
34
+
35
+
36
+ def get_activation(act_fn: str) -> nn.Module:
37
+ """Helper function to get activation function from string.
38
+
39
+ Args:
40
+ act_fn (str): Name of activation function.
41
+
42
+ Returns:
43
+ nn.Module: Activation function.
44
+ """
45
+
46
+ act_fn = act_fn.lower()
47
+ if act_fn in ACTIVATION_FUNCTIONS:
48
+ return ACTIVATION_FUNCTIONS[act_fn]
49
+ else:
50
+ raise ValueError(f"Unsupported activation function: {act_fn}")
51
+
52
+
53
+ class FP32SiLU(nn.Module):
54
+ r"""
55
+ SiLU activation function with input upcasted to torch.float32.
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
63
+
64
+
65
+ class GELU(nn.Module):
66
+ r"""
67
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
68
+
69
+ Parameters:
70
+ dim_in (`int`): The number of channels in the input.
71
+ dim_out (`int`): The number of channels in the output.
72
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
73
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
74
+ """
75
+
76
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
77
+ super().__init__()
78
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
79
+ self.approximate = approximate
80
+
81
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82
+ if gate.device.type != "mps":
83
+ return F.gelu(gate, approximate=self.approximate)
84
+ # mps: gelu is not implemented for float16
85
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
86
+
87
+ def forward(self, hidden_states):
88
+ hidden_states = self.proj(hidden_states)
89
+ hidden_states = self.gelu(hidden_states)
90
+ return hidden_states
91
+
92
+
93
+ class GEGLU(nn.Module):
94
+ r"""
95
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
96
+
97
+ Parameters:
98
+ dim_in (`int`): The number of channels in the input.
99
+ dim_out (`int`): The number of channels in the output.
100
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
101
+ """
102
+
103
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
104
+ super().__init__()
105
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106
+
107
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108
+ if gate.device.type != "mps":
109
+ return F.gelu(gate)
110
+ # mps: gelu is not implemented for float16
111
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
112
+
113
+ def forward(self, hidden_states, *args, **kwargs):
114
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
115
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
116
+ deprecate("scale", "1.0.0", deprecation_message)
117
+ hidden_states = self.proj(hidden_states)
118
+ if is_torch_npu_available():
119
+ # using torch_npu.npu_geglu can run faster and save memory on NPU.
120
+ return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
121
+ else:
122
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
123
+ return hidden_states * self.gelu(gate)
124
+
125
+
126
+ class SwiGLU(nn.Module):
127
+ r"""
128
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
129
+ but uses SiLU / Swish instead of GeLU.
130
+
131
+ Parameters:
132
+ dim_in (`int`): The number of channels in the input.
133
+ dim_out (`int`): The number of channels in the output.
134
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
135
+ """
136
+
137
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
+ super().__init__()
139
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140
+ self.activation = nn.SiLU()
141
+
142
+ def forward(self, hidden_states):
143
+ hidden_states = self.proj(hidden_states)
144
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
145
+ return hidden_states * self.activation(gate)
146
+
147
+
148
+ class ApproximateGELU(nn.Module):
149
+ r"""
150
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
151
+ [paper](https://arxiv.org/abs/1606.08415).
152
+
153
+ Parameters:
154
+ dim_in (`int`): The number of channels in the input.
155
+ dim_out (`int`): The number of channels in the output.
156
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
157
+ """
158
+
159
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
160
+ super().__init__()
161
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ x = self.proj(x)
165
+ return x * torch.sigmoid(1.702 * x)
flux/attention.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
+ from .embeddings import SinusoidalPositionalEmbedding
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
104
+ super().__init__()
105
+
106
+ self.context_pre_only = context_pre_only
107
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
+
109
+ self.norm1 = AdaLayerNormZero(dim)
110
+
111
+ if context_norm_type == "ada_norm_continous":
112
+ self.norm1_context = AdaLayerNormContinuous(
113
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
114
+ )
115
+ elif context_norm_type == "ada_norm_zero":
116
+ self.norm1_context = AdaLayerNormZero(dim)
117
+ else:
118
+ raise ValueError(
119
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
+ )
121
+ if hasattr(F, "scaled_dot_product_attention"):
122
+ processor = JointAttnProcessor2_0()
123
+ else:
124
+ raise ValueError(
125
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
+ )
127
+ self.attn = Attention(
128
+ query_dim=dim,
129
+ cross_attention_dim=None,
130
+ added_kv_proj_dim=dim,
131
+ dim_head=attention_head_dim,
132
+ heads=num_attention_heads,
133
+ out_dim=dim,
134
+ context_pre_only=context_pre_only,
135
+ bias=True,
136
+ processor=processor,
137
+ )
138
+
139
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
+
142
+ if not context_pre_only:
143
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
144
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
145
+ else:
146
+ self.norm2_context = None
147
+ self.ff_context = None
148
+
149
+ # let chunk size default to None
150
+ self._chunk_size = None
151
+ self._chunk_dim = 0
152
+
153
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
154
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
155
+ # Sets chunk feed-forward
156
+ self._chunk_size = chunk_size
157
+ self._chunk_dim = dim
158
+
159
+ def forward(
160
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161
+ ):
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
+
164
+ if self.context_pre_only:
165
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
166
+ else:
167
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168
+ encoder_hidden_states, emb=temb
169
+ )
170
+
171
+ # Attention.
172
+ attn_output, context_attn_output = self.attn(
173
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
174
+ )
175
+
176
+ # Process attention outputs for the `hidden_states`.
177
+ attn_output = gate_msa.unsqueeze(1) * attn_output
178
+ hidden_states = hidden_states + attn_output
179
+
180
+ norm_hidden_states = self.norm2(hidden_states)
181
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
+ if self._chunk_size is not None:
183
+ # "feed_forward_chunk_size" can be used to save memory
184
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
185
+ else:
186
+ ff_output = self.ff(norm_hidden_states)
187
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
188
+
189
+ hidden_states = hidden_states + ff_output
190
+
191
+ # Process attention outputs for the `encoder_hidden_states`.
192
+ if self.context_pre_only:
193
+ encoder_hidden_states = None
194
+ else:
195
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
196
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
197
+
198
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
199
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
200
+ if self._chunk_size is not None:
201
+ # "feed_forward_chunk_size" can be used to save memory
202
+ context_ff_output = _chunked_feed_forward(
203
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
204
+ )
205
+ else:
206
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
207
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
208
+
209
+ return encoder_hidden_states, hidden_states
210
+
211
+
212
+ @maybe_allow_in_graph
213
+ class BasicTransformerBlock(nn.Module):
214
+ r"""
215
+ A basic Transformer block.
216
+
217
+ Parameters:
218
+ dim (`int`): The number of channels in the input and output.
219
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
220
+ attention_head_dim (`int`): The number of channels in each head.
221
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
222
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
223
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
224
+ num_embeds_ada_norm (:
225
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
226
+ attention_bias (:
227
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
228
+ only_cross_attention (`bool`, *optional*):
229
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
230
+ double_self_attention (`bool`, *optional*):
231
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
232
+ upcast_attention (`bool`, *optional*):
233
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
234
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
235
+ Whether to use learnable elementwise affine parameters for normalization.
236
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
237
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
238
+ final_dropout (`bool` *optional*, defaults to False):
239
+ Whether to apply a final dropout after the last feed-forward layer.
240
+ attention_type (`str`, *optional*, defaults to `"default"`):
241
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
242
+ positional_embeddings (`str`, *optional*, defaults to `None`):
243
+ The type of positional embeddings to apply to.
244
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
245
+ The maximum number of positional embeddings to apply.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ dim: int,
251
+ num_attention_heads: int,
252
+ attention_head_dim: int,
253
+ dropout=0.0,
254
+ cross_attention_dim: Optional[int] = None,
255
+ activation_fn: str = "geglu",
256
+ num_embeds_ada_norm: Optional[int] = None,
257
+ attention_bias: bool = False,
258
+ only_cross_attention: bool = False,
259
+ double_self_attention: bool = False,
260
+ upcast_attention: bool = False,
261
+ norm_elementwise_affine: bool = True,
262
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
263
+ norm_eps: float = 1e-5,
264
+ final_dropout: bool = False,
265
+ attention_type: str = "default",
266
+ positional_embeddings: Optional[str] = None,
267
+ num_positional_embeddings: Optional[int] = None,
268
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
269
+ ada_norm_bias: Optional[int] = None,
270
+ ff_inner_dim: Optional[int] = None,
271
+ ff_bias: bool = True,
272
+ attention_out_bias: bool = True,
273
+ ):
274
+ super().__init__()
275
+ self.only_cross_attention = only_cross_attention
276
+
277
+ # We keep these boolean flags for backward-compatibility.
278
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
279
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
280
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
281
+ self.use_layer_norm = norm_type == "layer_norm"
282
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
283
+
284
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
285
+ raise ValueError(
286
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
287
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
288
+ )
289
+
290
+ self.norm_type = norm_type
291
+ self.num_embeds_ada_norm = num_embeds_ada_norm
292
+
293
+ if positional_embeddings and (num_positional_embeddings is None):
294
+ raise ValueError(
295
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
296
+ )
297
+
298
+ if positional_embeddings == "sinusoidal":
299
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
300
+ else:
301
+ self.pos_embed = None
302
+
303
+ # Define 3 blocks. Each block has its own normalization layer.
304
+ # 1. Self-Attn
305
+ if norm_type == "ada_norm":
306
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
307
+ elif norm_type == "ada_norm_zero":
308
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
309
+ elif norm_type == "ada_norm_continuous":
310
+ self.norm1 = AdaLayerNormContinuous(
311
+ dim,
312
+ ada_norm_continous_conditioning_embedding_dim,
313
+ norm_elementwise_affine,
314
+ norm_eps,
315
+ ada_norm_bias,
316
+ "rms_norm",
317
+ )
318
+ else:
319
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
320
+
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
328
+ upcast_attention=upcast_attention,
329
+ out_bias=attention_out_bias,
330
+ )
331
+
332
+ # 2. Cross-Attn
333
+ if cross_attention_dim is not None or double_self_attention:
334
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
335
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
336
+ # the second cross attention block.
337
+ if norm_type == "ada_norm":
338
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
339
+ elif norm_type == "ada_norm_continuous":
340
+ self.norm2 = AdaLayerNormContinuous(
341
+ dim,
342
+ ada_norm_continous_conditioning_embedding_dim,
343
+ norm_elementwise_affine,
344
+ norm_eps,
345
+ ada_norm_bias,
346
+ "rms_norm",
347
+ )
348
+ else:
349
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
350
+
351
+ self.attn2 = Attention(
352
+ query_dim=dim,
353
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
354
+ heads=num_attention_heads,
355
+ dim_head=attention_head_dim,
356
+ dropout=dropout,
357
+ bias=attention_bias,
358
+ upcast_attention=upcast_attention,
359
+ out_bias=attention_out_bias,
360
+ ) # is self-attn if encoder_hidden_states is none
361
+ else:
362
+ if norm_type == "ada_norm_single": # For Latte
363
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
364
+ else:
365
+ self.norm2 = None
366
+ self.attn2 = None
367
+
368
+ # 3. Feed-forward
369
+ if norm_type == "ada_norm_continuous":
370
+ self.norm3 = AdaLayerNormContinuous(
371
+ dim,
372
+ ada_norm_continous_conditioning_embedding_dim,
373
+ norm_elementwise_affine,
374
+ norm_eps,
375
+ ada_norm_bias,
376
+ "layer_norm",
377
+ )
378
+
379
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
380
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
381
+ elif norm_type == "layer_norm_i2vgen":
382
+ self.norm3 = None
383
+
384
+ self.ff = FeedForward(
385
+ dim,
386
+ dropout=dropout,
387
+ activation_fn=activation_fn,
388
+ final_dropout=final_dropout,
389
+ inner_dim=ff_inner_dim,
390
+ bias=ff_bias,
391
+ )
392
+
393
+ # 4. Fuser
394
+ if attention_type == "gated" or attention_type == "gated-text-image":
395
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
396
+
397
+ # 5. Scale-shift for PixArt-Alpha.
398
+ if norm_type == "ada_norm_single":
399
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
400
+
401
+ # let chunk size default to None
402
+ self._chunk_size = None
403
+ self._chunk_dim = 0
404
+
405
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
406
+ # Sets chunk feed-forward
407
+ self._chunk_size = chunk_size
408
+ self._chunk_dim = dim
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ attention_mask: Optional[torch.Tensor] = None,
414
+ encoder_hidden_states: Optional[torch.Tensor] = None,
415
+ encoder_attention_mask: Optional[torch.Tensor] = None,
416
+ timestep: Optional[torch.LongTensor] = None,
417
+ cross_attention_kwargs: Dict[str, Any] = None,
418
+ class_labels: Optional[torch.LongTensor] = None,
419
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
420
+ ) -> torch.Tensor:
421
+ if cross_attention_kwargs is not None:
422
+ if cross_attention_kwargs.get("scale", None) is not None:
423
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
424
+
425
+ # Notice that normalization is always applied before the real computation in the following blocks.
426
+ # 0. Self-Attention
427
+ batch_size = hidden_states.shape[0]
428
+
429
+ if self.norm_type == "ada_norm":
430
+ norm_hidden_states = self.norm1(hidden_states, timestep)
431
+ elif self.norm_type == "ada_norm_zero":
432
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
433
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
434
+ )
435
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
436
+ norm_hidden_states = self.norm1(hidden_states)
437
+ elif self.norm_type == "ada_norm_continuous":
438
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
439
+ elif self.norm_type == "ada_norm_single":
440
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
441
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
442
+ ).chunk(6, dim=1)
443
+ norm_hidden_states = self.norm1(hidden_states)
444
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
445
+ else:
446
+ raise ValueError("Incorrect norm used")
447
+
448
+ if self.pos_embed is not None:
449
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
450
+
451
+ # 1. Prepare GLIGEN inputs
452
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
453
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
454
+
455
+ attn_output = self.attn1(
456
+ norm_hidden_states,
457
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
458
+ attention_mask=attention_mask,
459
+ **cross_attention_kwargs,
460
+ )
461
+
462
+ if self.norm_type == "ada_norm_zero":
463
+ attn_output = gate_msa.unsqueeze(1) * attn_output
464
+ elif self.norm_type == "ada_norm_single":
465
+ attn_output = gate_msa * attn_output
466
+
467
+ hidden_states = attn_output + hidden_states
468
+ if hidden_states.ndim == 4:
469
+ hidden_states = hidden_states.squeeze(1)
470
+
471
+ # 1.2 GLIGEN Control
472
+ if gligen_kwargs is not None:
473
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
474
+
475
+ # 3. Cross-Attention
476
+ if self.attn2 is not None:
477
+ if self.norm_type == "ada_norm":
478
+ norm_hidden_states = self.norm2(hidden_states, timestep)
479
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
480
+ norm_hidden_states = self.norm2(hidden_states)
481
+ elif self.norm_type == "ada_norm_single":
482
+ # For PixArt norm2 isn't applied here:
483
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
484
+ norm_hidden_states = hidden_states
485
+ elif self.norm_type == "ada_norm_continuous":
486
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
487
+ else:
488
+ raise ValueError("Incorrect norm")
489
+
490
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
491
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
492
+
493
+ attn_output = self.attn2(
494
+ norm_hidden_states,
495
+ encoder_hidden_states=encoder_hidden_states,
496
+ attention_mask=encoder_attention_mask,
497
+ **cross_attention_kwargs,
498
+ )
499
+ hidden_states = attn_output + hidden_states
500
+
501
+ # 4. Feed-forward
502
+ # i2vgen doesn't have this norm 🤷‍♂️
503
+ if self.norm_type == "ada_norm_continuous":
504
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
505
+ elif not self.norm_type == "ada_norm_single":
506
+ norm_hidden_states = self.norm3(hidden_states)
507
+
508
+ if self.norm_type == "ada_norm_zero":
509
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
510
+
511
+ if self.norm_type == "ada_norm_single":
512
+ norm_hidden_states = self.norm2(hidden_states)
513
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
514
+
515
+ if self._chunk_size is not None:
516
+ # "feed_forward_chunk_size" can be used to save memory
517
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
518
+ else:
519
+ ff_output = self.ff(norm_hidden_states)
520
+
521
+ if self.norm_type == "ada_norm_zero":
522
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
523
+ elif self.norm_type == "ada_norm_single":
524
+ ff_output = gate_mlp * ff_output
525
+
526
+ hidden_states = ff_output + hidden_states
527
+ if hidden_states.ndim == 4:
528
+ hidden_states = hidden_states.squeeze(1)
529
+
530
+ return hidden_states
531
+
532
+
533
+ class LuminaFeedForward(nn.Module):
534
+ r"""
535
+ A feed-forward layer.
536
+
537
+ Parameters:
538
+ hidden_size (`int`):
539
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
540
+ hidden representations.
541
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
542
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
543
+ of this value.
544
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
545
+ dimension. Defaults to None.
546
+ """
547
+
548
+ def __init__(
549
+ self,
550
+ dim: int,
551
+ inner_dim: int,
552
+ multiple_of: Optional[int] = 256,
553
+ ffn_dim_multiplier: Optional[float] = None,
554
+ ):
555
+ super().__init__()
556
+ inner_dim = int(2 * inner_dim / 3)
557
+ # custom hidden_size factor multiplier
558
+ if ffn_dim_multiplier is not None:
559
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
560
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
561
+
562
+ self.linear_1 = nn.Linear(
563
+ dim,
564
+ inner_dim,
565
+ bias=False,
566
+ )
567
+ self.linear_2 = nn.Linear(
568
+ inner_dim,
569
+ dim,
570
+ bias=False,
571
+ )
572
+ self.linear_3 = nn.Linear(
573
+ dim,
574
+ inner_dim,
575
+ bias=False,
576
+ )
577
+ self.silu = FP32SiLU()
578
+
579
+ def forward(self, x):
580
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
581
+
582
+
583
+ @maybe_allow_in_graph
584
+ class TemporalBasicTransformerBlock(nn.Module):
585
+ r"""
586
+ A basic Transformer block for video like data.
587
+
588
+ Parameters:
589
+ dim (`int`): The number of channels in the input and output.
590
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
591
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
592
+ attention_head_dim (`int`): The number of channels in each head.
593
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
594
+ """
595
+
596
+ def __init__(
597
+ self,
598
+ dim: int,
599
+ time_mix_inner_dim: int,
600
+ num_attention_heads: int,
601
+ attention_head_dim: int,
602
+ cross_attention_dim: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+ self.is_res = dim == time_mix_inner_dim
606
+
607
+ self.norm_in = nn.LayerNorm(dim)
608
+
609
+ # Define 3 blocks. Each block has its own normalization layer.
610
+ # 1. Self-Attn
611
+ self.ff_in = FeedForward(
612
+ dim,
613
+ dim_out=time_mix_inner_dim,
614
+ activation_fn="geglu",
615
+ )
616
+
617
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
618
+ self.attn1 = Attention(
619
+ query_dim=time_mix_inner_dim,
620
+ heads=num_attention_heads,
621
+ dim_head=attention_head_dim,
622
+ cross_attention_dim=None,
623
+ )
624
+
625
+ # 2. Cross-Attn
626
+ if cross_attention_dim is not None:
627
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
628
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
629
+ # the second cross attention block.
630
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
631
+ self.attn2 = Attention(
632
+ query_dim=time_mix_inner_dim,
633
+ cross_attention_dim=cross_attention_dim,
634
+ heads=num_attention_heads,
635
+ dim_head=attention_head_dim,
636
+ ) # is self-attn if encoder_hidden_states is none
637
+ else:
638
+ self.norm2 = None
639
+ self.attn2 = None
640
+
641
+ # 3. Feed-forward
642
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
643
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
644
+
645
+ # let chunk size default to None
646
+ self._chunk_size = None
647
+ self._chunk_dim = None
648
+
649
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
650
+ # Sets chunk feed-forward
651
+ self._chunk_size = chunk_size
652
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
653
+ self._chunk_dim = 1
654
+
655
+ def forward(
656
+ self,
657
+ hidden_states: torch.Tensor,
658
+ num_frames: int,
659
+ encoder_hidden_states: Optional[torch.Tensor] = None,
660
+ ) -> torch.Tensor:
661
+ # Notice that normalization is always applied before the real computation in the following blocks.
662
+ # 0. Self-Attention
663
+ batch_size = hidden_states.shape[0]
664
+
665
+ batch_frames, seq_length, channels = hidden_states.shape
666
+ batch_size = batch_frames // num_frames
667
+
668
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
669
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
670
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
671
+
672
+ residual = hidden_states
673
+ hidden_states = self.norm_in(hidden_states)
674
+
675
+ if self._chunk_size is not None:
676
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
677
+ else:
678
+ hidden_states = self.ff_in(hidden_states)
679
+
680
+ if self.is_res:
681
+ hidden_states = hidden_states + residual
682
+
683
+ norm_hidden_states = self.norm1(hidden_states)
684
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
685
+ hidden_states = attn_output + hidden_states
686
+
687
+ # 3. Cross-Attention
688
+ if self.attn2 is not None:
689
+ norm_hidden_states = self.norm2(hidden_states)
690
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
691
+ hidden_states = attn_output + hidden_states
692
+
693
+ # 4. Feed-forward
694
+ norm_hidden_states = self.norm3(hidden_states)
695
+
696
+ if self._chunk_size is not None:
697
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
698
+ else:
699
+ ff_output = self.ff(norm_hidden_states)
700
+
701
+ if self.is_res:
702
+ hidden_states = ff_output + hidden_states
703
+ else:
704
+ hidden_states = ff_output
705
+
706
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
707
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
708
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
709
+
710
+ return hidden_states
711
+
712
+
713
+ class SkipFFTransformerBlock(nn.Module):
714
+ def __init__(
715
+ self,
716
+ dim: int,
717
+ num_attention_heads: int,
718
+ attention_head_dim: int,
719
+ kv_input_dim: int,
720
+ kv_input_dim_proj_use_bias: bool,
721
+ dropout=0.0,
722
+ cross_attention_dim: Optional[int] = None,
723
+ attention_bias: bool = False,
724
+ attention_out_bias: bool = True,
725
+ ):
726
+ super().__init__()
727
+ if kv_input_dim != dim:
728
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
729
+ else:
730
+ self.kv_mapper = None
731
+
732
+ self.norm1 = RMSNorm(dim, 1e-06)
733
+
734
+ self.attn1 = Attention(
735
+ query_dim=dim,
736
+ heads=num_attention_heads,
737
+ dim_head=attention_head_dim,
738
+ dropout=dropout,
739
+ bias=attention_bias,
740
+ cross_attention_dim=cross_attention_dim,
741
+ out_bias=attention_out_bias,
742
+ )
743
+
744
+ self.norm2 = RMSNorm(dim, 1e-06)
745
+
746
+ self.attn2 = Attention(
747
+ query_dim=dim,
748
+ cross_attention_dim=cross_attention_dim,
749
+ heads=num_attention_heads,
750
+ dim_head=attention_head_dim,
751
+ dropout=dropout,
752
+ bias=attention_bias,
753
+ out_bias=attention_out_bias,
754
+ )
755
+
756
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
757
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
758
+
759
+ if self.kv_mapper is not None:
760
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
761
+
762
+ norm_hidden_states = self.norm1(hidden_states)
763
+
764
+ attn_output = self.attn1(
765
+ norm_hidden_states,
766
+ encoder_hidden_states=encoder_hidden_states,
767
+ **cross_attention_kwargs,
768
+ )
769
+
770
+ hidden_states = attn_output + hidden_states
771
+
772
+ norm_hidden_states = self.norm2(hidden_states)
773
+
774
+ attn_output = self.attn2(
775
+ norm_hidden_states,
776
+ encoder_hidden_states=encoder_hidden_states,
777
+ **cross_attention_kwargs,
778
+ )
779
+
780
+ hidden_states = attn_output + hidden_states
781
+
782
+ return hidden_states
783
+
784
+
785
+ class FeedForward(nn.Module):
786
+ r"""
787
+ A feed-forward layer.
788
+
789
+ Parameters:
790
+ dim (`int`): The number of channels in the input.
791
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
792
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
793
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
794
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
795
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
796
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
797
+ """
798
+
799
+ def __init__(
800
+ self,
801
+ dim: int,
802
+ dim_out: Optional[int] = None,
803
+ mult: int = 4,
804
+ dropout: float = 0.0,
805
+ activation_fn: str = "geglu",
806
+ final_dropout: bool = False,
807
+ inner_dim=None,
808
+ bias: bool = True,
809
+ ):
810
+ super().__init__()
811
+ if inner_dim is None:
812
+ inner_dim = int(dim * mult)
813
+ dim_out = dim_out if dim_out is not None else dim
814
+
815
+ if activation_fn == "gelu":
816
+ act_fn = GELU(dim, inner_dim, bias=bias)
817
+ if activation_fn == "gelu-approximate":
818
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
819
+ elif activation_fn == "geglu":
820
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
821
+ elif activation_fn == "geglu-approximate":
822
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
823
+ elif activation_fn == "swiglu":
824
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
825
+
826
+ self.net = nn.ModuleList([])
827
+ # project in
828
+ self.net.append(act_fn)
829
+ # project dropout
830
+ self.net.append(nn.Dropout(dropout))
831
+ # project out
832
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
833
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
834
+ if final_dropout:
835
+ self.net.append(nn.Dropout(dropout))
836
+
837
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
838
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
839
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
840
+ deprecate("scale", "1.0.0", deprecation_message)
841
+ for module in self.net:
842
+ hidden_states = module(hidden_states)
843
+ return hidden_states
flux/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
flux/controlnet_flux.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from .lora.peft import PeftAdapterMixin
23
+ from diffusers.models.attention_processor import AttentionProcessor
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
+ #from .controlnet import BaseOutput, zero_module
27
+ from diffusers.utils import BaseOutput
28
+ from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+ from .transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
31
+ import numpy as np
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ def zero_module(module):
37
+ for p in module.parameters():
38
+ nn.init.zeros_(p)
39
+ return module
40
+
41
+ def get_1d_rotary_pos_embed(
42
+ dim: int,
43
+ pos: Union[np.ndarray, int],
44
+ theta: float = 10000.0,
45
+ use_real=False,
46
+ linear_factor=1.0,
47
+ ntk_factor=1.0,
48
+ repeat_interleave_real=True,
49
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
50
+ ):
51
+ """
52
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
53
+
54
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
55
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
56
+ data type.
57
+
58
+ Args:
59
+ dim (`int`): Dimension of the frequency tensor.
60
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
61
+ theta (`float`, *optional*, defaults to 10000.0):
62
+ Scaling factor for frequency computation. Defaults to 10000.0.
63
+ use_real (`bool`, *optional*):
64
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
65
+ linear_factor (`float`, *optional*, defaults to 1.0):
66
+ Scaling factor for the context extrapolation. Defaults to 1.0.
67
+ ntk_factor (`float`, *optional*, defaults to 1.0):
68
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
69
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
70
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
71
+ Otherwise, they are concateanted with themselves.
72
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
73
+ the dtype of the frequency tensor.
74
+ Returns:
75
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
76
+ """
77
+ assert dim % 2 == 0
78
+
79
+ if isinstance(pos, int):
80
+ pos = torch.arange(pos)
81
+ if isinstance(pos, np.ndarray):
82
+ pos = torch.from_numpy(pos) # type: ignore # [S]
83
+
84
+ theta = theta * ntk_factor
85
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
86
+ freqs = freqs.to(pos.device)
87
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
88
+ if use_real and repeat_interleave_real:
89
+ # flux, hunyuan-dit, cogvideox
90
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
91
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
92
+ return freqs_cos, freqs_sin
93
+ elif use_real:
94
+ # stable audio
95
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
96
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
97
+ return freqs_cos, freqs_sin
98
+ else:
99
+ # lumina
100
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
101
+ return freqs_cis
102
+
103
+ class FluxPosEmbed(nn.Module):
104
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
105
+ def __init__(self, theta: int, axes_dim: List[int]):
106
+ super().__init__()
107
+ self.theta = theta
108
+ self.axes_dim = axes_dim
109
+
110
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
111
+ n_axes = ids.shape[-1]
112
+ cos_out = []
113
+ sin_out = []
114
+ pos = ids.squeeze().float()
115
+ is_mps = ids.device.type == "mps"
116
+ freqs_dtype = torch.float32 if is_mps else torch.float64
117
+ for i in range(n_axes):
118
+ cos, sin = get_1d_rotary_pos_embed(
119
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
120
+ )
121
+ cos_out.append(cos)
122
+ sin_out.append(sin)
123
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
124
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
125
+ return freqs_cos, freqs_sin
126
+
127
+
128
+ @dataclass
129
+ class FluxControlNetOutput(BaseOutput):
130
+ controlnet_block_samples: Tuple[torch.Tensor]
131
+ controlnet_single_block_samples: Tuple[torch.Tensor]
132
+
133
+
134
+ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
135
+ _supports_gradient_checkpointing = True
136
+
137
+ @register_to_config
138
+ def __init__(
139
+ self,
140
+ patch_size: int = 1,
141
+ in_channels: int = 64,
142
+ num_layers: int = 19,
143
+ num_single_layers: int = 38,
144
+ attention_head_dim: int = 128,
145
+ num_attention_heads: int = 24,
146
+ joint_attention_dim: int = 4096,
147
+ pooled_projection_dim: int = 768,
148
+ guidance_embeds: bool = False,
149
+ axes_dims_rope: List[int] = [16, 56, 56],
150
+ num_mode: int = None,
151
+ ):
152
+ super().__init__()
153
+ self.out_channels = in_channels
154
+ self.inner_dim = num_attention_heads * attention_head_dim
155
+
156
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
157
+ text_time_guidance_cls = (
158
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
159
+ )
160
+ self.time_text_embed = text_time_guidance_cls(
161
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
162
+ )
163
+
164
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
165
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
166
+
167
+ self.transformer_blocks = nn.ModuleList(
168
+ [
169
+ FluxTransformerBlock(
170
+ dim=self.inner_dim,
171
+ num_attention_heads=num_attention_heads,
172
+ attention_head_dim=attention_head_dim,
173
+ )
174
+ for i in range(num_layers)
175
+ ]
176
+ )
177
+
178
+ self.single_transformer_blocks = nn.ModuleList(
179
+ [
180
+ FluxSingleTransformerBlock(
181
+ dim=self.inner_dim,
182
+ num_attention_heads=num_attention_heads,
183
+ attention_head_dim=attention_head_dim,
184
+ )
185
+ for i in range(num_single_layers)
186
+ ]
187
+ )
188
+
189
+ # controlnet_blocks
190
+ self.controlnet_blocks = nn.ModuleList([])
191
+ for _ in range(len(self.transformer_blocks)):
192
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
193
+
194
+ self.controlnet_single_blocks = nn.ModuleList([])
195
+ for _ in range(len(self.single_transformer_blocks)):
196
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
197
+
198
+ self.union = num_mode is not None
199
+ if self.union:
200
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
201
+
202
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
203
+
204
+ self.gradient_checkpointing = False
205
+
206
+ @property
207
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
208
+ def attn_processors(self):
209
+ r"""
210
+ Returns:
211
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
212
+ indexed by its weight name.
213
+ """
214
+ # set recursively
215
+ processors = {}
216
+
217
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
218
+ if hasattr(module, "get_processor"):
219
+ processors[f"{name}.processor"] = module.get_processor()
220
+
221
+ for sub_name, child in module.named_children():
222
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
223
+
224
+ return processors
225
+
226
+ for name, module in self.named_children():
227
+ fn_recursive_add_processors(name, module, processors)
228
+
229
+ return processors
230
+
231
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
232
+ def set_attn_processor(self, processor):
233
+ r"""
234
+ Sets the attention processor to use to compute attention.
235
+
236
+ Parameters:
237
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
238
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
239
+ for **all** `Attention` layers.
240
+
241
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
242
+ processor. This is strongly recommended when setting trainable attention processors.
243
+
244
+ """
245
+ count = len(self.attn_processors.keys())
246
+
247
+ if isinstance(processor, dict) and len(processor) != count:
248
+ raise ValueError(
249
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
250
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
251
+ )
252
+
253
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
254
+ if hasattr(module, "set_processor"):
255
+ if not isinstance(processor, dict):
256
+ module.set_processor(processor)
257
+ else:
258
+ module.set_processor(processor.pop(f"{name}.processor"))
259
+
260
+ for sub_name, child in module.named_children():
261
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
262
+
263
+ for name, module in self.named_children():
264
+ fn_recursive_attn_processor(name, module, processor)
265
+
266
+ def _set_gradient_checkpointing(self, module, value=False):
267
+ if hasattr(module, "gradient_checkpointing"):
268
+ module.gradient_checkpointing = value
269
+
270
+ @classmethod
271
+ def from_transformer(
272
+ cls,
273
+ transformer,
274
+ num_layers: int = 4,
275
+ num_single_layers: int = 10,
276
+ attention_head_dim: int = 128,
277
+ num_attention_heads: int = 24,
278
+ load_weights_from_transformer=True,
279
+ ):
280
+ config = transformer.config
281
+ config["num_layers"] = num_layers
282
+ config["num_single_layers"] = num_single_layers
283
+ config["attention_head_dim"] = attention_head_dim
284
+ config["num_attention_heads"] = num_attention_heads
285
+
286
+ controlnet = cls(**config)
287
+
288
+ if load_weights_from_transformer:
289
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
290
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
291
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
292
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
293
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
294
+ controlnet.single_transformer_blocks.load_state_dict(
295
+ transformer.single_transformer_blocks.state_dict(), strict=False
296
+ )
297
+
298
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
299
+
300
+ return controlnet
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ controlnet_cond: torch.Tensor,
306
+ controlnet_mode: torch.Tensor = None,
307
+ conditioning_scale: float = 1.0,
308
+ encoder_hidden_states: torch.Tensor = None,
309
+ t5_encoder_hidden_states: torch.Tensor = None,
310
+ pooled_projections: torch.Tensor = None,
311
+ timestep: torch.LongTensor = None,
312
+ img_ids: torch.Tensor = None,
313
+ txt_ids: torch.Tensor = None,
314
+ guidance: torch.Tensor = None,
315
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
316
+ return_dict: bool = True,
317
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
318
+ """
319
+ The [`FluxTransformer2DModel`] forward method.
320
+
321
+ Args:
322
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
323
+ Input `hidden_states`.
324
+ controlnet_cond (`torch.Tensor`):
325
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
326
+ controlnet_mode (`torch.Tensor`):
327
+ The mode tensor of shape `(batch_size, 1)`.
328
+ conditioning_scale (`float`, defaults to `1.0`):
329
+ The scale factor for ControlNet outputs.
330
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
331
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
332
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
333
+ from the embeddings of input conditions.
334
+ timestep ( `torch.LongTensor`):
335
+ Used to indicate denoising step.
336
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
337
+ A list of tensors that if specified are added to the residuals of transformer blocks.
338
+ joint_attention_kwargs (`dict`, *optional*):
339
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
340
+ `self.processor` in
341
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
342
+ return_dict (`bool`, *optional*, defaults to `True`):
343
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
344
+ tuple.
345
+
346
+ Returns:
347
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
348
+ `tuple` where the first element is the sample tensor.
349
+ """
350
+ if joint_attention_kwargs is not None:
351
+ joint_attention_kwargs = joint_attention_kwargs.copy()
352
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
353
+ else:
354
+ lora_scale = 1.0
355
+
356
+ if USE_PEFT_BACKEND:
357
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
358
+ scale_lora_layers(self, lora_scale)
359
+ else:
360
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
361
+ logger.warning(
362
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
363
+ )
364
+ hidden_states = self.x_embedder(hidden_states)
365
+
366
+ # add
367
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
368
+
369
+ timestep = timestep.to(hidden_states.dtype) * 1000
370
+ if guidance is not None:
371
+ guidance = guidance.to(hidden_states.dtype) * 1000
372
+ else:
373
+ guidance = None
374
+ temb = (
375
+ self.time_text_embed(timestep, pooled_projections)
376
+ if guidance is None
377
+ else self.time_text_embed(timestep, guidance, pooled_projections)
378
+ )
379
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
380
+ if t5_encoder_hidden_states is not None:
381
+ encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
382
+
383
+ if txt_ids.ndim == 3:
384
+ logger.warning(
385
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
386
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
387
+ )
388
+ txt_ids = txt_ids[0]
389
+
390
+ if self.union:
391
+ # union mode
392
+ if controlnet_mode is None:
393
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
394
+ # union mode emb
395
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
396
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
397
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
398
+
399
+ if img_ids.ndim == 3:
400
+ logger.warning(
401
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
402
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
403
+ )
404
+ img_ids = img_ids[0]
405
+
406
+ ids = torch.cat((txt_ids, img_ids), dim=0)
407
+ image_rotary_emb = self.pos_embed(ids)
408
+
409
+ block_samples = ()
410
+ for index_block, block in enumerate(self.transformer_blocks):
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
423
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
424
+ create_custom_forward(block),
425
+ hidden_states,
426
+ encoder_hidden_states,
427
+ temb,
428
+ image_rotary_emb,
429
+ **ckpt_kwargs,
430
+ )
431
+
432
+ else:
433
+ encoder_hidden_states, hidden_states = block(
434
+ hidden_states=hidden_states,
435
+ encoder_hidden_states=encoder_hidden_states,
436
+ temb=temb,
437
+ image_rotary_emb=image_rotary_emb,
438
+ )
439
+ block_samples = block_samples + (hidden_states,)
440
+
441
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
442
+
443
+ single_block_samples = ()
444
+ for index_block, block in enumerate(self.single_transformer_blocks):
445
+ if self.training and self.gradient_checkpointing:
446
+
447
+ def create_custom_forward(module, return_dict=None):
448
+ def custom_forward(*inputs):
449
+ if return_dict is not None:
450
+ return module(*inputs, return_dict=return_dict)
451
+ else:
452
+ return module(*inputs)
453
+
454
+ return custom_forward
455
+
456
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
457
+ hidden_states = torch.utils.checkpoint.checkpoint(
458
+ create_custom_forward(block),
459
+ hidden_states,
460
+ temb,
461
+ image_rotary_emb,
462
+ **ckpt_kwargs,
463
+ )
464
+
465
+ else:
466
+ hidden_states = block(
467
+ hidden_states=hidden_states,
468
+ temb=temb,
469
+ image_rotary_emb=image_rotary_emb,
470
+ )
471
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
472
+
473
+ # controlnet block
474
+ controlnet_block_samples = ()
475
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
476
+ block_sample = controlnet_block(block_sample)
477
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
478
+
479
+ controlnet_single_block_samples = ()
480
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
481
+ single_block_sample = controlnet_block(single_block_sample)
482
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
483
+
484
+ # scaling
485
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
486
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
487
+
488
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
489
+ controlnet_single_block_samples = (
490
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
491
+ )
492
+
493
+ if USE_PEFT_BACKEND:
494
+ # remove `lora_scale` from each PEFT layer
495
+ unscale_lora_layers(self, lora_scale)
496
+
497
+ if not return_dict:
498
+ return (controlnet_block_samples, controlnet_single_block_samples)
499
+
500
+ return FluxControlNetOutput(
501
+ controlnet_block_samples=controlnet_block_samples,
502
+ controlnet_single_block_samples=controlnet_single_block_samples,
503
+ )
504
+
505
+
506
+ class FluxMultiControlNetModel(ModelMixin):
507
+ r"""
508
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
509
+
510
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
511
+ compatible with `FluxControlNetModel`.
512
+
513
+ Args:
514
+ controlnets (`List[FluxControlNetModel]`):
515
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
516
+ `FluxControlNetModel` as a list.
517
+ """
518
+
519
+ def __init__(self, controlnets):
520
+ super().__init__()
521
+ self.nets = nn.ModuleList(controlnets)
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.FloatTensor,
526
+ controlnet_cond: List[torch.tensor],
527
+ controlnet_mode: List[torch.tensor],
528
+ conditioning_scale: List[float],
529
+ encoder_hidden_states: torch.Tensor = None,
530
+ t5_encoder_hidden_states: torch.Tensor = None,
531
+ pooled_projections: torch.Tensor = None,
532
+ timestep: torch.LongTensor = None,
533
+ img_ids: torch.Tensor = None,
534
+ txt_ids: torch.Tensor = None,
535
+ guidance: torch.Tensor = None,
536
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
537
+ return_dict: bool = True,
538
+ ) -> Union[FluxControlNetOutput, Tuple]:
539
+ # ControlNet-Union with multiple conditions
540
+ # only load one ControlNet for saving memories
541
+ if len(self.nets) == 1 and self.nets[0].union:
542
+ controlnet = self.nets[0]
543
+
544
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
545
+ block_samples, single_block_samples = controlnet(
546
+ hidden_states=hidden_states,
547
+ controlnet_cond=image,
548
+ controlnet_mode=mode[:, None],
549
+ conditioning_scale=scale,
550
+ timestep=timestep,
551
+ guidance=guidance,
552
+ pooled_projections=pooled_projections,
553
+ encoder_hidden_states=encoder_hidden_states,
554
+ t5_encoder_hidden_states=t5_encoder_hidden_states,
555
+ txt_ids=txt_ids,
556
+ img_ids=img_ids,
557
+ joint_attention_kwargs=joint_attention_kwargs,
558
+ return_dict=return_dict,
559
+ )
560
+
561
+ # merge samples
562
+ if i == 0:
563
+ control_block_samples = block_samples
564
+ control_single_block_samples = single_block_samples
565
+ else:
566
+ control_block_samples = [
567
+ control_block_sample + block_sample
568
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
569
+ ]
570
+
571
+ control_single_block_samples = [
572
+ control_single_block_sample + block_sample
573
+ for control_single_block_sample, block_sample in zip(
574
+ control_single_block_samples, single_block_samples
575
+ )
576
+ ]
577
+
578
+ # Regular Multi-ControlNets
579
+ # load all ControlNets into memories
580
+ else:
581
+ for i, (image, mode, scale, controlnet) in enumerate(
582
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
583
+ ):
584
+ block_samples, single_block_samples = controlnet(
585
+ hidden_states=hidden_states,
586
+ controlnet_cond=image,
587
+ controlnet_mode=mode[:, None],
588
+ conditioning_scale=scale,
589
+ timestep=timestep,
590
+ guidance=guidance,
591
+ pooled_projections=pooled_projections,
592
+ encoder_hidden_states=encoder_hidden_states,
593
+ txt_ids=txt_ids,
594
+ img_ids=img_ids,
595
+ joint_attention_kwargs=joint_attention_kwargs,
596
+ return_dict=return_dict,
597
+ )
598
+
599
+ # merge samples
600
+ if i == 0:
601
+ control_block_samples = block_samples
602
+ control_single_block_samples = single_block_samples
603
+ else:
604
+ if block_samples is not None and control_block_samples is not None:
605
+ control_block_samples = [
606
+ control_block_sample + block_sample
607
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
608
+ ]
609
+ if single_block_samples is not None and control_single_block_samples is not None:
610
+ control_single_block_samples = [
611
+ control_single_block_sample + block_sample
612
+ for control_single_block_sample, block_sample in zip(
613
+ control_single_block_samples, single_block_samples
614
+ )
615
+ ]
616
+
617
+ return control_block_samples, control_single_block_samples
flux/embeddings.py ADDED
@@ -0,0 +1,1469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.utils import deprecate
23
+ from .activations import FP32SiLU, get_activation
24
+ from .attention_processor import Attention
25
+
26
+
27
+ def get_timestep_embedding(
28
+ timesteps: torch.Tensor,
29
+ embedding_dim: int,
30
+ flip_sin_to_cos: bool = False,
31
+ downscale_freq_shift: float = 1,
32
+ scale: float = 1,
33
+ max_period: int = 10000,
34
+ ):
35
+ """
36
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
+
38
+ Args
39
+ timesteps (torch.Tensor):
40
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
41
+ embedding_dim (int):
42
+ the dimension of the output.
43
+ flip_sin_to_cos (bool):
44
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
45
+ downscale_freq_shift (float):
46
+ Controls the delta between frequencies between dimensions
47
+ scale (float):
48
+ Scaling factor applied to the embeddings.
49
+ max_period (int):
50
+ Controls the maximum frequency of the embeddings
51
+ Returns
52
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
53
+ """
54
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
55
+
56
+ half_dim = embedding_dim // 2
57
+ exponent = -math.log(max_period) * torch.arange(
58
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
59
+ )
60
+ exponent = exponent / (half_dim - downscale_freq_shift)
61
+
62
+ emb = torch.exp(exponent)
63
+ emb = timesteps[:, None].float() * emb[None, :]
64
+
65
+ # scale embeddings
66
+ emb = scale * emb
67
+
68
+ # concat sine and cosine embeddings
69
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
70
+
71
+ # flip sine and cosine embeddings
72
+ if flip_sin_to_cos:
73
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
74
+
75
+ # zero pad
76
+ if embedding_dim % 2 == 1:
77
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
78
+ return emb
79
+
80
+
81
+ def get_2d_sincos_pos_embed(
82
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
83
+ ):
84
+ """
85
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
86
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
87
+ """
88
+ if isinstance(grid_size, int):
89
+ grid_size = (grid_size, grid_size)
90
+
91
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
92
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
93
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
94
+ grid = np.stack(grid, axis=0)
95
+
96
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
97
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
98
+ if cls_token and extra_tokens > 0:
99
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
100
+ return pos_embed
101
+
102
+
103
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
104
+ if embed_dim % 2 != 0:
105
+ raise ValueError("embed_dim must be divisible by 2")
106
+
107
+ # use half of dimensions to encode grid_h
108
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
109
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
110
+
111
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
112
+ return emb
113
+
114
+
115
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
116
+ """
117
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
118
+ """
119
+ if embed_dim % 2 != 0:
120
+ raise ValueError("embed_dim must be divisible by 2")
121
+
122
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
123
+ omega /= embed_dim / 2.0
124
+ omega = 1.0 / 10000**omega # (D/2,)
125
+
126
+ pos = pos.reshape(-1) # (M,)
127
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
128
+
129
+ emb_sin = np.sin(out) # (M, D/2)
130
+ emb_cos = np.cos(out) # (M, D/2)
131
+
132
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
133
+ return emb
134
+
135
+
136
+ class PatchEmbed(nn.Module):
137
+ """2D Image to Patch Embedding with support for SD3 cropping."""
138
+
139
+ def __init__(
140
+ self,
141
+ height=224,
142
+ width=224,
143
+ patch_size=16,
144
+ in_channels=3,
145
+ embed_dim=768,
146
+ layer_norm=False,
147
+ flatten=True,
148
+ bias=True,
149
+ interpolation_scale=1,
150
+ pos_embed_type="sincos",
151
+ pos_embed_max_size=None, # For SD3 cropping
152
+ ):
153
+ super().__init__()
154
+
155
+ num_patches = (height // patch_size) * (width // patch_size)
156
+ self.flatten = flatten
157
+ self.layer_norm = layer_norm
158
+ self.pos_embed_max_size = pos_embed_max_size
159
+
160
+ self.proj = nn.Conv2d(
161
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
162
+ )
163
+ if layer_norm:
164
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
165
+ else:
166
+ self.norm = None
167
+
168
+ self.patch_size = patch_size
169
+ self.height, self.width = height // patch_size, width // patch_size
170
+ self.base_size = height // patch_size
171
+ self.interpolation_scale = interpolation_scale
172
+
173
+ # Calculate positional embeddings based on max size or default
174
+ if pos_embed_max_size:
175
+ grid_size = pos_embed_max_size
176
+ else:
177
+ grid_size = int(num_patches**0.5)
178
+
179
+ if pos_embed_type is None:
180
+ self.pos_embed = None
181
+ elif pos_embed_type == "sincos":
182
+ pos_embed = get_2d_sincos_pos_embed(
183
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
184
+ )
185
+ persistent = True if pos_embed_max_size else False
186
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
187
+ else:
188
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
189
+
190
+ def cropped_pos_embed(self, height, width):
191
+ """Crops positional embeddings for SD3 compatibility."""
192
+ if self.pos_embed_max_size is None:
193
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
194
+
195
+ height = height // self.patch_size
196
+ width = width // self.patch_size
197
+ if height > self.pos_embed_max_size:
198
+ raise ValueError(
199
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
200
+ )
201
+ if width > self.pos_embed_max_size:
202
+ raise ValueError(
203
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
204
+ )
205
+
206
+ top = (self.pos_embed_max_size - height) // 2
207
+ left = (self.pos_embed_max_size - width) // 2
208
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
209
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
210
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
211
+ return spatial_pos_embed
212
+
213
+ def forward(self, latent):
214
+ if self.pos_embed_max_size is not None:
215
+ height, width = latent.shape[-2:]
216
+ else:
217
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
218
+
219
+ latent = self.proj(latent)
220
+ if self.flatten:
221
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
222
+ if self.layer_norm:
223
+ latent = self.norm(latent)
224
+ if self.pos_embed is None:
225
+ return latent.to(latent.dtype)
226
+ # Interpolate or crop positional embeddings as needed
227
+ if self.pos_embed_max_size:
228
+ pos_embed = self.cropped_pos_embed(height, width)
229
+ else:
230
+ if self.height != height or self.width != width:
231
+ pos_embed = get_2d_sincos_pos_embed(
232
+ embed_dim=self.pos_embed.shape[-1],
233
+ grid_size=(height, width),
234
+ base_size=self.base_size,
235
+ interpolation_scale=self.interpolation_scale,
236
+ )
237
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
238
+ else:
239
+ pos_embed = self.pos_embed
240
+
241
+ return (latent + pos_embed).to(latent.dtype)
242
+
243
+
244
+ class LuminaPatchEmbed(nn.Module):
245
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
246
+
247
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
248
+ super().__init__()
249
+ self.patch_size = patch_size
250
+ self.proj = nn.Linear(
251
+ in_features=patch_size * patch_size * in_channels,
252
+ out_features=embed_dim,
253
+ bias=bias,
254
+ )
255
+
256
+ def forward(self, x, freqs_cis):
257
+ """
258
+ Patchifies and embeds the input tensor(s).
259
+
260
+ Args:
261
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
262
+
263
+ Returns:
264
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
265
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
266
+ frequency tensor(s).
267
+ """
268
+ freqs_cis = freqs_cis.to(x[0].device)
269
+ patch_height = patch_width = self.patch_size
270
+ batch_size, channel, height, width = x.size()
271
+ height_tokens, width_tokens = height // patch_height, width // patch_width
272
+
273
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
274
+ 0, 2, 4, 1, 3, 5
275
+ )
276
+ x = x.flatten(3)
277
+ x = self.proj(x)
278
+ x = x.flatten(1, 2)
279
+
280
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
281
+
282
+ return (
283
+ x,
284
+ mask,
285
+ [(height, width)] * batch_size,
286
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
287
+ )
288
+
289
+
290
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
291
+ """
292
+ RoPE for image tokens with 2d structure.
293
+
294
+ Args:
295
+ embed_dim: (`int`):
296
+ The embedding dimension size
297
+ crops_coords (`Tuple[int]`)
298
+ The top-left and bottom-right coordinates of the crop.
299
+ grid_size (`Tuple[int]`):
300
+ The grid size of the positional embedding.
301
+ use_real (`bool`):
302
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
303
+
304
+ Returns:
305
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
306
+ """
307
+ start, stop = crops_coords
308
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
309
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
310
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
311
+ grid = np.stack(grid, axis=0) # [2, W, H]
312
+
313
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
314
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
315
+ return pos_embed
316
+
317
+
318
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
319
+ assert embed_dim % 4 == 0
320
+
321
+ # use half of dimensions to encode grid_h
322
+ emb_h = get_1d_rotary_pos_embed(
323
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
324
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
325
+ emb_w = get_1d_rotary_pos_embed(
326
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
327
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
328
+
329
+ if use_real:
330
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
331
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
332
+ return cos, sin
333
+ else:
334
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
335
+ return emb
336
+
337
+
338
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
339
+ assert embed_dim % 4 == 0
340
+
341
+ emb_h = get_1d_rotary_pos_embed(
342
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
343
+ ) # (H, D/4)
344
+ emb_w = get_1d_rotary_pos_embed(
345
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
346
+ ) # (W, D/4)
347
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
348
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
349
+
350
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
351
+ return emb
352
+
353
+
354
+ def get_1d_rotary_pos_embed(
355
+ dim: int,
356
+ pos: Union[np.ndarray, int],
357
+ theta: float = 10000.0,
358
+ use_real=False,
359
+ linear_factor=1.0,
360
+ ntk_factor=1.0,
361
+ repeat_interleave_real=True,
362
+ ):
363
+ """
364
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
365
+
366
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
367
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
368
+ data type.
369
+
370
+ Args:
371
+ dim (`int`): Dimension of the frequency tensor.
372
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
373
+ theta (`float`, *optional*, defaults to 10000.0):
374
+ Scaling factor for frequency computation. Defaults to 10000.0.
375
+ use_real (`bool`, *optional*):
376
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
377
+ linear_factor (`float`, *optional*, defaults to 1.0):
378
+ Scaling factor for the context extrapolation. Defaults to 1.0.
379
+ ntk_factor (`float`, *optional*, defaults to 1.0):
380
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
381
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
382
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
383
+ Otherwise, they are concateanted with themselves.
384
+ Returns:
385
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
386
+ """
387
+ assert dim % 2 == 0
388
+
389
+ if isinstance(pos, int):
390
+ pos = np.arange(pos)
391
+ theta = theta * ntk_factor
392
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
393
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
394
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
395
+ if use_real and repeat_interleave_real:
396
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
397
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
398
+ return freqs_cos, freqs_sin
399
+ elif use_real:
400
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
401
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
402
+ return freqs_cos, freqs_sin
403
+ else:
404
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
405
+ return freqs_cis
406
+
407
+
408
+ def apply_rotary_emb(
409
+ x: torch.Tensor,
410
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
411
+ use_real: bool = True,
412
+ use_real_unbind_dim: int = -1,
413
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
414
+ """
415
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
416
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
417
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
418
+ tensors contain rotary embeddings and are returned as real tensors.
419
+
420
+ Args:
421
+ x (`torch.Tensor`):
422
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
423
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
424
+
425
+ Returns:
426
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
427
+ """
428
+ if use_real:
429
+ cos, sin = freqs_cis # [S, D]
430
+ cos = cos[None, None]
431
+ sin = sin[None, None]
432
+ cos, sin = cos.to(x.device), sin.to(x.device)
433
+
434
+ if use_real_unbind_dim == -1:
435
+ # Use for example in Lumina
436
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
437
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
438
+ elif use_real_unbind_dim == -2:
439
+ # Use for example in Stable Audio
440
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
441
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
442
+ else:
443
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
444
+
445
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
446
+
447
+ return out
448
+ else:
449
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
450
+ freqs_cis = freqs_cis.unsqueeze(2)
451
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
452
+
453
+ return x_out.type_as(x)
454
+
455
+
456
+ class TimestepEmbedding(nn.Module):
457
+ def __init__(
458
+ self,
459
+ in_channels: int,
460
+ time_embed_dim: int,
461
+ act_fn: str = "silu",
462
+ out_dim: int = None,
463
+ post_act_fn: Optional[str] = None,
464
+ cond_proj_dim=None,
465
+ sample_proj_bias=True,
466
+ ):
467
+ super().__init__()
468
+
469
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
470
+
471
+ if cond_proj_dim is not None:
472
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
473
+ else:
474
+ self.cond_proj = None
475
+
476
+ self.act = get_activation(act_fn)
477
+
478
+ if out_dim is not None:
479
+ time_embed_dim_out = out_dim
480
+ else:
481
+ time_embed_dim_out = time_embed_dim
482
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
483
+
484
+ if post_act_fn is None:
485
+ self.post_act = None
486
+ else:
487
+ self.post_act = get_activation(post_act_fn)
488
+
489
+ def forward(self, sample, condition=None):
490
+ if condition is not None:
491
+ sample = sample + self.cond_proj(condition)
492
+ sample = self.linear_1(sample)
493
+
494
+ if self.act is not None:
495
+ sample = self.act(sample)
496
+
497
+ sample = self.linear_2(sample)
498
+
499
+ if self.post_act is not None:
500
+ sample = self.post_act(sample)
501
+ return sample
502
+
503
+
504
+ class Timesteps(nn.Module):
505
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
506
+ super().__init__()
507
+ self.num_channels = num_channels
508
+ self.flip_sin_to_cos = flip_sin_to_cos
509
+ self.downscale_freq_shift = downscale_freq_shift
510
+ self.scale = scale
511
+
512
+ def forward(self, timesteps):
513
+ t_emb = get_timestep_embedding(
514
+ timesteps,
515
+ self.num_channels,
516
+ flip_sin_to_cos=self.flip_sin_to_cos,
517
+ downscale_freq_shift=self.downscale_freq_shift,
518
+ scale=self.scale,
519
+ )
520
+ return t_emb
521
+
522
+
523
+ class GaussianFourierProjection(nn.Module):
524
+ """Gaussian Fourier embeddings for noise levels."""
525
+
526
+ def __init__(
527
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
528
+ ):
529
+ super().__init__()
530
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
531
+ self.log = log
532
+ self.flip_sin_to_cos = flip_sin_to_cos
533
+
534
+ if set_W_to_weight:
535
+ # to delete later
536
+ del self.weight
537
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
538
+ self.weight = self.W
539
+ del self.W
540
+
541
+ def forward(self, x):
542
+ if self.log:
543
+ x = torch.log(x)
544
+
545
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
546
+
547
+ if self.flip_sin_to_cos:
548
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
549
+ else:
550
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
551
+ return out
552
+
553
+
554
+ class SinusoidalPositionalEmbedding(nn.Module):
555
+ """Apply positional information to a sequence of embeddings.
556
+
557
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
558
+ them
559
+
560
+ Args:
561
+ embed_dim: (int): Dimension of the positional embedding.
562
+ max_seq_length: Maximum sequence length to apply positional embeddings
563
+
564
+ """
565
+
566
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
567
+ super().__init__()
568
+ position = torch.arange(max_seq_length).unsqueeze(1)
569
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
570
+ pe = torch.zeros(1, max_seq_length, embed_dim)
571
+ pe[0, :, 0::2] = torch.sin(position * div_term)
572
+ pe[0, :, 1::2] = torch.cos(position * div_term)
573
+ self.register_buffer("pe", pe)
574
+
575
+ def forward(self, x):
576
+ _, seq_length, _ = x.shape
577
+ x = x + self.pe[:, :seq_length]
578
+ return x
579
+
580
+
581
+ class ImagePositionalEmbeddings(nn.Module):
582
+ """
583
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
584
+ height and width of the latent space.
585
+
586
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
587
+
588
+ For VQ-diffusion:
589
+
590
+ Output vector embeddings are used as input for the transformer.
591
+
592
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
593
+
594
+ Args:
595
+ num_embed (`int`):
596
+ Number of embeddings for the latent pixels embeddings.
597
+ height (`int`):
598
+ Height of the latent image i.e. the number of height embeddings.
599
+ width (`int`):
600
+ Width of the latent image i.e. the number of width embeddings.
601
+ embed_dim (`int`):
602
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
603
+ """
604
+
605
+ def __init__(
606
+ self,
607
+ num_embed: int,
608
+ height: int,
609
+ width: int,
610
+ embed_dim: int,
611
+ ):
612
+ super().__init__()
613
+
614
+ self.height = height
615
+ self.width = width
616
+ self.num_embed = num_embed
617
+ self.embed_dim = embed_dim
618
+
619
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
620
+ self.height_emb = nn.Embedding(self.height, embed_dim)
621
+ self.width_emb = nn.Embedding(self.width, embed_dim)
622
+
623
+ def forward(self, index):
624
+ emb = self.emb(index)
625
+
626
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
627
+
628
+ # 1 x H x D -> 1 x H x 1 x D
629
+ height_emb = height_emb.unsqueeze(2)
630
+
631
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
632
+
633
+ # 1 x W x D -> 1 x 1 x W x D
634
+ width_emb = width_emb.unsqueeze(1)
635
+
636
+ pos_emb = height_emb + width_emb
637
+
638
+ # 1 x H x W x D -> 1 x L xD
639
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
640
+
641
+ emb = emb + pos_emb[:, : emb.shape[1], :]
642
+
643
+ return emb
644
+
645
+
646
+ class LabelEmbedding(nn.Module):
647
+ """
648
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
649
+
650
+ Args:
651
+ num_classes (`int`): The number of classes.
652
+ hidden_size (`int`): The size of the vector embeddings.
653
+ dropout_prob (`float`): The probability of dropping a label.
654
+ """
655
+
656
+ def __init__(self, num_classes, hidden_size, dropout_prob):
657
+ super().__init__()
658
+ use_cfg_embedding = dropout_prob > 0
659
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
660
+ self.num_classes = num_classes
661
+ self.dropout_prob = dropout_prob
662
+
663
+ def token_drop(self, labels, force_drop_ids=None):
664
+ """
665
+ Drops labels to enable classifier-free guidance.
666
+ """
667
+ if force_drop_ids is None:
668
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
669
+ else:
670
+ drop_ids = torch.tensor(force_drop_ids == 1)
671
+ labels = torch.where(drop_ids, self.num_classes, labels)
672
+ return labels
673
+
674
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
675
+ use_dropout = self.dropout_prob > 0
676
+ if (self.training and use_dropout) or (force_drop_ids is not None):
677
+ labels = self.token_drop(labels, force_drop_ids)
678
+ embeddings = self.embedding_table(labels)
679
+ return embeddings
680
+
681
+
682
+ class TextImageProjection(nn.Module):
683
+ def __init__(
684
+ self,
685
+ text_embed_dim: int = 1024,
686
+ image_embed_dim: int = 768,
687
+ cross_attention_dim: int = 768,
688
+ num_image_text_embeds: int = 10,
689
+ ):
690
+ super().__init__()
691
+
692
+ self.num_image_text_embeds = num_image_text_embeds
693
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
694
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
695
+
696
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
697
+ batch_size = text_embeds.shape[0]
698
+
699
+ # image
700
+ image_text_embeds = self.image_embeds(image_embeds)
701
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
702
+
703
+ # text
704
+ text_embeds = self.text_proj(text_embeds)
705
+
706
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
707
+
708
+
709
+ class ImageProjection(nn.Module):
710
+ def __init__(
711
+ self,
712
+ image_embed_dim: int = 768,
713
+ cross_attention_dim: int = 768,
714
+ num_image_text_embeds: int = 32,
715
+ ):
716
+ super().__init__()
717
+
718
+ self.num_image_text_embeds = num_image_text_embeds
719
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
720
+ self.norm = nn.LayerNorm(cross_attention_dim)
721
+
722
+ def forward(self, image_embeds: torch.Tensor):
723
+ batch_size = image_embeds.shape[0]
724
+
725
+ # image
726
+ image_embeds = self.image_embeds(image_embeds)
727
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
728
+ image_embeds = self.norm(image_embeds)
729
+ return image_embeds
730
+
731
+
732
+ class IPAdapterFullImageProjection(nn.Module):
733
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
734
+ super().__init__()
735
+ from .attention import FeedForward
736
+
737
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
738
+ self.norm = nn.LayerNorm(cross_attention_dim)
739
+
740
+ def forward(self, image_embeds: torch.Tensor):
741
+ return self.norm(self.ff(image_embeds))
742
+
743
+
744
+ class IPAdapterFaceIDImageProjection(nn.Module):
745
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
746
+ super().__init__()
747
+ from .attention import FeedForward
748
+
749
+ self.num_tokens = num_tokens
750
+ self.cross_attention_dim = cross_attention_dim
751
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
752
+ self.norm = nn.LayerNorm(cross_attention_dim)
753
+
754
+ def forward(self, image_embeds: torch.Tensor):
755
+ x = self.ff(image_embeds)
756
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
757
+ return self.norm(x)
758
+
759
+
760
+ class CombinedTimestepLabelEmbeddings(nn.Module):
761
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
762
+ super().__init__()
763
+
764
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
765
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
766
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
767
+
768
+ def forward(self, timestep, class_labels, hidden_dtype=None):
769
+ timesteps_proj = self.time_proj(timestep)
770
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
771
+
772
+ class_labels = self.class_embedder(class_labels) # (N, D)
773
+
774
+ conditioning = timesteps_emb + class_labels # (N, D)
775
+
776
+ return conditioning
777
+
778
+
779
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
780
+ def __init__(self, embedding_dim, pooled_projection_dim):
781
+ super().__init__()
782
+
783
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
784
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
785
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
786
+
787
+ def forward(self, timestep, pooled_projection):
788
+ timesteps_proj = self.time_proj(timestep)
789
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
790
+
791
+ pooled_projections = self.text_embedder(pooled_projection)
792
+
793
+ conditioning = timesteps_emb + pooled_projections
794
+
795
+ return conditioning
796
+
797
+
798
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
799
+ def __init__(self, embedding_dim, pooled_projection_dim):
800
+ super().__init__()
801
+
802
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
803
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
804
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
805
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
806
+
807
+ def forward(self, timestep, guidance, pooled_projection):
808
+ timesteps_proj = self.time_proj(timestep)
809
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
810
+
811
+ guidance_proj = self.time_proj(guidance)
812
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
813
+
814
+ time_guidance_emb = timesteps_emb + guidance_emb
815
+
816
+ pooled_projections = self.text_embedder(pooled_projection)
817
+ conditioning = time_guidance_emb + pooled_projections
818
+
819
+ return conditioning
820
+
821
+
822
+ class HunyuanDiTAttentionPool(nn.Module):
823
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
824
+
825
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
826
+ super().__init__()
827
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
828
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
829
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
830
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
831
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
832
+ self.num_heads = num_heads
833
+
834
+ def forward(self, x):
835
+ x = x.permute(1, 0, 2) # NLC -> LNC
836
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
837
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
838
+ x, _ = F.multi_head_attention_forward(
839
+ query=x[:1],
840
+ key=x,
841
+ value=x,
842
+ embed_dim_to_check=x.shape[-1],
843
+ num_heads=self.num_heads,
844
+ q_proj_weight=self.q_proj.weight,
845
+ k_proj_weight=self.k_proj.weight,
846
+ v_proj_weight=self.v_proj.weight,
847
+ in_proj_weight=None,
848
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
849
+ bias_k=None,
850
+ bias_v=None,
851
+ add_zero_attn=False,
852
+ dropout_p=0,
853
+ out_proj_weight=self.c_proj.weight,
854
+ out_proj_bias=self.c_proj.bias,
855
+ use_separate_proj_weight=True,
856
+ training=self.training,
857
+ need_weights=False,
858
+ )
859
+ return x.squeeze(0)
860
+
861
+
862
+ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
863
+ def __init__(
864
+ self,
865
+ embedding_dim,
866
+ pooled_projection_dim=1024,
867
+ seq_len=256,
868
+ cross_attention_dim=2048,
869
+ use_style_cond_and_image_meta_size=True,
870
+ ):
871
+ super().__init__()
872
+
873
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
874
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
875
+
876
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
877
+
878
+ self.pooler = HunyuanDiTAttentionPool(
879
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
880
+ )
881
+
882
+ # Here we use a default learned embedder layer for future extension.
883
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
884
+ if use_style_cond_and_image_meta_size:
885
+ self.style_embedder = nn.Embedding(1, embedding_dim)
886
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
887
+ else:
888
+ extra_in_dim = pooled_projection_dim
889
+
890
+ self.extra_embedder = PixArtAlphaTextProjection(
891
+ in_features=extra_in_dim,
892
+ hidden_size=embedding_dim * 4,
893
+ out_features=embedding_dim,
894
+ act_fn="silu_fp32",
895
+ )
896
+
897
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
898
+ timesteps_proj = self.time_proj(timestep)
899
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
900
+
901
+ # extra condition1: text
902
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
903
+
904
+ if self.use_style_cond_and_image_meta_size:
905
+ # extra condition2: image meta size embedding
906
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
907
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
908
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
909
+
910
+ # extra condition3: style embedding
911
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
912
+
913
+ # Concatenate all extra vectors
914
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
915
+ else:
916
+ extra_cond = torch.cat([pooled_projections], dim=1)
917
+
918
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
919
+
920
+ return conditioning
921
+
922
+
923
+ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
924
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
925
+ super().__init__()
926
+ self.time_proj = Timesteps(
927
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
928
+ )
929
+
930
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
931
+
932
+ self.caption_embedder = nn.Sequential(
933
+ nn.LayerNorm(cross_attention_dim),
934
+ nn.Linear(
935
+ cross_attention_dim,
936
+ hidden_size,
937
+ bias=True,
938
+ ),
939
+ )
940
+
941
+ def forward(self, timestep, caption_feat, caption_mask):
942
+ # timestep embedding:
943
+ time_freq = self.time_proj(timestep)
944
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
945
+
946
+ # caption condition embedding:
947
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
948
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
949
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
950
+ caption_embed = self.caption_embedder(caption_feats_pool)
951
+
952
+ conditioning = time_embed + caption_embed
953
+
954
+ return conditioning
955
+
956
+
957
+ class TextTimeEmbedding(nn.Module):
958
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
959
+ super().__init__()
960
+ self.norm1 = nn.LayerNorm(encoder_dim)
961
+ self.pool = AttentionPooling(num_heads, encoder_dim)
962
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
963
+ self.norm2 = nn.LayerNorm(time_embed_dim)
964
+
965
+ def forward(self, hidden_states):
966
+ hidden_states = self.norm1(hidden_states)
967
+ hidden_states = self.pool(hidden_states)
968
+ hidden_states = self.proj(hidden_states)
969
+ hidden_states = self.norm2(hidden_states)
970
+ return hidden_states
971
+
972
+
973
+ class TextImageTimeEmbedding(nn.Module):
974
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
975
+ super().__init__()
976
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
977
+ self.text_norm = nn.LayerNorm(time_embed_dim)
978
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
979
+
980
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
981
+ # text
982
+ time_text_embeds = self.text_proj(text_embeds)
983
+ time_text_embeds = self.text_norm(time_text_embeds)
984
+
985
+ # image
986
+ time_image_embeds = self.image_proj(image_embeds)
987
+
988
+ return time_image_embeds + time_text_embeds
989
+
990
+
991
+ class ImageTimeEmbedding(nn.Module):
992
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
993
+ super().__init__()
994
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
995
+ self.image_norm = nn.LayerNorm(time_embed_dim)
996
+
997
+ def forward(self, image_embeds: torch.Tensor):
998
+ # image
999
+ time_image_embeds = self.image_proj(image_embeds)
1000
+ time_image_embeds = self.image_norm(time_image_embeds)
1001
+ return time_image_embeds
1002
+
1003
+
1004
+ class ImageHintTimeEmbedding(nn.Module):
1005
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
1006
+ super().__init__()
1007
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
1008
+ self.image_norm = nn.LayerNorm(time_embed_dim)
1009
+ self.input_hint_block = nn.Sequential(
1010
+ nn.Conv2d(3, 16, 3, padding=1),
1011
+ nn.SiLU(),
1012
+ nn.Conv2d(16, 16, 3, padding=1),
1013
+ nn.SiLU(),
1014
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
1015
+ nn.SiLU(),
1016
+ nn.Conv2d(32, 32, 3, padding=1),
1017
+ nn.SiLU(),
1018
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
1019
+ nn.SiLU(),
1020
+ nn.Conv2d(96, 96, 3, padding=1),
1021
+ nn.SiLU(),
1022
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
1023
+ nn.SiLU(),
1024
+ nn.Conv2d(256, 4, 3, padding=1),
1025
+ )
1026
+
1027
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
1028
+ # image
1029
+ time_image_embeds = self.image_proj(image_embeds)
1030
+ time_image_embeds = self.image_norm(time_image_embeds)
1031
+ hint = self.input_hint_block(hint)
1032
+ return time_image_embeds, hint
1033
+
1034
+
1035
+ class AttentionPooling(nn.Module):
1036
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
1037
+
1038
+ def __init__(self, num_heads, embed_dim, dtype=None):
1039
+ super().__init__()
1040
+ self.dtype = dtype
1041
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
1042
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1043
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1044
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
1045
+ self.num_heads = num_heads
1046
+ self.dim_per_head = embed_dim // self.num_heads
1047
+
1048
+ def forward(self, x):
1049
+ bs, length, width = x.size()
1050
+
1051
+ def shape(x):
1052
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
1053
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
1054
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
1055
+ x = x.transpose(1, 2)
1056
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
1057
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
1058
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
1059
+ x = x.transpose(1, 2)
1060
+ return x
1061
+
1062
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
1063
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
1064
+
1065
+ # (bs*n_heads, class_token_length, dim_per_head)
1066
+ q = shape(self.q_proj(class_token))
1067
+ # (bs*n_heads, length+class_token_length, dim_per_head)
1068
+ k = shape(self.k_proj(x))
1069
+ v = shape(self.v_proj(x))
1070
+
1071
+ # (bs*n_heads, class_token_length, length+class_token_length):
1072
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
1073
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
1074
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
1075
+
1076
+ # (bs*n_heads, dim_per_head, class_token_length)
1077
+ a = torch.einsum("bts,bcs->bct", weight, v)
1078
+
1079
+ # (bs, length+1, width)
1080
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
1081
+
1082
+ return a[:, 0, :] # cls_token
1083
+
1084
+
1085
+ def get_fourier_embeds_from_boundingbox(embed_dim, box):
1086
+ """
1087
+ Args:
1088
+ embed_dim: int
1089
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
1090
+ Returns:
1091
+ [B x N x embed_dim] tensor of positional embeddings
1092
+ """
1093
+
1094
+ batch_size, num_boxes = box.shape[:2]
1095
+
1096
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
1097
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
1098
+ emb = emb * box.unsqueeze(-1)
1099
+
1100
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
1101
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
1102
+
1103
+ return emb
1104
+
1105
+
1106
+ class GLIGENTextBoundingboxProjection(nn.Module):
1107
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
1108
+ super().__init__()
1109
+ self.positive_len = positive_len
1110
+ self.out_dim = out_dim
1111
+
1112
+ self.fourier_embedder_dim = fourier_freqs
1113
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
1114
+
1115
+ if isinstance(out_dim, tuple):
1116
+ out_dim = out_dim[0]
1117
+
1118
+ if feature_type == "text-only":
1119
+ self.linears = nn.Sequential(
1120
+ nn.Linear(self.positive_len + self.position_dim, 512),
1121
+ nn.SiLU(),
1122
+ nn.Linear(512, 512),
1123
+ nn.SiLU(),
1124
+ nn.Linear(512, out_dim),
1125
+ )
1126
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
1127
+
1128
+ elif feature_type == "text-image":
1129
+ self.linears_text = nn.Sequential(
1130
+ nn.Linear(self.positive_len + self.position_dim, 512),
1131
+ nn.SiLU(),
1132
+ nn.Linear(512, 512),
1133
+ nn.SiLU(),
1134
+ nn.Linear(512, out_dim),
1135
+ )
1136
+ self.linears_image = nn.Sequential(
1137
+ nn.Linear(self.positive_len + self.position_dim, 512),
1138
+ nn.SiLU(),
1139
+ nn.Linear(512, 512),
1140
+ nn.SiLU(),
1141
+ nn.Linear(512, out_dim),
1142
+ )
1143
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
1144
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
1145
+
1146
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
1147
+
1148
+ def forward(
1149
+ self,
1150
+ boxes,
1151
+ masks,
1152
+ positive_embeddings=None,
1153
+ phrases_masks=None,
1154
+ image_masks=None,
1155
+ phrases_embeddings=None,
1156
+ image_embeddings=None,
1157
+ ):
1158
+ masks = masks.unsqueeze(-1)
1159
+
1160
+ # embedding position (it may includes padding as placeholder)
1161
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
1162
+
1163
+ # learnable null embedding
1164
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
1165
+
1166
+ # replace padding with learnable null embedding
1167
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
1168
+
1169
+ # positionet with text only information
1170
+ if positive_embeddings is not None:
1171
+ # learnable null embedding
1172
+ positive_null = self.null_positive_feature.view(1, 1, -1)
1173
+
1174
+ # replace padding with learnable null embedding
1175
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
1176
+
1177
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
1178
+
1179
+ # positionet with text and image information
1180
+ else:
1181
+ phrases_masks = phrases_masks.unsqueeze(-1)
1182
+ image_masks = image_masks.unsqueeze(-1)
1183
+
1184
+ # learnable null embedding
1185
+ text_null = self.null_text_feature.view(1, 1, -1)
1186
+ image_null = self.null_image_feature.view(1, 1, -1)
1187
+
1188
+ # replace padding with learnable null embedding
1189
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
1190
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
1191
+
1192
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
1193
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
1194
+ objs = torch.cat([objs_text, objs_image], dim=1)
1195
+
1196
+ return objs
1197
+
1198
+
1199
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
1200
+ """
1201
+ For PixArt-Alpha.
1202
+
1203
+ Reference:
1204
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
1205
+ """
1206
+
1207
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
1208
+ super().__init__()
1209
+
1210
+ self.outdim = size_emb_dim
1211
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1212
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
1213
+
1214
+ self.use_additional_conditions = use_additional_conditions
1215
+ if use_additional_conditions:
1216
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1217
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1218
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1219
+
1220
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
1221
+ timesteps_proj = self.time_proj(timestep)
1222
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
1223
+
1224
+ if self.use_additional_conditions:
1225
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
1226
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
1227
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
1228
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
1229
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
1230
+ else:
1231
+ conditioning = timesteps_emb
1232
+
1233
+ return conditioning
1234
+
1235
+
1236
+ class PixArtAlphaTextProjection(nn.Module):
1237
+ """
1238
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
1239
+
1240
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
1241
+ """
1242
+
1243
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
1244
+ super().__init__()
1245
+ if out_features is None:
1246
+ out_features = hidden_size
1247
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
1248
+ if act_fn == "gelu_tanh":
1249
+ self.act_1 = nn.GELU(approximate="tanh")
1250
+ elif act_fn == "silu":
1251
+ self.act_1 = nn.SiLU()
1252
+ elif act_fn == "silu_fp32":
1253
+ self.act_1 = FP32SiLU()
1254
+ else:
1255
+ raise ValueError(f"Unknown activation function: {act_fn}")
1256
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
1257
+
1258
+ def forward(self, caption):
1259
+ hidden_states = self.linear_1(caption)
1260
+ hidden_states = self.act_1(hidden_states)
1261
+ hidden_states = self.linear_2(hidden_states)
1262
+ return hidden_states
1263
+
1264
+
1265
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
1266
+ def __init__(
1267
+ self,
1268
+ embed_dims: int = 768,
1269
+ dim_head: int = 64,
1270
+ heads: int = 16,
1271
+ ffn_ratio: float = 4,
1272
+ ) -> None:
1273
+ super().__init__()
1274
+ from .attention import FeedForward
1275
+
1276
+ self.ln0 = nn.LayerNorm(embed_dims)
1277
+ self.ln1 = nn.LayerNorm(embed_dims)
1278
+ self.attn = Attention(
1279
+ query_dim=embed_dims,
1280
+ dim_head=dim_head,
1281
+ heads=heads,
1282
+ out_bias=False,
1283
+ )
1284
+ self.ff = nn.Sequential(
1285
+ nn.LayerNorm(embed_dims),
1286
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1287
+ )
1288
+
1289
+ def forward(self, x, latents, residual):
1290
+ encoder_hidden_states = self.ln0(x)
1291
+ latents = self.ln1(latents)
1292
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1293
+ latents = self.attn(latents, encoder_hidden_states) + residual
1294
+ latents = self.ff(latents) + latents
1295
+ return latents
1296
+
1297
+
1298
+ class IPAdapterPlusImageProjection(nn.Module):
1299
+ """Resampler of IP-Adapter Plus.
1300
+
1301
+ Args:
1302
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1303
+ that is the same
1304
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1305
+ hidden_dims (int):
1306
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1307
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1308
+ Defaults to 16. num_queries (int):
1309
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
1310
+ of feedforward network hidden
1311
+ layer channels. Defaults to 4.
1312
+ """
1313
+
1314
+ def __init__(
1315
+ self,
1316
+ embed_dims: int = 768,
1317
+ output_dims: int = 1024,
1318
+ hidden_dims: int = 1280,
1319
+ depth: int = 4,
1320
+ dim_head: int = 64,
1321
+ heads: int = 16,
1322
+ num_queries: int = 8,
1323
+ ffn_ratio: float = 4,
1324
+ ) -> None:
1325
+ super().__init__()
1326
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
1327
+
1328
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
1329
+
1330
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
1331
+ self.norm_out = nn.LayerNorm(output_dims)
1332
+
1333
+ self.layers = nn.ModuleList(
1334
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
1335
+ )
1336
+
1337
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1338
+ """Forward pass.
1339
+
1340
+ Args:
1341
+ x (torch.Tensor): Input Tensor.
1342
+ Returns:
1343
+ torch.Tensor: Output Tensor.
1344
+ """
1345
+ latents = self.latents.repeat(x.size(0), 1, 1)
1346
+
1347
+ x = self.proj_in(x)
1348
+
1349
+ for block in self.layers:
1350
+ residual = latents
1351
+ latents = block(x, latents, residual)
1352
+
1353
+ latents = self.proj_out(latents)
1354
+ return self.norm_out(latents)
1355
+
1356
+
1357
+ class IPAdapterFaceIDPlusImageProjection(nn.Module):
1358
+ """FacePerceiverResampler of IP-Adapter Plus.
1359
+
1360
+ Args:
1361
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1362
+ that is the same
1363
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1364
+ hidden_dims (int):
1365
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1366
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1367
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
1368
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
1369
+ layer channels. Defaults to 4.
1370
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
1371
+ layer channels (for ID embeddings). Defaults to 4.
1372
+ """
1373
+
1374
+ def __init__(
1375
+ self,
1376
+ embed_dims: int = 768,
1377
+ output_dims: int = 768,
1378
+ hidden_dims: int = 1280,
1379
+ id_embeddings_dim: int = 512,
1380
+ depth: int = 4,
1381
+ dim_head: int = 64,
1382
+ heads: int = 16,
1383
+ num_tokens: int = 4,
1384
+ num_queries: int = 8,
1385
+ ffn_ratio: float = 4,
1386
+ ffproj_ratio: int = 2,
1387
+ ) -> None:
1388
+ super().__init__()
1389
+ from .attention import FeedForward
1390
+
1391
+ self.num_tokens = num_tokens
1392
+ self.embed_dim = embed_dims
1393
+ self.clip_embeds = None
1394
+ self.shortcut = False
1395
+ self.shortcut_scale = 1.0
1396
+
1397
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
1398
+ self.norm = nn.LayerNorm(embed_dims)
1399
+
1400
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
1401
+
1402
+ self.proj_out = nn.Linear(embed_dims, output_dims)
1403
+ self.norm_out = nn.LayerNorm(output_dims)
1404
+
1405
+ self.layers = nn.ModuleList(
1406
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
1407
+ )
1408
+
1409
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
1410
+ """Forward pass.
1411
+
1412
+ Args:
1413
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
1414
+ Returns:
1415
+ torch.Tensor: Output Tensor.
1416
+ """
1417
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
1418
+ id_embeds = self.proj(id_embeds)
1419
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
1420
+ id_embeds = self.norm(id_embeds)
1421
+ latents = id_embeds
1422
+
1423
+ clip_embeds = self.proj_in(self.clip_embeds)
1424
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
1425
+
1426
+ for block in self.layers:
1427
+ residual = latents
1428
+ latents = block(x, latents, residual)
1429
+
1430
+ latents = self.proj_out(latents)
1431
+ out = self.norm_out(latents)
1432
+ if self.shortcut:
1433
+ out = id_embeds + self.shortcut_scale * out
1434
+ return out
1435
+
1436
+
1437
+ class MultiIPAdapterImageProjection(nn.Module):
1438
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
1439
+ super().__init__()
1440
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
1441
+
1442
+ def forward(self, image_embeds: List[torch.Tensor]):
1443
+ projected_image_embeds = []
1444
+
1445
+ # currently, we accept `image_embeds` as
1446
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
1447
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
1448
+ if not isinstance(image_embeds, list):
1449
+ deprecation_message = (
1450
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
1451
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
1452
+ )
1453
+ deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
1454
+ image_embeds = [image_embeds.unsqueeze(1)]
1455
+
1456
+ if len(image_embeds) != len(self.image_projection_layers):
1457
+ raise ValueError(
1458
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
1459
+ )
1460
+
1461
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
1462
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
1463
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
1464
+ image_embed = image_projection_layer(image_embed)
1465
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
1466
+
1467
+ projected_image_embeds.append(image_embed)
1468
+
1469
+ return projected_image_embeds
flux/flux_network.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .transformer_flux import FluxTransformer2DModel
4
+
5
+ class FluxNetwork(nn.Module):
6
+ TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型
7
+ FLUX_PREFIX = "flux"
8
+
9
+ def __init__(self, flux_model: FluxTransformer2DModel):
10
+ super().__init__()
11
+ self.flux_model = flux_model
12
+ self.trainable_component_names = [] # 用于记录可训练组件的名称
13
+
14
+ @staticmethod
15
+ def generate_trainable_components(layers, num_transformer_blocks=19):
16
+ transformer_components = [
17
+ "attn.to_q",
18
+ "attn.to_k",
19
+ "attn.to_v",
20
+ "attn.to_out",
21
+ "norm1",
22
+ "norm1_context",
23
+ ]
24
+
25
+ single_transformer_components = [
26
+ "attn.to_q",
27
+ "attn.to_k",
28
+ "attn.to_v",
29
+ "norm",
30
+ #"proj_mlp",
31
+ ]
32
+
33
+ components = ["context_embedder"] # 添加 context_embedder
34
+ for layer in layers:
35
+ if layer < num_transformer_blocks:
36
+ prefix = f"transformer_blocks.{layer}"
37
+ base_components = transformer_components
38
+ else:
39
+ prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
40
+ base_components = single_transformer_components
41
+ components.extend([f"{prefix}.{comp}" for comp in base_components])
42
+
43
+ return components
44
+
45
+ #def apply_to(self, num_layers=1, additional_components=None):
46
+ # component_names = self.generate_trainable_components(num_layers)
47
+ #
48
+ # if additional_components:
49
+ # component_names.extend(additional_components)
50
+ #
51
+ # self.trainable_component_names = [] # 重置
52
+ # for name in component_names:
53
+ # recursive_getattr(self.flux_model, name).requires_grad_(True)
54
+ # self.trainable_component_names.append(name) # 记录名称
55
+
56
+ #def apply_to(self, num_layers=1, additional_components=None):
57
+ # component_names = self.generate_trainable_components(num_layers)
58
+ #
59
+ # if additional_components:
60
+ # component_names.extend(additional_components)
61
+ #
62
+ # self.trainable_component_names = [] # 重置
63
+ # for name in component_names:
64
+ # component = recursive_getattr(self.flux_model, name)
65
+ # if isinstance(component, nn.Module):
66
+ # component.requires_grad_(True)
67
+ # self.trainable_component_names.append(name)
68
+ # else:
69
+ # print(f"Warning: {name} is not a Module, skipping.")
70
+
71
+ def apply_to(self, layers=None, additional_components=None):
72
+ if layers is None:
73
+ layers = list(range(57)) # 默认包含所有层
74
+
75
+ component_names = self.generate_trainable_components(layers)
76
+
77
+ if additional_components:
78
+ component_names.extend(additional_components)
79
+
80
+ self.trainable_component_names = [] # 重置
81
+ for name in component_names:
82
+ try:
83
+ component = recursive_getattr(self.flux_model, name)
84
+ if isinstance(component, nn.Module):
85
+ component.requires_grad_(True)
86
+ self.trainable_component_names.append(name)
87
+ else:
88
+ print(f"Warning: {name} is not a Module, skipping.")
89
+ except AttributeError:
90
+ print(f"Warning: {name} not found in the model, skipping.")
91
+
92
+ def prepare_grad_etc(self):
93
+ # 供flux_model调用,用于冻结/解冻组件
94
+ self.flux_model.requires_grad_(False)
95
+ for name in self.trainable_component_names:
96
+ recursive_getattr(self.flux_model, name).requires_grad_(True)
97
+
98
+ def get_trainable_params(self):
99
+ # 返回需要训练的参数
100
+ params = []
101
+ for name in self.trainable_component_names:
102
+ params.extend(recursive_getattr(self.flux_model, name).parameters())
103
+ return params
104
+
105
+ def print_trainable_params_info(self):
106
+ total_params = 0
107
+ for name in self.trainable_component_names:
108
+ module = recursive_getattr(self.flux_model, name)
109
+ module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
110
+ total_params += module_params
111
+ #print(f'{name}: {module_params} trainable parameters')
112
+ print(f'Total trainable params: {total_params}')
113
+
114
+ def save_weights(self, file, dtype=None):
115
+ # 保存需要训练的组件参数
116
+ state_dict = {}
117
+ for name in self.trainable_component_names:
118
+ state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
119
+ if dtype is not None:
120
+ for v in state_dict.values():
121
+ v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}
122
+ torch.save(state_dict, file)
123
+
124
+ #def load_weights(self, file):
125
+ # # 加载需要训练的组件参数
126
+ # state_dict = torch.load(file, weights_only=True)
127
+ # for name in state_dict:
128
+ # module = recursive_getattr(self.flux_model, name)
129
+ # module.load_state_dict(state_dict[name])
130
+ # print(f"加载参数: {name}")
131
+
132
+ def load_weights(self, file, device):
133
+ print(f"Loading weights from {file}")
134
+ try:
135
+ state_dict = torch.load(file, map_location=device, weights_only=True)
136
+ except Exception as e:
137
+ print(f"Failed to load weights from {file}: {str(e)}")
138
+ return False
139
+
140
+ successfully_loaded = []
141
+ failed_to_load = []
142
+
143
+ for name in state_dict:
144
+ try:
145
+ module = recursive_getattr(self.flux_model, name)
146
+ module_state_dict = module.state_dict()
147
+
148
+ # 检查state_dict的键是否匹配
149
+ if set(state_dict[name].keys()) != set(module_state_dict.keys()):
150
+ raise ValueError(f"State dict keys for {name} do not match")
151
+
152
+ # 检查张量的形状是否匹配
153
+ for key in state_dict[name]:
154
+ if state_dict[name][key].shape != module_state_dict[key].shape:
155
+ raise ValueError(f"Shape mismatch for {name}.{key}")
156
+
157
+ module.load_state_dict(state_dict[name])
158
+ successfully_loaded.append(name)
159
+
160
+ except Exception as e:
161
+ print(f"Failed to load weights for {name}: {str(e)}")
162
+ failed_to_load.append(name)
163
+
164
+ if successfully_loaded:
165
+ print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
166
+ if failed_to_load:
167
+ print(f"Failed to load weights for: {', '.join(failed_to_load)}")
168
+
169
+ return len(failed_to_load) == 0 # 如果没有加载失败的组件,则返回True
170
+
171
+ # 改进的递归获取属性函数
172
+ def recursive_getattr(obj, attr):
173
+ attrs = attr.split(".")
174
+ for i in range(len(attrs)):
175
+ obj = getattr(obj, attrs[i])
176
+ return obj
177
+
178
+ # 递归设置属性函数
179
+ def recursive_setattr(obj, attr, val):
180
+ attrs = attr.split(".")
181
+ for i in range(len(attrs)-1):
182
+ obj = getattr(obj, attrs[i])
183
+ setattr(obj, attrs[-1], val)
flux/lora/lora_base.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Callable, Dict, List, Optional, Union
20
+
21
+ import safetensors
22
+ import torch
23
+ import torch.nn as nn
24
+ from huggingface_hub import model_info
25
+ from huggingface_hub.constants import HF_HUB_OFFLINE
26
+
27
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ _get_model_file,
31
+ delete_adapter_layers,
32
+ deprecate,
33
+ is_accelerate_available,
34
+ is_peft_available,
35
+ is_transformers_available,
36
+ logging,
37
+ recurse_remove_peft_layers,
38
+ set_adapter_layers,
39
+ set_weights_and_activate_adapters,
40
+ )
41
+
42
+
43
+ if is_transformers_available():
44
+ from transformers import PreTrainedModel
45
+
46
+ if is_peft_available():
47
+ from peft.tuners.tuners_utils import BaseTunerLayer
48
+
49
+ if is_accelerate_available():
50
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
56
+ """
57
+ Fuses LoRAs for the text encoder.
58
+
59
+ Args:
60
+ text_encoder (`torch.nn.Module`):
61
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
62
+ attribute.
63
+ lora_scale (`float`, defaults to 1.0):
64
+ Controls how much to influence the outputs with the LoRA parameters.
65
+ safe_fusing (`bool`, defaults to `False`):
66
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
67
+ adapter_names (`List[str]` or `str`):
68
+ The names of the adapters to use.
69
+ """
70
+ merge_kwargs = {"safe_merge": safe_fusing}
71
+
72
+ for module in text_encoder.modules():
73
+ if isinstance(module, BaseTunerLayer):
74
+ if lora_scale != 1.0:
75
+ module.scale_layer(lora_scale)
76
+
77
+ # For BC with previous PEFT versions, we need to check the signature
78
+ # of the `merge` method to see if it supports the `adapter_names` argument.
79
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
80
+ if "adapter_names" in supported_merge_kwargs:
81
+ merge_kwargs["adapter_names"] = adapter_names
82
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
83
+ raise ValueError(
84
+ "The `adapter_names` argument is not supported with your PEFT version. "
85
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
86
+ )
87
+
88
+ module.merge(**merge_kwargs)
89
+
90
+
91
+ def unfuse_text_encoder_lora(text_encoder):
92
+ """
93
+ Unfuses LoRAs for the text encoder.
94
+
95
+ Args:
96
+ text_encoder (`torch.nn.Module`):
97
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
98
+ attribute.
99
+ """
100
+ for module in text_encoder.modules():
101
+ if isinstance(module, BaseTunerLayer):
102
+ module.unmerge()
103
+
104
+
105
+ def set_adapters_for_text_encoder(
106
+ adapter_names: Union[List[str], str],
107
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
108
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
109
+ ):
110
+ """
111
+ Sets the adapter layers for the text encoder.
112
+
113
+ Args:
114
+ adapter_names (`List[str]` or `str`):
115
+ The names of the adapters to use.
116
+ text_encoder (`torch.nn.Module`, *optional*):
117
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
118
+ attribute.
119
+ text_encoder_weights (`List[float]`, *optional*):
120
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
121
+ """
122
+ if text_encoder is None:
123
+ raise ValueError(
124
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
125
+ )
126
+
127
+ def process_weights(adapter_names, weights):
128
+ # Expand weights into a list, one entry per adapter
129
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
130
+ if not isinstance(weights, list):
131
+ weights = [weights] * len(adapter_names)
132
+
133
+ if len(adapter_names) != len(weights):
134
+ raise ValueError(
135
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
136
+ )
137
+
138
+ # Set None values to default of 1.0
139
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
140
+ weights = [w if w is not None else 1.0 for w in weights]
141
+
142
+ return weights
143
+
144
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
145
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
146
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
147
+
148
+
149
+ def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
150
+ """
151
+ Disables the LoRA layers for the text encoder.
152
+
153
+ Args:
154
+ text_encoder (`torch.nn.Module`, *optional*):
155
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
156
+ attribute.
157
+ """
158
+ if text_encoder is None:
159
+ raise ValueError("Text Encoder not found.")
160
+ set_adapter_layers(text_encoder, enabled=False)
161
+
162
+
163
+ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
164
+ """
165
+ Enables the LoRA layers for the text encoder.
166
+
167
+ Args:
168
+ text_encoder (`torch.nn.Module`, *optional*):
169
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
170
+ attribute.
171
+ """
172
+ if text_encoder is None:
173
+ raise ValueError("Text Encoder not found.")
174
+ set_adapter_layers(text_encoder, enabled=True)
175
+
176
+
177
+ def _remove_text_encoder_monkey_patch(text_encoder):
178
+ recurse_remove_peft_layers(text_encoder)
179
+ if getattr(text_encoder, "peft_config", None) is not None:
180
+ del text_encoder.peft_config
181
+ text_encoder._hf_peft_config_loaded = None
182
+
183
+
184
+ class LoraBaseMixin:
185
+ """Utility class for handling LoRAs."""
186
+
187
+ _lora_loadable_modules = []
188
+ num_fused_loras = 0
189
+
190
+ def load_lora_weights(self, **kwargs):
191
+ raise NotImplementedError("`load_lora_weights()` is not implemented.")
192
+
193
+ @classmethod
194
+ def save_lora_weights(cls, **kwargs):
195
+ raise NotImplementedError("`save_lora_weights()` not implemented.")
196
+
197
+ @classmethod
198
+ def lora_state_dict(cls, **kwargs):
199
+ raise NotImplementedError("`lora_state_dict()` is not implemented.")
200
+
201
+ @classmethod
202
+ def _optionally_disable_offloading(cls, _pipeline):
203
+ """
204
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
205
+
206
+ Args:
207
+ _pipeline (`DiffusionPipeline`):
208
+ The pipeline to disable offloading for.
209
+
210
+ Returns:
211
+ tuple:
212
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
213
+ """
214
+ is_model_cpu_offload = False
215
+ is_sequential_cpu_offload = False
216
+
217
+ if _pipeline is not None and _pipeline.hf_device_map is None:
218
+ for _, component in _pipeline.components.items():
219
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
220
+ if not is_model_cpu_offload:
221
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
222
+ if not is_sequential_cpu_offload:
223
+ is_sequential_cpu_offload = (
224
+ isinstance(component._hf_hook, AlignDevicesHook)
225
+ or hasattr(component._hf_hook, "hooks")
226
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
227
+ )
228
+
229
+ logger.info(
230
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
231
+ )
232
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
233
+
234
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
235
+
236
+ @classmethod
237
+ def _fetch_state_dict(
238
+ cls,
239
+ pretrained_model_name_or_path_or_dict,
240
+ weight_name,
241
+ use_safetensors,
242
+ local_files_only,
243
+ cache_dir,
244
+ force_download,
245
+ proxies,
246
+ token,
247
+ revision,
248
+ subfolder,
249
+ user_agent,
250
+ allow_pickle,
251
+ ):
252
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
253
+
254
+ model_file = None
255
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
256
+ # Let's first try to load .safetensors weights
257
+ if (use_safetensors and weight_name is None) or (
258
+ weight_name is not None and weight_name.endswith(".safetensors")
259
+ ):
260
+ try:
261
+ # Here we're relaxing the loading check to enable more Inference API
262
+ # friendliness where sometimes, it's not at all possible to automatically
263
+ # determine `weight_name`.
264
+ if weight_name is None:
265
+ weight_name = cls._best_guess_weight_name(
266
+ pretrained_model_name_or_path_or_dict,
267
+ file_extension=".safetensors",
268
+ local_files_only=local_files_only,
269
+ )
270
+ model_file = _get_model_file(
271
+ pretrained_model_name_or_path_or_dict,
272
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
273
+ cache_dir=cache_dir,
274
+ force_download=force_download,
275
+ proxies=proxies,
276
+ local_files_only=local_files_only,
277
+ token=token,
278
+ revision=revision,
279
+ subfolder=subfolder,
280
+ user_agent=user_agent,
281
+ )
282
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
283
+ except (IOError, safetensors.SafetensorError) as e:
284
+ if not allow_pickle:
285
+ raise e
286
+ # try loading non-safetensors weights
287
+ model_file = None
288
+ pass
289
+
290
+ if model_file is None:
291
+ if weight_name is None:
292
+ weight_name = cls._best_guess_weight_name(
293
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
294
+ )
295
+ model_file = _get_model_file(
296
+ pretrained_model_name_or_path_or_dict,
297
+ weights_name=weight_name or LORA_WEIGHT_NAME,
298
+ cache_dir=cache_dir,
299
+ force_download=force_download,
300
+ proxies=proxies,
301
+ local_files_only=local_files_only,
302
+ token=token,
303
+ revision=revision,
304
+ subfolder=subfolder,
305
+ user_agent=user_agent,
306
+ )
307
+ state_dict = load_state_dict(model_file)
308
+ else:
309
+ state_dict = pretrained_model_name_or_path_or_dict
310
+
311
+ return state_dict
312
+
313
+ @classmethod
314
+ def _best_guess_weight_name(
315
+ cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
316
+ ):
317
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
318
+
319
+ if local_files_only or HF_HUB_OFFLINE:
320
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
321
+
322
+ targeted_files = []
323
+
324
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
325
+ return
326
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
327
+ targeted_files = [
328
+ f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
329
+ ]
330
+ else:
331
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
332
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
333
+ if len(targeted_files) == 0:
334
+ return
335
+
336
+ # "scheduler" does not correspond to a LoRA checkpoint.
337
+ # "optimizer" does not correspond to a LoRA checkpoint
338
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
339
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
340
+ targeted_files = list(
341
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
342
+ )
343
+
344
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
345
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
346
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
347
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
348
+
349
+ if len(targeted_files) > 1:
350
+ raise ValueError(
351
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
352
+ )
353
+ weight_name = targeted_files[0]
354
+ return weight_name
355
+
356
+ def unload_lora_weights(self):
357
+ """
358
+ Unloads the LoRA parameters.
359
+
360
+ Examples:
361
+
362
+ ```python
363
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
364
+ >>> pipeline.unload_lora_weights()
365
+ >>> ...
366
+ ```
367
+ """
368
+ if not USE_PEFT_BACKEND:
369
+ raise ValueError("PEFT backend is required for this method.")
370
+
371
+ for component in self._lora_loadable_modules:
372
+ model = getattr(self, component, None)
373
+ if model is not None:
374
+ if issubclass(model.__class__, ModelMixin):
375
+ model.unload_lora()
376
+ elif issubclass(model.__class__, PreTrainedModel):
377
+ _remove_text_encoder_monkey_patch(model)
378
+
379
+ def fuse_lora(
380
+ self,
381
+ components: List[str] = [],
382
+ lora_scale: float = 1.0,
383
+ safe_fusing: bool = False,
384
+ adapter_names: Optional[List[str]] = None,
385
+ **kwargs,
386
+ ):
387
+ r"""
388
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
389
+
390
+ <Tip warning={true}>
391
+
392
+ This is an experimental API.
393
+
394
+ </Tip>
395
+
396
+ Args:
397
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
398
+ lora_scale (`float`, defaults to 1.0):
399
+ Controls how much to influence the outputs with the LoRA parameters.
400
+ safe_fusing (`bool`, defaults to `False`):
401
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
402
+ adapter_names (`List[str]`, *optional*):
403
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
404
+
405
+ Example:
406
+
407
+ ```py
408
+ from diffusers import DiffusionPipeline
409
+ import torch
410
+
411
+ pipeline = DiffusionPipeline.from_pretrained(
412
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
413
+ ).to("cuda")
414
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
415
+ pipeline.fuse_lora(lora_scale=0.7)
416
+ ```
417
+ """
418
+ if "fuse_unet" in kwargs:
419
+ depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
420
+ deprecate(
421
+ "fuse_unet",
422
+ "1.0.0",
423
+ depr_message,
424
+ )
425
+ if "fuse_transformer" in kwargs:
426
+ depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
427
+ deprecate(
428
+ "fuse_transformer",
429
+ "1.0.0",
430
+ depr_message,
431
+ )
432
+ if "fuse_text_encoder" in kwargs:
433
+ depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
434
+ deprecate(
435
+ "fuse_text_encoder",
436
+ "1.0.0",
437
+ depr_message,
438
+ )
439
+
440
+ if len(components) == 0:
441
+ raise ValueError("`components` cannot be an empty list.")
442
+
443
+ for fuse_component in components:
444
+ if fuse_component not in self._lora_loadable_modules:
445
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
446
+
447
+ model = getattr(self, fuse_component, None)
448
+ if model is not None:
449
+ # check if diffusers model
450
+ if issubclass(model.__class__, ModelMixin):
451
+ model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
452
+ # handle transformers models.
453
+ if issubclass(model.__class__, PreTrainedModel):
454
+ fuse_text_encoder_lora(
455
+ model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
456
+ )
457
+
458
+ self.num_fused_loras += 1
459
+
460
+ def unfuse_lora(self, components: List[str] = [], **kwargs):
461
+ r"""
462
+ Reverses the effect of
463
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
464
+
465
+ <Tip warning={true}>
466
+
467
+ This is an experimental API.
468
+
469
+ </Tip>
470
+
471
+ Args:
472
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
473
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
474
+ unfuse_text_encoder (`bool`, defaults to `True`):
475
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
476
+ LoRA parameters then it won't have any effect.
477
+ """
478
+ if "unfuse_unet" in kwargs:
479
+ depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
480
+ deprecate(
481
+ "unfuse_unet",
482
+ "1.0.0",
483
+ depr_message,
484
+ )
485
+ if "unfuse_transformer" in kwargs:
486
+ depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
487
+ deprecate(
488
+ "unfuse_transformer",
489
+ "1.0.0",
490
+ depr_message,
491
+ )
492
+ if "unfuse_text_encoder" in kwargs:
493
+ depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
494
+ deprecate(
495
+ "unfuse_text_encoder",
496
+ "1.0.0",
497
+ depr_message,
498
+ )
499
+
500
+ if len(components) == 0:
501
+ raise ValueError("`components` cannot be an empty list.")
502
+
503
+ for fuse_component in components:
504
+ if fuse_component not in self._lora_loadable_modules:
505
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
506
+
507
+ model = getattr(self, fuse_component, None)
508
+ if model is not None:
509
+ if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
510
+ for module in model.modules():
511
+ if isinstance(module, BaseTunerLayer):
512
+ module.unmerge()
513
+
514
+ self.num_fused_loras -= 1
515
+
516
+ def set_adapters(
517
+ self,
518
+ adapter_names: Union[List[str], str],
519
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
520
+ ):
521
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
522
+
523
+ adapter_weights = copy.deepcopy(adapter_weights)
524
+
525
+ # Expand weights into a list, one entry per adapter
526
+ if not isinstance(adapter_weights, list):
527
+ adapter_weights = [adapter_weights] * len(adapter_names)
528
+
529
+ if len(adapter_names) != len(adapter_weights):
530
+ raise ValueError(
531
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
532
+ )
533
+
534
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
535
+ all_adapters = {
536
+ adapter for adapters in list_adapters.values() for adapter in adapters
537
+ } # eg ["adapter1", "adapter2"]
538
+ invert_list_adapters = {
539
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
540
+ for adapter in all_adapters
541
+ } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
542
+
543
+ # Decompose weights into weights for denoiser and text encoders.
544
+ _component_adapter_weights = {}
545
+ for component in self._lora_loadable_modules:
546
+ model = getattr(self, component)
547
+
548
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
549
+ if isinstance(weights, dict):
550
+ component_adapter_weights = weights.pop(component, None)
551
+
552
+ if component_adapter_weights is not None and not hasattr(self, component):
553
+ logger.warning(
554
+ f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
555
+ )
556
+
557
+ if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
558
+ logger.warning(
559
+ (
560
+ f"Lora weight dict for adapter '{adapter_name}' contains {component},"
561
+ f"but this will be ignored because {adapter_name} does not contain weights for {component}."
562
+ f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
563
+ )
564
+ )
565
+
566
+ else:
567
+ component_adapter_weights = weights
568
+
569
+ _component_adapter_weights.setdefault(component, [])
570
+ _component_adapter_weights[component].append(component_adapter_weights)
571
+
572
+ if issubclass(model.__class__, ModelMixin):
573
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
574
+ elif issubclass(model.__class__, PreTrainedModel):
575
+ set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
576
+
577
+ def disable_lora(self):
578
+ if not USE_PEFT_BACKEND:
579
+ raise ValueError("PEFT backend is required for this method.")
580
+
581
+ for component in self._lora_loadable_modules:
582
+ model = getattr(self, component, None)
583
+ if model is not None:
584
+ if issubclass(model.__class__, ModelMixin):
585
+ model.disable_lora()
586
+ elif issubclass(model.__class__, PreTrainedModel):
587
+ disable_lora_for_text_encoder(model)
588
+
589
+ def enable_lora(self):
590
+ if not USE_PEFT_BACKEND:
591
+ raise ValueError("PEFT backend is required for this method.")
592
+
593
+ for component in self._lora_loadable_modules:
594
+ model = getattr(self, component, None)
595
+ if model is not None:
596
+ if issubclass(model.__class__, ModelMixin):
597
+ model.enable_lora()
598
+ elif issubclass(model.__class__, PreTrainedModel):
599
+ enable_lora_for_text_encoder(model)
600
+
601
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
602
+ """
603
+ Args:
604
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
605
+ adapter_names (`Union[List[str], str]`):
606
+ The names of the adapter to delete. Can be a single string or a list of strings
607
+ """
608
+ if not USE_PEFT_BACKEND:
609
+ raise ValueError("PEFT backend is required for this method.")
610
+
611
+ if isinstance(adapter_names, str):
612
+ adapter_names = [adapter_names]
613
+
614
+ for component in self._lora_loadable_modules:
615
+ model = getattr(self, component, None)
616
+ if model is not None:
617
+ if issubclass(model.__class__, ModelMixin):
618
+ model.delete_adapters(adapter_names)
619
+ elif issubclass(model.__class__, PreTrainedModel):
620
+ for adapter_name in adapter_names:
621
+ delete_adapter_layers(model, adapter_name)
622
+
623
+ def get_active_adapters(self) -> List[str]:
624
+ """
625
+ Gets the list of the current active adapters.
626
+
627
+ Example:
628
+
629
+ ```python
630
+ from diffusers import DiffusionPipeline
631
+
632
+ pipeline = DiffusionPipeline.from_pretrained(
633
+ "stabilityai/stable-diffusion-xl-base-1.0",
634
+ ).to("cuda")
635
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
636
+ pipeline.get_active_adapters()
637
+ ```
638
+ """
639
+ if not USE_PEFT_BACKEND:
640
+ raise ValueError(
641
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
642
+ )
643
+
644
+ active_adapters = []
645
+
646
+ for component in self._lora_loadable_modules:
647
+ model = getattr(self, component, None)
648
+ if model is not None and issubclass(model.__class__, ModelMixin):
649
+ for module in model.modules():
650
+ if isinstance(module, BaseTunerLayer):
651
+ active_adapters = module.active_adapters
652
+ break
653
+
654
+ return active_adapters
655
+
656
+ def get_list_adapters(self) -> Dict[str, List[str]]:
657
+ """
658
+ Gets the current list of all available adapters in the pipeline.
659
+ """
660
+ if not USE_PEFT_BACKEND:
661
+ raise ValueError(
662
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
663
+ )
664
+
665
+ set_adapters = {}
666
+
667
+ for component in self._lora_loadable_modules:
668
+ model = getattr(self, component, None)
669
+ if (
670
+ model is not None
671
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
672
+ and hasattr(model, "peft_config")
673
+ ):
674
+ set_adapters[component] = list(model.peft_config.keys())
675
+
676
+ return set_adapters
677
+
678
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
679
+ """
680
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
681
+ you want to load multiple adapters and free some GPU memory.
682
+
683
+ Args:
684
+ adapter_names (`List[str]`):
685
+ List of adapters to send device to.
686
+ device (`Union[torch.device, str, int]`):
687
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
688
+ """
689
+ if not USE_PEFT_BACKEND:
690
+ raise ValueError("PEFT backend is required for this method.")
691
+
692
+ for component in self._lora_loadable_modules:
693
+ model = getattr(self, component, None)
694
+ if model is not None:
695
+ for module in model.modules():
696
+ if isinstance(module, BaseTunerLayer):
697
+ for adapter_name in adapter_names:
698
+ module.lora_A[adapter_name].to(device)
699
+ module.lora_B[adapter_name].to(device)
700
+ # this is a param, not a module, so device placement is not in-place -> re-assign
701
+ if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
702
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
703
+ adapter_name
704
+ ].to(device)
705
+
706
+ @staticmethod
707
+ def pack_weights(layers, prefix):
708
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
709
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
710
+ return layers_state_dict
711
+
712
+ @staticmethod
713
+ def write_lora_layers(
714
+ state_dict: Dict[str, torch.Tensor],
715
+ save_directory: str,
716
+ is_main_process: bool,
717
+ weight_name: str,
718
+ save_function: Callable,
719
+ safe_serialization: bool,
720
+ ):
721
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
722
+
723
+ if os.path.isfile(save_directory):
724
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
725
+ return
726
+
727
+ if save_function is None:
728
+ if safe_serialization:
729
+
730
+ def save_function(weights, filename):
731
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
732
+
733
+ else:
734
+ save_function = torch.save
735
+
736
+ os.makedirs(save_directory, exist_ok=True)
737
+
738
+ if weight_name is None:
739
+ if safe_serialization:
740
+ weight_name = LORA_WEIGHT_NAME_SAFE
741
+ else:
742
+ weight_name = LORA_WEIGHT_NAME
743
+
744
+ save_path = Path(save_directory, weight_name).as_posix()
745
+ save_function(state_dict, save_path)
746
+ logger.info(f"Model weights saved in {save_path}")
747
+
748
+ @property
749
+ def lora_scale(self) -> float:
750
+ # property function that returns the lora scale which can be set at run time by the pipeline.
751
+ # if _lora_scale has not been set, return 1
752
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
flux/lora/lora_conversion_utils.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ from diffusers.utils import is_peft_version, logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
24
+ # 1. get all state_dict_keys
25
+ all_keys = list(state_dict.keys())
26
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
27
+
28
+ # 2. check if needs remapping, if not return original dict
29
+ is_in_sgm_format = False
30
+ for key in all_keys:
31
+ if any(p in key for p in sgm_patterns):
32
+ is_in_sgm_format = True
33
+ break
34
+
35
+ if not is_in_sgm_format:
36
+ return state_dict
37
+
38
+ # 3. Else remap from SGM patterns
39
+ new_state_dict = {}
40
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
41
+
42
+ # Retrieves # of down, mid and up blocks
43
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
44
+
45
+ for layer in all_keys:
46
+ if "text" in layer:
47
+ new_state_dict[layer] = state_dict.pop(layer)
48
+ else:
49
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
50
+ if sgm_patterns[0] in layer:
51
+ input_block_ids.add(layer_id)
52
+ elif sgm_patterns[1] in layer:
53
+ middle_block_ids.add(layer_id)
54
+ elif sgm_patterns[2] in layer:
55
+ output_block_ids.add(layer_id)
56
+ else:
57
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
58
+
59
+ input_blocks = {
60
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
61
+ for layer_id in input_block_ids
62
+ }
63
+ middle_blocks = {
64
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
65
+ for layer_id in middle_block_ids
66
+ }
67
+ output_blocks = {
68
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
69
+ for layer_id in output_block_ids
70
+ }
71
+
72
+ # Rename keys accordingly
73
+ for i in input_block_ids:
74
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
75
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
76
+
77
+ for key in input_blocks[i]:
78
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
79
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
80
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
81
+ new_key = delimiter.join(
82
+ key.split(delimiter)[: block_slice_pos - 1]
83
+ + [str(block_id), inner_block_key, inner_layers_in_block]
84
+ + key.split(delimiter)[block_slice_pos + 1 :]
85
+ )
86
+ new_state_dict[new_key] = state_dict.pop(key)
87
+
88
+ for i in middle_block_ids:
89
+ key_part = None
90
+ if i == 0:
91
+ key_part = [inner_block_map[0], "0"]
92
+ elif i == 1:
93
+ key_part = [inner_block_map[1], "0"]
94
+ elif i == 2:
95
+ key_part = [inner_block_map[0], "1"]
96
+ else:
97
+ raise ValueError(f"Invalid middle block id {i}.")
98
+
99
+ for key in middle_blocks[i]:
100
+ new_key = delimiter.join(
101
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
102
+ )
103
+ new_state_dict[new_key] = state_dict.pop(key)
104
+
105
+ for i in output_block_ids:
106
+ block_id = i // (unet_config.layers_per_block + 1)
107
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
108
+
109
+ for key in output_blocks[i]:
110
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
111
+ inner_block_key = inner_block_map[inner_block_id]
112
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
113
+ new_key = delimiter.join(
114
+ key.split(delimiter)[: block_slice_pos - 1]
115
+ + [str(block_id), inner_block_key, inner_layers_in_block]
116
+ + key.split(delimiter)[block_slice_pos + 1 :]
117
+ )
118
+ new_state_dict[new_key] = state_dict.pop(key)
119
+
120
+ if len(state_dict) > 0:
121
+ raise ValueError("At this point all state dict entries have to be converted.")
122
+
123
+ return new_state_dict
124
+
125
+
126
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
127
+ """
128
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129
+
130
+ Args:
131
+ state_dict (`dict`): The state dict to convert.
132
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134
+ "text_encoder".
135
+
136
+ Returns:
137
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138
+ """
139
+ unet_state_dict = {}
140
+ te_state_dict = {}
141
+ te2_state_dict = {}
142
+ network_alphas = {}
143
+
144
+ # Check for DoRA-enabled LoRAs.
145
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
146
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
147
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
148
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
149
+ if is_peft_version("<", "0.9.0"):
150
+ raise ValueError(
151
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
152
+ )
153
+
154
+ # Iterate over all LoRA weights.
155
+ all_lora_keys = list(state_dict.keys())
156
+ for key in all_lora_keys:
157
+ if not key.endswith("lora_down.weight"):
158
+ continue
159
+
160
+ # Extract LoRA name.
161
+ lora_name = key.split(".")[0]
162
+
163
+ # Find corresponding up weight and alpha.
164
+ lora_name_up = lora_name + ".lora_up.weight"
165
+ lora_name_alpha = lora_name + ".alpha"
166
+
167
+ # Handle U-Net LoRAs.
168
+ if lora_name.startswith("lora_unet_"):
169
+ diffusers_name = _convert_unet_lora_key(key)
170
+
171
+ # Store down and up weights.
172
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
173
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
174
+
175
+ # Store DoRA scale if present.
176
+ if dora_present_in_unet:
177
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
178
+ unet_state_dict[
179
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
180
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
181
+
182
+ # Handle text encoder LoRAs.
183
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
184
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
185
+
186
+ # Store down and up weights for te or te2.
187
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
188
+ te_state_dict[diffusers_name] = state_dict.pop(key)
189
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
190
+ else:
191
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
192
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
193
+
194
+ # Store DoRA scale if present.
195
+ if dora_present_in_te or dora_present_in_te2:
196
+ dora_scale_key_to_replace_te = (
197
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
198
+ )
199
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
200
+ te_state_dict[
201
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
202
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
203
+ elif lora_name.startswith("lora_te2_"):
204
+ te2_state_dict[
205
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
206
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
207
+
208
+ # Store alpha if present.
209
+ if lora_name_alpha in state_dict:
210
+ alpha = state_dict.pop(lora_name_alpha).item()
211
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
212
+
213
+ # Check if any keys remain.
214
+ if len(state_dict) > 0:
215
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
216
+
217
+ logger.info("Non-diffusers checkpoint detected.")
218
+
219
+ # Construct final state dict.
220
+ unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
221
+ te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
222
+ te2_state_dict = (
223
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
224
+ if len(te2_state_dict) > 0
225
+ else None
226
+ )
227
+ if te2_state_dict is not None:
228
+ te_state_dict.update(te2_state_dict)
229
+
230
+ new_state_dict = {**unet_state_dict, **te_state_dict}
231
+ return new_state_dict, network_alphas
232
+
233
+
234
+ def _convert_unet_lora_key(key):
235
+ """
236
+ Converts a U-Net LoRA key to a Diffusers compatible key.
237
+ """
238
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
239
+
240
+ # Replace common U-Net naming patterns.
241
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
242
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
243
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
244
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
245
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
246
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
247
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
248
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
249
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
250
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
251
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
252
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
253
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
254
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
255
+
256
+ # SDXL specific conversions.
257
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
258
+ pattern = r"\.\d+(?=\D*$)"
259
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
260
+ if ".in." in diffusers_name:
261
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
262
+ if ".out." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
264
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("op", "conv")
266
+ if "skip" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
268
+
269
+ # LyCORIS specific conversions.
270
+ if "time.emb.proj" in diffusers_name:
271
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
272
+ if "conv.shortcut" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
274
+
275
+ # General conversions.
276
+ if "transformer_blocks" in diffusers_name:
277
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
278
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
279
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
280
+ elif "ff" in diffusers_name:
281
+ pass
282
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
283
+ pass
284
+ else:
285
+ pass
286
+
287
+ return diffusers_name
288
+
289
+
290
+ def _convert_text_encoder_lora_key(key, lora_name):
291
+ """
292
+ Converts a text encoder LoRA key to a Diffusers compatible key.
293
+ """
294
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
295
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
296
+ else:
297
+ key_to_replace = "lora_te2_"
298
+
299
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
300
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
301
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
302
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
303
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
304
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
305
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
306
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
307
+
308
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
309
+ pass
310
+ elif "mlp" in diffusers_name:
311
+ # Be aware that this is the new diffusers convention and the rest of the code might
312
+ # not utilize it yet.
313
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
314
+ return diffusers_name
315
+
316
+
317
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
318
+ """
319
+ Gets the correct alpha name for the Diffusers model.
320
+ """
321
+ if lora_name_alpha.startswith("lora_unet_"):
322
+ prefix = "unet."
323
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
324
+ prefix = "text_encoder."
325
+ else:
326
+ prefix = "text_encoder_2."
327
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328
+ return {new_name: alpha}
flux/lora/lora_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
flux/lora/peft.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ from functools import partial
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ from diffusers.utils import (
20
+ MIN_PEFT_VERSION,
21
+ USE_PEFT_BACKEND,
22
+ check_peft_version,
23
+ delete_adapter_layers,
24
+ is_peft_available,
25
+ set_adapter_layers,
26
+ set_weights_and_activate_adapters,
27
+ )
28
+ #from .unet_loader_utils import _maybe_expand_lora_scales
29
+
30
+
31
+ _SET_ADAPTER_SCALE_FN_MAPPING = {
32
+ #"UNet2DConditionModel": _maybe_expand_lora_scales,
33
+ #"UNetMotionModel": _maybe_expand_lora_scales,
34
+ "SD3Transformer2DModel": lambda model_cls, weights: weights,
35
+ "FluxTransformer2DModel": lambda model_cls, weights: weights,
36
+ }
37
+
38
+
39
+ class PeftAdapterMixin:
40
+ """
41
+ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
42
+ more details about adapters and injecting them in a base model, check out the PEFT
43
+ [documentation](https://huggingface.co/docs/peft/index).
44
+
45
+ Install the latest version of PEFT, and use this mixin to:
46
+
47
+ - Attach new adapters in the model.
48
+ - Attach multiple adapters and iteratively activate/deactivate them.
49
+ - Activate/deactivate all adapters from the model.
50
+ - Get a list of the active adapters.
51
+ """
52
+
53
+ _hf_peft_config_loaded = False
54
+
55
+ def set_adapters(
56
+ self,
57
+ adapter_names: Union[List[str], str],
58
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
59
+ ):
60
+ """
61
+ Set the currently active adapters for use in the UNet.
62
+
63
+ Args:
64
+ adapter_names (`List[str]` or `str`):
65
+ The names of the adapters to use.
66
+ adapter_weights (`Union[List[float], float]`, *optional*):
67
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
68
+ adapters.
69
+
70
+ Example:
71
+
72
+ ```py
73
+ from diffusers import AutoPipelineForText2Image
74
+ import torch
75
+
76
+ pipeline = AutoPipelineForText2Image.from_pretrained(
77
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
78
+ ).to("cuda")
79
+ pipeline.load_lora_weights(
80
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
81
+ )
82
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
83
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
84
+ ```
85
+ """
86
+ if not USE_PEFT_BACKEND:
87
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
88
+
89
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
90
+
91
+ # Expand weights into a list, one entry per adapter
92
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
93
+ if not isinstance(weights, list):
94
+ weights = [weights] * len(adapter_names)
95
+
96
+ if len(adapter_names) != len(weights):
97
+ raise ValueError(
98
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
99
+ )
100
+
101
+ # Set None values to default of 1.0
102
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
103
+ weights = [w if w is not None else 1.0 for w in weights]
104
+
105
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
106
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
107
+ weights = scale_expansion_fn(self, weights)
108
+
109
+ set_weights_and_activate_adapters(self, adapter_names, weights)
110
+
111
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
112
+ r"""
113
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
114
+ to the adapter to follow the convention of the PEFT library.
115
+
116
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
117
+ [documentation](https://huggingface.co/docs/peft).
118
+
119
+ Args:
120
+ adapter_config (`[~peft.PeftConfig]`):
121
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
122
+ methods.
123
+ adapter_name (`str`, *optional*, defaults to `"default"`):
124
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
125
+ """
126
+ check_peft_version(min_version=MIN_PEFT_VERSION)
127
+
128
+ if not is_peft_available():
129
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
130
+
131
+ from peft import PeftConfig, inject_adapter_in_model
132
+
133
+ if not self._hf_peft_config_loaded:
134
+ self._hf_peft_config_loaded = True
135
+ elif adapter_name in self.peft_config:
136
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
137
+
138
+ if not isinstance(adapter_config, PeftConfig):
139
+ raise ValueError(
140
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
141
+ )
142
+
143
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
144
+ # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here.
145
+ adapter_config.base_model_name_or_path = None
146
+ inject_adapter_in_model(adapter_config, self, adapter_name)
147
+ self.set_adapter(adapter_name)
148
+
149
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
150
+ """
151
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
152
+
153
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
154
+ [documentation](https://huggingface.co/docs/peft).
155
+
156
+ Args:
157
+ adapter_name (Union[str, List[str]])):
158
+ The list of adapters to set or the adapter name in the case of a single adapter.
159
+ """
160
+ check_peft_version(min_version=MIN_PEFT_VERSION)
161
+
162
+ if not self._hf_peft_config_loaded:
163
+ raise ValueError("No adapter loaded. Please load an adapter first.")
164
+
165
+ if isinstance(adapter_name, str):
166
+ adapter_name = [adapter_name]
167
+
168
+ missing = set(adapter_name) - set(self.peft_config)
169
+ if len(missing) > 0:
170
+ raise ValueError(
171
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
172
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
173
+ )
174
+
175
+ from peft.tuners.tuners_utils import BaseTunerLayer
176
+
177
+ _adapters_has_been_set = False
178
+
179
+ for _, module in self.named_modules():
180
+ if isinstance(module, BaseTunerLayer):
181
+ if hasattr(module, "set_adapter"):
182
+ module.set_adapter(adapter_name)
183
+ # Previous versions of PEFT does not support multi-adapter inference
184
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
185
+ raise ValueError(
186
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
187
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
188
+ )
189
+ else:
190
+ module.active_adapter = adapter_name
191
+ _adapters_has_been_set = True
192
+
193
+ if not _adapters_has_been_set:
194
+ raise ValueError(
195
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
196
+ )
197
+
198
+ def disable_adapters(self) -> None:
199
+ r"""
200
+ Disable all adapters attached to the model and fallback to inference with the base model only.
201
+
202
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
203
+ [documentation](https://huggingface.co/docs/peft).
204
+ """
205
+ check_peft_version(min_version=MIN_PEFT_VERSION)
206
+
207
+ if not self._hf_peft_config_loaded:
208
+ raise ValueError("No adapter loaded. Please load an adapter first.")
209
+
210
+ from peft.tuners.tuners_utils import BaseTunerLayer
211
+
212
+ for _, module in self.named_modules():
213
+ if isinstance(module, BaseTunerLayer):
214
+ if hasattr(module, "enable_adapters"):
215
+ module.enable_adapters(enabled=False)
216
+ else:
217
+ # support for older PEFT versions
218
+ module.disable_adapters = True
219
+
220
+ def enable_adapters(self) -> None:
221
+ """
222
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
223
+ adapters to enable.
224
+
225
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
226
+ [documentation](https://huggingface.co/docs/peft).
227
+ """
228
+ check_peft_version(min_version=MIN_PEFT_VERSION)
229
+
230
+ if not self._hf_peft_config_loaded:
231
+ raise ValueError("No adapter loaded. Please load an adapter first.")
232
+
233
+ from peft.tuners.tuners_utils import BaseTunerLayer
234
+
235
+ for _, module in self.named_modules():
236
+ if isinstance(module, BaseTunerLayer):
237
+ if hasattr(module, "enable_adapters"):
238
+ module.enable_adapters(enabled=True)
239
+ else:
240
+ # support for older PEFT versions
241
+ module.disable_adapters = False
242
+
243
+ def active_adapters(self) -> List[str]:
244
+ """
245
+ Gets the current list of active adapters of the model.
246
+
247
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
248
+ [documentation](https://huggingface.co/docs/peft).
249
+ """
250
+ check_peft_version(min_version=MIN_PEFT_VERSION)
251
+
252
+ if not is_peft_available():
253
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
254
+
255
+ if not self._hf_peft_config_loaded:
256
+ raise ValueError("No adapter loaded. Please load an adapter first.")
257
+
258
+ from peft.tuners.tuners_utils import BaseTunerLayer
259
+
260
+ for _, module in self.named_modules():
261
+ if isinstance(module, BaseTunerLayer):
262
+ return module.active_adapter
263
+
264
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
265
+ if not USE_PEFT_BACKEND:
266
+ raise ValueError("PEFT backend is required for `fuse_lora()`.")
267
+
268
+ self.lora_scale = lora_scale
269
+ self._safe_fusing = safe_fusing
270
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
271
+
272
+ def _fuse_lora_apply(self, module, adapter_names=None):
273
+ from peft.tuners.tuners_utils import BaseTunerLayer
274
+
275
+ merge_kwargs = {"safe_merge": self._safe_fusing}
276
+
277
+ if isinstance(module, BaseTunerLayer):
278
+ if self.lora_scale != 1.0:
279
+ module.scale_layer(self.lora_scale)
280
+
281
+ # For BC with prevous PEFT versions, we need to check the signature
282
+ # of the `merge` method to see if it supports the `adapter_names` argument.
283
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
284
+ if "adapter_names" in supported_merge_kwargs:
285
+ merge_kwargs["adapter_names"] = adapter_names
286
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
287
+ raise ValueError(
288
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
289
+ " to the latest version of PEFT. `pip install -U peft`"
290
+ )
291
+
292
+ module.merge(**merge_kwargs)
293
+
294
+ def unfuse_lora(self):
295
+ if not USE_PEFT_BACKEND:
296
+ raise ValueError("PEFT backend is required for `unfuse_lora()`.")
297
+ self.apply(self._unfuse_lora_apply)
298
+
299
+ def _unfuse_lora_apply(self, module):
300
+ from peft.tuners.tuners_utils import BaseTunerLayer
301
+
302
+ if isinstance(module, BaseTunerLayer):
303
+ module.unmerge()
304
+
305
+ def unload_lora(self):
306
+ if not USE_PEFT_BACKEND:
307
+ raise ValueError("PEFT backend is required for `unload_lora()`.")
308
+
309
+ from diffusers.utils import recurse_remove_peft_layers
310
+
311
+ recurse_remove_peft_layers(self)
312
+ if hasattr(self, "peft_config"):
313
+ del self.peft_config
314
+
315
+ def disable_lora(self):
316
+ """
317
+ Disables the active LoRA layers of the underlying model.
318
+
319
+ Example:
320
+
321
+ ```py
322
+ from diffusers import AutoPipelineForText2Image
323
+ import torch
324
+
325
+ pipeline = AutoPipelineForText2Image.from_pretrained(
326
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
327
+ ).to("cuda")
328
+ pipeline.load_lora_weights(
329
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
330
+ )
331
+ pipeline.disable_lora()
332
+ ```
333
+ """
334
+ if not USE_PEFT_BACKEND:
335
+ raise ValueError("PEFT backend is required for this method.")
336
+ set_adapter_layers(self, enabled=False)
337
+
338
+ def enable_lora(self):
339
+ """
340
+ Enables the active LoRA layers of the underlying model.
341
+
342
+ Example:
343
+
344
+ ```py
345
+ from diffusers import AutoPipelineForText2Image
346
+ import torch
347
+
348
+ pipeline = AutoPipelineForText2Image.from_pretrained(
349
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
350
+ ).to("cuda")
351
+ pipeline.load_lora_weights(
352
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
353
+ )
354
+ pipeline.enable_lora()
355
+ ```
356
+ """
357
+ if not USE_PEFT_BACKEND:
358
+ raise ValueError("PEFT backend is required for this method.")
359
+ set_adapter_layers(self, enabled=True)
360
+
361
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
362
+ """
363
+ Delete an adapter's LoRA layers from the underlying model.
364
+
365
+ Args:
366
+ adapter_names (`Union[List[str], str]`):
367
+ The names (single string or list of strings) of the adapter to delete.
368
+
369
+ Example:
370
+
371
+ ```py
372
+ from diffusers import AutoPipelineForText2Image
373
+ import torch
374
+
375
+ pipeline = AutoPipelineForText2Image.from_pretrained(
376
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
377
+ ).to("cuda")
378
+ pipeline.load_lora_weights(
379
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
380
+ )
381
+ pipeline.delete_adapters("cinematic")
382
+ ```
383
+ """
384
+ if not USE_PEFT_BACKEND:
385
+ raise ValueError("PEFT backend is required for this method.")
386
+
387
+ if isinstance(adapter_names, str):
388
+ adapter_names = [adapter_names]
389
+
390
+ for adapter_name in adapter_names:
391
+ delete_adapter_layers(self, adapter_name)
392
+
393
+ # Pop also the corresponding adapter from the config
394
+ if hasattr(self, "peft_config"):
395
+ self.peft_config.pop(adapter_name, None)
flux/normalization.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numbers
17
+ from typing import Dict, Optional, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.utils import is_torch_version
24
+ from .embeddings import get_activation
25
+ from .embeddings import (
26
+ CombinedTimestepLabelEmbeddings,
27
+ PixArtAlphaCombinedTimestepSizeEmbeddings,
28
+ )
29
+
30
+
31
+ class AdaLayerNorm(nn.Module):
32
+ r"""
33
+ Norm layer modified to incorporate timestep embeddings.
34
+
35
+ Parameters:
36
+ embedding_dim (`int`): The size of each embedding vector.
37
+ num_embeddings (`int`): The size of the embeddings dictionary.
38
+ """
39
+
40
+ def __init__(self, embedding_dim: int, num_embeddings: int):
41
+ super().__init__()
42
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
43
+ self.silu = nn.SiLU()
44
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
45
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
46
+
47
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
48
+ emb = self.linear(self.silu(self.emb(timestep)))
49
+ scale, shift = torch.chunk(emb, 2)
50
+ x = self.norm(x) * (1 + scale) + shift
51
+ return x
52
+
53
+
54
+ class FP32LayerNorm(nn.LayerNorm):
55
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
56
+ origin_dtype = inputs.dtype
57
+ return F.layer_norm(
58
+ inputs.float(),
59
+ self.normalized_shape,
60
+ self.weight.float() if self.weight is not None else None,
61
+ self.bias.float() if self.bias is not None else None,
62
+ self.eps,
63
+ ).to(origin_dtype)
64
+
65
+
66
+ class AdaLayerNormZero(nn.Module):
67
+ r"""
68
+ Norm layer adaptive layer norm zero (adaLN-Zero).
69
+
70
+ Parameters:
71
+ embedding_dim (`int`): The size of each embedding vector.
72
+ num_embeddings (`int`): The size of the embeddings dictionary.
73
+ """
74
+
75
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
76
+ super().__init__()
77
+ if num_embeddings is not None:
78
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
79
+ else:
80
+ self.emb = None
81
+
82
+ self.silu = nn.SiLU()
83
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
84
+ if norm_type == "layer_norm":
85
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
86
+ elif norm_type == "fp32_layer_norm":
87
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
88
+ else:
89
+ raise ValueError(
90
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
91
+ )
92
+
93
+ def forward(
94
+ self,
95
+ x: torch.Tensor,
96
+ timestep: Optional[torch.Tensor] = None,
97
+ class_labels: Optional[torch.LongTensor] = None,
98
+ hidden_dtype: Optional[torch.dtype] = None,
99
+ emb: Optional[torch.Tensor] = None,
100
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
101
+ if self.emb is not None:
102
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
103
+ emb = self.linear(self.silu(emb))
104
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
105
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
106
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
107
+
108
+
109
+ class AdaLayerNormZeroSingle(nn.Module):
110
+ r"""
111
+ Norm layer adaptive layer norm zero (adaLN-Zero).
112
+
113
+ Parameters:
114
+ embedding_dim (`int`): The size of each embedding vector.
115
+ num_embeddings (`int`): The size of the embeddings dictionary.
116
+ """
117
+
118
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
119
+ super().__init__()
120
+
121
+ self.silu = nn.SiLU()
122
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
123
+ if norm_type == "layer_norm":
124
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
125
+ else:
126
+ raise ValueError(
127
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
128
+ )
129
+
130
+ def forward(
131
+ self,
132
+ x: torch.Tensor,
133
+ emb: Optional[torch.Tensor] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
135
+ emb = self.linear(self.silu(emb))
136
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
137
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
138
+ return x, gate_msa
139
+
140
+
141
+ class LuminaRMSNormZero(nn.Module):
142
+ """
143
+ Norm layer adaptive RMS normalization zero.
144
+
145
+ Parameters:
146
+ embedding_dim (`int`): The size of each embedding vector.
147
+ """
148
+
149
+ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
150
+ super().__init__()
151
+ self.silu = nn.SiLU()
152
+ self.linear = nn.Linear(
153
+ min(embedding_dim, 1024),
154
+ 4 * embedding_dim,
155
+ bias=True,
156
+ )
157
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
158
+
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ emb: Optional[torch.Tensor] = None,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
165
+ emb = self.linear(self.silu(emb))
166
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
167
+ x = self.norm(x) * (1 + scale_msa[:, None])
168
+
169
+ return x, gate_msa, scale_mlp, gate_mlp
170
+
171
+
172
+ class AdaLayerNormSingle(nn.Module):
173
+ r"""
174
+ Norm layer adaptive layer norm single (adaLN-single).
175
+
176
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
177
+
178
+ Parameters:
179
+ embedding_dim (`int`): The size of each embedding vector.
180
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
181
+ """
182
+
183
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
184
+ super().__init__()
185
+
186
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
187
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
188
+ )
189
+
190
+ self.silu = nn.SiLU()
191
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
192
+
193
+ def forward(
194
+ self,
195
+ timestep: torch.Tensor,
196
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
197
+ batch_size: Optional[int] = None,
198
+ hidden_dtype: Optional[torch.dtype] = None,
199
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
200
+ # No modulation happening here.
201
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
202
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
203
+
204
+
205
+ class AdaGroupNorm(nn.Module):
206
+ r"""
207
+ GroupNorm layer modified to incorporate timestep embeddings.
208
+
209
+ Parameters:
210
+ embedding_dim (`int`): The size of each embedding vector.
211
+ num_embeddings (`int`): The size of the embeddings dictionary.
212
+ num_groups (`int`): The number of groups to separate the channels into.
213
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
214
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
215
+ """
216
+
217
+ def __init__(
218
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
219
+ ):
220
+ super().__init__()
221
+ self.num_groups = num_groups
222
+ self.eps = eps
223
+
224
+ if act_fn is None:
225
+ self.act = None
226
+ else:
227
+ self.act = get_activation(act_fn)
228
+
229
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
230
+
231
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
232
+ if self.act:
233
+ emb = self.act(emb)
234
+ emb = self.linear(emb)
235
+ emb = emb[:, :, None, None]
236
+ scale, shift = emb.chunk(2, dim=1)
237
+
238
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
239
+ x = x * (1 + scale) + shift
240
+ return x
241
+
242
+
243
+ class AdaLayerNormContinuous(nn.Module):
244
+ def __init__(
245
+ self,
246
+ embedding_dim: int,
247
+ conditioning_embedding_dim: int,
248
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
249
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
250
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
251
+ # However, this is how it was implemented in the original code, and it's rather likely you should
252
+ # set `elementwise_affine` to False.
253
+ elementwise_affine=True,
254
+ eps=1e-5,
255
+ bias=True,
256
+ norm_type="layer_norm",
257
+ ):
258
+ super().__init__()
259
+ self.silu = nn.SiLU()
260
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
261
+ if norm_type == "layer_norm":
262
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
263
+ elif norm_type == "rms_norm":
264
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
265
+ else:
266
+ raise ValueError(f"unknown norm_type {norm_type}")
267
+
268
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
269
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
270
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
271
+ scale, shift = torch.chunk(emb, 2, dim=1)
272
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
273
+ return x
274
+
275
+
276
+ class LuminaLayerNormContinuous(nn.Module):
277
+ def __init__(
278
+ self,
279
+ embedding_dim: int,
280
+ conditioning_embedding_dim: int,
281
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
282
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
283
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
284
+ # However, this is how it was implemented in the original code, and it's rather likely you should
285
+ # set `elementwise_affine` to False.
286
+ elementwise_affine=True,
287
+ eps=1e-5,
288
+ bias=True,
289
+ norm_type="layer_norm",
290
+ out_dim: Optional[int] = None,
291
+ ):
292
+ super().__init__()
293
+ # AdaLN
294
+ self.silu = nn.SiLU()
295
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
296
+ if norm_type == "layer_norm":
297
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
298
+ else:
299
+ raise ValueError(f"unknown norm_type {norm_type}")
300
+ # linear_2
301
+ if out_dim is not None:
302
+ self.linear_2 = nn.Linear(
303
+ embedding_dim,
304
+ out_dim,
305
+ bias=bias,
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ x: torch.Tensor,
311
+ conditioning_embedding: torch.Tensor,
312
+ ) -> torch.Tensor:
313
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
314
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
315
+ scale = emb
316
+ x = self.norm(x) * (1 + scale)[:, None, :]
317
+
318
+ if self.linear_2 is not None:
319
+ x = self.linear_2(x)
320
+
321
+ return x
322
+
323
+
324
+ if is_torch_version(">=", "2.1.0"):
325
+ LayerNorm = nn.LayerNorm
326
+ else:
327
+ # Has optional bias parameter compared to torch layer norm
328
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
329
+ class LayerNorm(nn.Module):
330
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
331
+ super().__init__()
332
+
333
+ self.eps = eps
334
+
335
+ if isinstance(dim, numbers.Integral):
336
+ dim = (dim,)
337
+
338
+ self.dim = torch.Size(dim)
339
+
340
+ if elementwise_affine:
341
+ self.weight = nn.Parameter(torch.ones(dim))
342
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
343
+ else:
344
+ self.weight = None
345
+ self.bias = None
346
+
347
+ def forward(self, input):
348
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
349
+
350
+
351
+ class RMSNorm(nn.Module):
352
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
353
+ super().__init__()
354
+
355
+ self.eps = eps
356
+
357
+ if isinstance(dim, numbers.Integral):
358
+ dim = (dim,)
359
+
360
+ self.dim = torch.Size(dim)
361
+
362
+ if elementwise_affine:
363
+ self.weight = nn.Parameter(torch.ones(dim))
364
+ else:
365
+ self.weight = None
366
+
367
+ def forward(self, hidden_states):
368
+ input_dtype = hidden_states.dtype
369
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
370
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
371
+
372
+ if self.weight is not None:
373
+ # convert into half-precision if necessary
374
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
375
+ hidden_states = hidden_states.to(self.weight.dtype)
376
+ hidden_states = hidden_states * self.weight
377
+ else:
378
+ hidden_states = hidden_states.to(input_dtype)
379
+
380
+ return hidden_states
381
+
382
+
383
+ class GlobalResponseNorm(nn.Module):
384
+ # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
385
+ def __init__(self, dim):
386
+ super().__init__()
387
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
388
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
389
+
390
+ def forward(self, x):
391
+ gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
392
+ nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
393
+ return self.gamma * (x * nx) + self.beta + x
flux/pipeline_flux.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from .transformer_flux import FluxTransformer2DModel
26
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from .pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers import FluxPipeline
55
+
56
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57
+ >>> pipe.to("cuda")
58
+ >>> prompt = "A cat holding a sign that says hello world"
59
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
60
+ >>> # Refer to the pipeline documentation for more details.
61
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62
+ >>> image.save("flux.png")
63
+ ```
64
+ """
65
+
66
+
67
+ def calculate_shift(
68
+ image_seq_len,
69
+ base_seq_len: int = 256,
70
+ max_seq_len: int = 4096,
71
+ base_shift: float = 0.5,
72
+ max_shift: float = 1.16,
73
+ ):
74
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75
+ b = base_shift - m * base_seq_len
76
+ mu = image_seq_len * m + b
77
+ return mu
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ """
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
141
+ r"""
142
+ The Flux pipeline for text-to-image generation.
143
+
144
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
145
+
146
+ Args:
147
+ transformer ([`FluxTransformer2DModel`]):
148
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
149
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
150
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
151
+ vae ([`AutoencoderKL`]):
152
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
153
+ text_encoder ([`CLIPTextModel`]):
154
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
155
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
156
+ text_encoder_2 ([`T5EncoderModel`]):
157
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
158
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
159
+ tokenizer (`CLIPTokenizer`):
160
+ Tokenizer of class
161
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
162
+ tokenizer_2 (`T5TokenizerFast`):
163
+ Second Tokenizer of class
164
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
165
+ """
166
+
167
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
168
+ _optional_components = []
169
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
170
+
171
+ def __init__(
172
+ self,
173
+ scheduler: FlowMatchEulerDiscreteScheduler,
174
+ vae: AutoencoderKL,
175
+ text_encoder: CLIPTextModel,
176
+ tokenizer: CLIPTokenizer,
177
+ text_encoder_2: T5EncoderModel,
178
+ tokenizer_2: T5TokenizerFast,
179
+ transformer: FluxTransformer2DModel,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.register_modules(
184
+ vae=vae,
185
+ text_encoder=text_encoder,
186
+ text_encoder_2=text_encoder_2,
187
+ tokenizer=tokenizer,
188
+ tokenizer_2=tokenizer_2,
189
+ transformer=transformer,
190
+ scheduler=scheduler,
191
+ )
192
+ self.vae_scale_factor = (
193
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
194
+ )
195
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
196
+ self.tokenizer_max_length = (
197
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
198
+ )
199
+ self.default_sample_size = 64
200
+
201
+ def _get_t5_prompt_embeds(
202
+ self,
203
+ prompt: Union[str, List[str]] = None,
204
+ num_images_per_prompt: int = 1,
205
+ max_sequence_length: int = 512,
206
+ device: Optional[torch.device] = None,
207
+ dtype: Optional[torch.dtype] = None,
208
+ ):
209
+ device = device or self._execution_device
210
+ dtype = dtype or self.text_encoder.dtype
211
+
212
+ prompt = [prompt] if isinstance(prompt, str) else prompt
213
+ batch_size = len(prompt)
214
+
215
+ text_inputs = self.tokenizer_2(
216
+ prompt,
217
+ padding="max_length",
218
+ max_length=max_sequence_length,
219
+ truncation=True,
220
+ return_length=False,
221
+ return_overflowing_tokens=False,
222
+ return_tensors="pt",
223
+ )
224
+ text_input_ids = text_inputs.input_ids
225
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
226
+
227
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
228
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
229
+ logger.warning(
230
+ "The following part of your input was truncated because `max_sequence_length` is set to "
231
+ f" {max_sequence_length} tokens: {removed_text}"
232
+ )
233
+
234
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
235
+
236
+ dtype = self.text_encoder_2.dtype
237
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
238
+
239
+ _, seq_len, _ = prompt_embeds.shape
240
+
241
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
242
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
243
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
244
+
245
+ return prompt_embeds
246
+
247
+ def _get_clip_prompt_embeds(
248
+ self,
249
+ prompt: Union[str, List[str]],
250
+ num_images_per_prompt: int = 1,
251
+ device: Optional[torch.device] = None,
252
+ ):
253
+ device = device or self._execution_device
254
+
255
+ prompt = [prompt] if isinstance(prompt, str) else prompt
256
+ batch_size = len(prompt)
257
+
258
+ text_inputs = self.tokenizer(
259
+ prompt,
260
+ padding="max_length",
261
+ max_length=self.tokenizer_max_length,
262
+ truncation=True,
263
+ return_overflowing_tokens=False,
264
+ return_length=False,
265
+ return_tensors="pt",
266
+ )
267
+
268
+ text_input_ids = text_inputs.input_ids
269
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
270
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
271
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
272
+ logger.warning(
273
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
274
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
275
+ )
276
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
277
+
278
+ # Use pooled output of CLIPTextModel
279
+ prompt_embeds = prompt_embeds.pooler_output
280
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
281
+
282
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
283
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
284
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
285
+
286
+ return prompt_embeds
287
+
288
+ def encode_prompt(
289
+ self,
290
+ prompt: Union[str, List[str]],
291
+ prompt_2: Union[str, List[str]],
292
+ device: Optional[torch.device] = None,
293
+ num_images_per_prompt: int = 1,
294
+ prompt_embeds: Optional[torch.FloatTensor] = None,
295
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
296
+ max_sequence_length: int = 512,
297
+ lora_scale: Optional[float] = None,
298
+ ):
299
+ r"""
300
+
301
+ Args:
302
+ prompt (`str` or `List[str]`, *optional*):
303
+ prompt to be encoded
304
+ prompt_2 (`str` or `List[str]`, *optional*):
305
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
306
+ used in all text-encoders
307
+ device: (`torch.device`):
308
+ torch device
309
+ num_images_per_prompt (`int`):
310
+ number of images that should be generated per prompt
311
+ prompt_embeds (`torch.FloatTensor`, *optional*):
312
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
313
+ provided, text embeddings will be generated from `prompt` input argument.
314
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
315
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
316
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
317
+ lora_scale (`float`, *optional*):
318
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
319
+ """
320
+ device = device or self._execution_device
321
+
322
+ # set lora scale so that monkey patched LoRA
323
+ # function of text encoder can correctly access it
324
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
325
+ self._lora_scale = lora_scale
326
+
327
+ # dynamically adjust the LoRA scale
328
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
329
+ scale_lora_layers(self.text_encoder, lora_scale)
330
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
331
+ scale_lora_layers(self.text_encoder_2, lora_scale)
332
+
333
+ prompt = [prompt] if isinstance(prompt, str) else prompt
334
+ if prompt is not None:
335
+ batch_size = len(prompt)
336
+ else:
337
+ batch_size = prompt_embeds.shape[0]
338
+
339
+ if prompt_embeds is None:
340
+ prompt_2 = prompt_2 or prompt
341
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
342
+
343
+ # We only use the pooled prompt output from the CLIPTextModel
344
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
345
+ prompt=prompt,
346
+ device=device,
347
+ num_images_per_prompt=num_images_per_prompt,
348
+ )
349
+ prompt_embeds = self._get_t5_prompt_embeds(
350
+ prompt=prompt_2,
351
+ num_images_per_prompt=num_images_per_prompt,
352
+ max_sequence_length=max_sequence_length,
353
+ device=device,
354
+ )
355
+
356
+ if self.text_encoder is not None:
357
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
358
+ # Retrieve the original scale by scaling back the LoRA layers
359
+ unscale_lora_layers(self.text_encoder, lora_scale)
360
+
361
+ if self.text_encoder_2 is not None:
362
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
363
+ # Retrieve the original scale by scaling back the LoRA layers
364
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
365
+
366
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
367
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
368
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
369
+
370
+ return prompt_embeds, pooled_prompt_embeds, text_ids
371
+
372
+ def check_inputs(
373
+ self,
374
+ prompt,
375
+ prompt_2,
376
+ height,
377
+ width,
378
+ prompt_embeds=None,
379
+ pooled_prompt_embeds=None,
380
+ callback_on_step_end_tensor_inputs=None,
381
+ max_sequence_length=None,
382
+ ):
383
+ if height % 8 != 0 or width % 8 != 0:
384
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
385
+
386
+ if callback_on_step_end_tensor_inputs is not None and not all(
387
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
388
+ ):
389
+ raise ValueError(
390
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
391
+ )
392
+
393
+ if prompt is not None and prompt_embeds is not None:
394
+ raise ValueError(
395
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
396
+ " only forward one of the two."
397
+ )
398
+ elif prompt_2 is not None and prompt_embeds is not None:
399
+ raise ValueError(
400
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
401
+ " only forward one of the two."
402
+ )
403
+ elif prompt is None and prompt_embeds is None:
404
+ raise ValueError(
405
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
406
+ )
407
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
408
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
409
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
410
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
411
+
412
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
413
+ raise ValueError(
414
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
415
+ )
416
+
417
+ if max_sequence_length is not None and max_sequence_length > 512:
418
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
419
+
420
+ @staticmethod
421
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
422
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
423
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
424
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
425
+
426
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
427
+
428
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
429
+ latent_image_ids = latent_image_ids.reshape(
430
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
431
+ )
432
+
433
+ return latent_image_ids.to(device=device, dtype=dtype)
434
+
435
+ @staticmethod
436
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
437
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
438
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
439
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
440
+
441
+ return latents
442
+
443
+ @staticmethod
444
+ def _unpack_latents(latents, height, width, vae_scale_factor):
445
+ batch_size, num_patches, channels = latents.shape
446
+
447
+ height = height // vae_scale_factor
448
+ width = width // vae_scale_factor
449
+
450
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
451
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
452
+
453
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
454
+
455
+ return latents
456
+
457
+ def prepare_latents(
458
+ self,
459
+ batch_size,
460
+ num_channels_latents,
461
+ height,
462
+ width,
463
+ dtype,
464
+ device,
465
+ generator,
466
+ latents=None,
467
+ ):
468
+ height = 2 * (int(height) // self.vae_scale_factor)
469
+ width = 2 * (int(width) // self.vae_scale_factor)
470
+
471
+ shape = (batch_size, num_channels_latents, height, width)
472
+
473
+ if latents is not None:
474
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
475
+ return latents.to(device=device, dtype=dtype), latent_image_ids
476
+
477
+ if isinstance(generator, list) and len(generator) != batch_size:
478
+ raise ValueError(
479
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
480
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
481
+ )
482
+
483
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
484
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
485
+
486
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
487
+
488
+ return latents, latent_image_ids
489
+
490
+ @property
491
+ def guidance_scale(self):
492
+ return self._guidance_scale
493
+
494
+ @property
495
+ def joint_attention_kwargs(self):
496
+ return self._joint_attention_kwargs
497
+
498
+ @property
499
+ def num_timesteps(self):
500
+ return self._num_timesteps
501
+
502
+ @property
503
+ def interrupt(self):
504
+ return self._interrupt
505
+
506
+ @torch.no_grad()
507
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
508
+ def __call__(
509
+ self,
510
+ prompt: Union[str, List[str]] = None,
511
+ prompt_2: Optional[Union[str, List[str]]] = None,
512
+ height: Optional[int] = None,
513
+ width: Optional[int] = None,
514
+ num_inference_steps: int = 28,
515
+ timesteps: List[int] = None,
516
+ guidance_scale: float = 7.0,
517
+ num_images_per_prompt: Optional[int] = 1,
518
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
+ latents: Optional[torch.FloatTensor] = None,
520
+ prompt_embeds: Optional[torch.FloatTensor] = None,
521
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
522
+ output_type: Optional[str] = "pil",
523
+ return_dict: bool = True,
524
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
525
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
526
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
527
+ max_sequence_length: int = 512,
528
+ ):
529
+ r"""
530
+ Function invoked when calling the pipeline for generation.
531
+
532
+ Args:
533
+ prompt (`str` or `List[str]`, *optional*):
534
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
535
+ instead.
536
+ prompt_2 (`str` or `List[str]`, *optional*):
537
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
538
+ will be used instead
539
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
540
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
541
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
542
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
543
+ num_inference_steps (`int`, *optional*, defaults to 50):
544
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
545
+ expense of slower inference.
546
+ timesteps (`List[int]`, *optional*):
547
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
548
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
549
+ passed will be used. Must be in descending order.
550
+ guidance_scale (`float`, *optional*, defaults to 7.0):
551
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
552
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
553
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
554
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
555
+ usually at the expense of lower image quality.
556
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
557
+ The number of images to generate per prompt.
558
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
559
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
560
+ to make generation deterministic.
561
+ latents (`torch.FloatTensor`, *optional*):
562
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
563
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
564
+ tensor will ge generated by sampling using the supplied random `generator`.
565
+ prompt_embeds (`torch.FloatTensor`, *optional*):
566
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
567
+ provided, text embeddings will be generated from `prompt` input argument.
568
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
569
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
570
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
571
+ output_type (`str`, *optional*, defaults to `"pil"`):
572
+ The output format of the generate image. Choose between
573
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
574
+ return_dict (`bool`, *optional*, defaults to `True`):
575
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
576
+ joint_attention_kwargs (`dict`, *optional*):
577
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
578
+ `self.processor` in
579
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
580
+ callback_on_step_end (`Callable`, *optional*):
581
+ A function that calls at the end of each denoising steps during the inference. The function is called
582
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
583
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
584
+ `callback_on_step_end_tensor_inputs`.
585
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
586
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
587
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
588
+ `._callback_tensor_inputs` attribute of your pipeline class.
589
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
590
+
591
+ Examples:
592
+
593
+ Returns:
594
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
595
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
596
+ images.
597
+ """
598
+
599
+ height = height or self.default_sample_size * self.vae_scale_factor
600
+ width = width or self.default_sample_size * self.vae_scale_factor
601
+
602
+ # 1. Check inputs. Raise error if not correct
603
+ self.check_inputs(
604
+ prompt,
605
+ prompt_2,
606
+ height,
607
+ width,
608
+ prompt_embeds=prompt_embeds,
609
+ pooled_prompt_embeds=pooled_prompt_embeds,
610
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
611
+ max_sequence_length=max_sequence_length,
612
+ )
613
+
614
+ self._guidance_scale = guidance_scale
615
+ self._joint_attention_kwargs = joint_attention_kwargs
616
+ self._interrupt = False
617
+
618
+ # 2. Define call parameters
619
+ if prompt is not None and isinstance(prompt, str):
620
+ batch_size = 1
621
+ elif prompt is not None and isinstance(prompt, list):
622
+ batch_size = len(prompt)
623
+ else:
624
+ batch_size = prompt_embeds.shape[0]
625
+
626
+ device = self._execution_device
627
+
628
+ lora_scale = (
629
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
630
+ )
631
+ (
632
+ prompt_embeds,
633
+ pooled_prompt_embeds,
634
+ text_ids,
635
+ ) = self.encode_prompt(
636
+ prompt=prompt,
637
+ prompt_2=prompt_2,
638
+ prompt_embeds=prompt_embeds,
639
+ pooled_prompt_embeds=pooled_prompt_embeds,
640
+ device=device,
641
+ num_images_per_prompt=num_images_per_prompt,
642
+ max_sequence_length=max_sequence_length,
643
+ lora_scale=lora_scale,
644
+ )
645
+
646
+ # 4. Prepare latent variables
647
+ num_channels_latents = self.transformer.config.in_channels // 4
648
+ latents, latent_image_ids = self.prepare_latents(
649
+ batch_size * num_images_per_prompt,
650
+ num_channels_latents,
651
+ height,
652
+ width,
653
+ prompt_embeds.dtype,
654
+ device,
655
+ generator,
656
+ latents,
657
+ )
658
+
659
+ # 5. Prepare timesteps
660
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
661
+ image_seq_len = latents.shape[1]
662
+ mu = calculate_shift(
663
+ image_seq_len,
664
+ self.scheduler.config.base_image_seq_len,
665
+ self.scheduler.config.max_image_seq_len,
666
+ self.scheduler.config.base_shift,
667
+ self.scheduler.config.max_shift,
668
+ )
669
+ timesteps, num_inference_steps = retrieve_timesteps(
670
+ self.scheduler,
671
+ num_inference_steps,
672
+ device,
673
+ timesteps,
674
+ sigmas,
675
+ mu=mu,
676
+ )
677
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
678
+ self._num_timesteps = len(timesteps)
679
+
680
+ # 6. Denoising loop
681
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
682
+ for i, t in enumerate(timesteps):
683
+ if self.interrupt:
684
+ continue
685
+
686
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
688
+
689
+ # handle guidance
690
+ if self.transformer.config.guidance_embeds:
691
+ guidance = torch.tensor([guidance_scale], device=device)
692
+ guidance = guidance.expand(latents.shape[0])
693
+ else:
694
+ guidance = None
695
+
696
+ noise_pred = self.transformer(
697
+ hidden_states=latents,
698
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
699
+ timestep=timestep / 1000,
700
+ guidance=guidance,
701
+ pooled_projections=pooled_prompt_embeds,
702
+ encoder_hidden_states=prompt_embeds,
703
+ txt_ids=text_ids,
704
+ img_ids=latent_image_ids,
705
+ joint_attention_kwargs=self.joint_attention_kwargs,
706
+ return_dict=False,
707
+ )[0]
708
+
709
+ # compute the previous noisy sample x_t -> x_t-1
710
+ latents_dtype = latents.dtype
711
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
712
+
713
+ if latents.dtype != latents_dtype:
714
+ if torch.backends.mps.is_available():
715
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
716
+ latents = latents.to(latents_dtype)
717
+
718
+ if callback_on_step_end is not None:
719
+ callback_kwargs = {}
720
+ for k in callback_on_step_end_tensor_inputs:
721
+ callback_kwargs[k] = locals()[k]
722
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
723
+
724
+ latents = callback_outputs.pop("latents", latents)
725
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
726
+
727
+ # call the callback, if provided
728
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
729
+ progress_bar.update()
730
+
731
+ if XLA_AVAILABLE:
732
+ xm.mark_step()
733
+
734
+ if output_type == "latent":
735
+ image = latents
736
+
737
+ else:
738
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
739
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
740
+ image = self.vae.decode(latents, return_dict=False)[0]
741
+ image = self.image_processor.postprocess(image, output_type=output_type)
742
+
743
+ # Offload all models
744
+ self.maybe_free_model_hooks()
745
+
746
+ if not return_dict:
747
+ return (image,)
748
+
749
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_chameleon.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from .transformer_flux import FluxTransformer2DModel
26
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from .pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers import FluxPipeline
55
+
56
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57
+ >>> pipe.to("cuda")
58
+ >>> prompt = "A cat holding a sign that says hello world"
59
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
60
+ >>> # Refer to the pipeline documentation for more details.
61
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62
+ >>> image.save("flux.png")
63
+ ```
64
+ """
65
+
66
+
67
+ def calculate_shift(
68
+ image_seq_len,
69
+ base_seq_len: int = 256,
70
+ max_seq_len: int = 4096,
71
+ base_shift: float = 0.5,
72
+ max_shift: float = 1.16,
73
+ ):
74
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75
+ b = base_shift - m * base_seq_len
76
+ mu = image_seq_len * m + b
77
+ return mu
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ """
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
141
+ r"""
142
+ The Flux pipeline for text-to-image generation.
143
+
144
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
145
+
146
+ Args:
147
+ transformer ([`FluxTransformer2DModel`]):
148
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
149
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
150
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
151
+ vae ([`AutoencoderKL`]):
152
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
153
+ text_encoder ([`CLIPTextModel`]):
154
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
155
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
156
+ text_encoder_2 ([`T5EncoderModel`]):
157
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
158
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
159
+ tokenizer (`CLIPTokenizer`):
160
+ Tokenizer of class
161
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
162
+ tokenizer_2 (`T5TokenizerFast`):
163
+ Second Tokenizer of class
164
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
165
+ """
166
+
167
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
168
+ _optional_components = []
169
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
170
+
171
+ def __init__(
172
+ self,
173
+ scheduler: FlowMatchEulerDiscreteScheduler,
174
+ vae: AutoencoderKL,
175
+ text_encoder: CLIPTextModel,
176
+ tokenizer: CLIPTokenizer,
177
+ transformer: FluxTransformer2DModel,
178
+ text_encoder_2: T5EncoderModel | None = None,
179
+ tokenizer_2: T5TokenizerFast | None = None,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.register_modules(
184
+ vae=vae,
185
+ text_encoder=text_encoder,
186
+ #text_encoder_2=text_encoder_2,
187
+ tokenizer=tokenizer,
188
+ #tokenizer_2=tokenizer_2,
189
+ transformer=transformer,
190
+ scheduler=scheduler,
191
+ )
192
+ self.vae_scale_factor = (
193
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
194
+ )
195
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
196
+ self.tokenizer_max_length = (
197
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
198
+ )
199
+ self.default_sample_size = 64
200
+
201
+ def _get_t5_prompt_embeds(
202
+ self,
203
+ prompt: Union[str, List[str]] = None,
204
+ num_images_per_prompt: int = 1,
205
+ max_sequence_length: int = 512,
206
+ device: Optional[torch.device] = None,
207
+ dtype: Optional[torch.dtype] = None,
208
+ ):
209
+ device = device or self._execution_device
210
+ dtype = dtype or self.text_encoder.dtype
211
+
212
+ prompt = [prompt] if isinstance(prompt, str) else prompt
213
+ batch_size = len(prompt)
214
+
215
+ text_inputs = self.tokenizer_2(
216
+ prompt,
217
+ padding="max_length",
218
+ max_length=max_sequence_length,
219
+ truncation=True,
220
+ return_length=False,
221
+ return_overflowing_tokens=False,
222
+ return_tensors="pt",
223
+ )
224
+ text_input_ids = text_inputs.input_ids
225
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
226
+
227
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
228
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
229
+ logger.warning(
230
+ "The following part of your input was truncated because `max_sequence_length` is set to "
231
+ f" {max_sequence_length} tokens: {removed_text}"
232
+ )
233
+
234
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
235
+
236
+ dtype = self.text_encoder_2.dtype
237
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
238
+
239
+ _, seq_len, _ = prompt_embeds.shape
240
+
241
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
242
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
243
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
244
+
245
+ return prompt_embeds
246
+
247
+ def _get_clip_prompt_embeds(
248
+ self,
249
+ prompt: Union[str, List[str]],
250
+ num_images_per_prompt: int = 1,
251
+ device: Optional[torch.device] = None,
252
+ ):
253
+ device = device or self._execution_device
254
+
255
+ prompt = [prompt] if isinstance(prompt, str) else prompt
256
+ batch_size = len(prompt)
257
+
258
+ text_inputs = self.tokenizer(
259
+ prompt,
260
+ padding="max_length",
261
+ max_length=self.tokenizer_max_length,
262
+ truncation=True,
263
+ return_overflowing_tokens=False,
264
+ return_length=False,
265
+ return_tensors="pt",
266
+ )
267
+
268
+ text_input_ids = text_inputs.input_ids
269
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
270
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
271
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
272
+ logger.warning(
273
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
274
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
275
+ )
276
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
277
+
278
+ # Use pooled output of CLIPTextModel
279
+ prompt_embeds = prompt_embeds.pooler_output
280
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
281
+
282
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
283
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
284
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
285
+
286
+ return prompt_embeds
287
+
288
+ def encode_prompt(
289
+ self,
290
+ prompt: Union[str, List[str]],
291
+ prompt_2: Union[str, List[str]],
292
+ device: Optional[torch.device] = None,
293
+ num_images_per_prompt: int = 1,
294
+ prompt_embeds: Optional[torch.FloatTensor] = None,
295
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
296
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
297
+ max_sequence_length: int = 512,
298
+ lora_scale: Optional[float] = None,
299
+ ):
300
+ r"""
301
+
302
+ Args:
303
+ prompt (`str` or `List[str]`, *optional*):
304
+ prompt to be encoded
305
+ prompt_2 (`str` or `List[str]`, *optional*):
306
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
307
+ used in all text-encoders
308
+ device: (`torch.device`):
309
+ torch device
310
+ num_images_per_prompt (`int`):
311
+ number of images that should be generated per prompt
312
+ prompt_embeds (`torch.FloatTensor`, *optional*):
313
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
314
+ provided, text embeddings will be generated from `prompt` input argument.
315
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
316
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
317
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
318
+ lora_scale (`float`, *optional*):
319
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
320
+ """
321
+ device = device or self._execution_device
322
+
323
+ # set lora scale so that monkey patched LoRA
324
+ # function of text encoder can correctly access it
325
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
326
+ self._lora_scale = lora_scale
327
+
328
+ # dynamically adjust the LoRA scale
329
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
330
+ scale_lora_layers(self.text_encoder, lora_scale)
331
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
332
+ scale_lora_layers(self.text_encoder_2, lora_scale)
333
+
334
+ prompt = [prompt] if isinstance(prompt, str) else prompt
335
+ if prompt is not None:
336
+ batch_size = len(prompt)
337
+ else:
338
+ batch_size = prompt_embeds.shape[0]
339
+
340
+ if prompt_embeds is None:
341
+ prompt_2 = prompt_2 or prompt
342
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
343
+
344
+ # We only use the pooled prompt output from the CLIPTextModel
345
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
346
+ prompt=prompt,
347
+ device=device,
348
+ num_images_per_prompt=num_images_per_prompt,
349
+ )
350
+ prompt_embeds = self._get_t5_prompt_embeds(
351
+ prompt=prompt_2,
352
+ num_images_per_prompt=num_images_per_prompt,
353
+ max_sequence_length=max_sequence_length,
354
+ device=device,
355
+ )
356
+
357
+ if self.text_encoder is not None:
358
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
359
+ # Retrieve the original scale by scaling back the LoRA layers
360
+ unscale_lora_layers(self.text_encoder, lora_scale)
361
+
362
+ #if self.text_encoder_2 is not None:
363
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
364
+ # # Retrieve the original scale by scaling back the LoRA layers
365
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
366
+
367
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
368
+ if t5_prompt_embeds is not None:
369
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
370
+ else:
371
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
372
+
373
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
374
+
375
+ return prompt_embeds, pooled_prompt_embeds, text_ids
376
+
377
+ def check_inputs(
378
+ self,
379
+ prompt,
380
+ prompt_2,
381
+ height,
382
+ width,
383
+ prompt_embeds=None,
384
+ pooled_prompt_embeds=None,
385
+ callback_on_step_end_tensor_inputs=None,
386
+ max_sequence_length=None,
387
+ ):
388
+ if height % 8 != 0 or width % 8 != 0:
389
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
390
+
391
+ if callback_on_step_end_tensor_inputs is not None and not all(
392
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
393
+ ):
394
+ raise ValueError(
395
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
396
+ )
397
+
398
+ if prompt is not None and prompt_embeds is not None:
399
+ raise ValueError(
400
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
401
+ " only forward one of the two."
402
+ )
403
+ elif prompt_2 is not None and prompt_embeds is not None:
404
+ raise ValueError(
405
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
406
+ " only forward one of the two."
407
+ )
408
+ elif prompt is None and prompt_embeds is None:
409
+ raise ValueError(
410
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
411
+ )
412
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
413
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
414
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
415
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
416
+
417
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
418
+ raise ValueError(
419
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
420
+ )
421
+
422
+ if max_sequence_length is not None and max_sequence_length > 512:
423
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
424
+
425
+ @staticmethod
426
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
427
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
428
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
429
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
430
+
431
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
432
+
433
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
434
+ latent_image_ids = latent_image_ids.reshape(
435
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
436
+ )
437
+
438
+ return latent_image_ids.to(device=device, dtype=dtype)
439
+
440
+ @staticmethod
441
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
442
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
443
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
444
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
445
+
446
+ return latents
447
+
448
+ @staticmethod
449
+ def _unpack_latents(latents, height, width, vae_scale_factor):
450
+ batch_size, num_patches, channels = latents.shape
451
+
452
+ height = height // vae_scale_factor
453
+ width = width // vae_scale_factor
454
+
455
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
456
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
457
+
458
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
459
+
460
+ return latents
461
+
462
+ def prepare_latents(
463
+ self,
464
+ batch_size,
465
+ num_channels_latents,
466
+ height,
467
+ width,
468
+ dtype,
469
+ device,
470
+ generator,
471
+ latents=None,
472
+ ):
473
+ height = 2 * (int(height) // self.vae_scale_factor)
474
+ width = 2 * (int(width) // self.vae_scale_factor)
475
+
476
+ shape = (batch_size, num_channels_latents, height, width)
477
+
478
+ if latents is not None:
479
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
480
+ return latents.to(device=device, dtype=dtype), latent_image_ids
481
+
482
+ if isinstance(generator, list) and len(generator) != batch_size:
483
+ raise ValueError(
484
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
485
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
486
+ )
487
+
488
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
489
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
490
+
491
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
492
+
493
+ return latents, latent_image_ids
494
+
495
+ @property
496
+ def guidance_scale(self):
497
+ return self._guidance_scale
498
+
499
+ @property
500
+ def joint_attention_kwargs(self):
501
+ return self._joint_attention_kwargs
502
+
503
+ @property
504
+ def num_timesteps(self):
505
+ return self._num_timesteps
506
+
507
+ @property
508
+ def interrupt(self):
509
+ return self._interrupt
510
+
511
+ #@torch.inference_mode()
512
+ @torch.no_grad()
513
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
514
+ def __call__(
515
+ self,
516
+ prompt: Union[str, List[str]] = None,
517
+ prompt_2: Optional[Union[str, List[str]]] = None,
518
+ height: Optional[int] = None,
519
+ width: Optional[int] = None,
520
+ num_inference_steps: int = 28,
521
+ timesteps: List[int] = None,
522
+ guidance_scale: float = 7.0,
523
+ num_images_per_prompt: Optional[int] = 1,
524
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
525
+ latents: Optional[torch.FloatTensor] = None,
526
+ prompt_embeds: Optional[torch.FloatTensor] = None,
527
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
528
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
529
+ output_type: Optional[str] = "pil",
530
+ return_dict: bool = True,
531
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
532
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
533
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
534
+ max_sequence_length: int = 512,
535
+ ):
536
+ r"""
537
+ Function invoked when calling the pipeline for generation.
538
+
539
+ Args:
540
+ prompt (`str` or `List[str]`, *optional*):
541
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
542
+ instead.
543
+ prompt_2 (`str` or `List[str]`, *optional*):
544
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
545
+ will be used instead
546
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
547
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
548
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
549
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
550
+ num_inference_steps (`int`, *optional*, defaults to 50):
551
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
552
+ expense of slower inference.
553
+ timesteps (`List[int]`, *optional*):
554
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
555
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
556
+ passed will be used. Must be in descending order.
557
+ guidance_scale (`float`, *optional*, defaults to 7.0):
558
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
559
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
560
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
561
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
562
+ usually at the expense of lower image quality.
563
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
564
+ The number of images to generate per prompt.
565
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
566
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
567
+ to make generation deterministic.
568
+ latents (`torch.FloatTensor`, *optional*):
569
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
570
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
571
+ tensor will ge generated by sampling using the supplied random `generator`.
572
+ prompt_embeds (`torch.FloatTensor`, *optional*):
573
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
574
+ provided, text embeddings will be generated from `prompt` input argument.
575
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
576
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
577
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
578
+ output_type (`str`, *optional*, defaults to `"pil"`):
579
+ The output format of the generate image. Choose between
580
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
581
+ return_dict (`bool`, *optional*, defaults to `True`):
582
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
583
+ joint_attention_kwargs (`dict`, *optional*):
584
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
585
+ `self.processor` in
586
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
587
+ callback_on_step_end (`Callable`, *optional*):
588
+ A function that calls at the end of each denoising steps during the inference. The function is called
589
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
590
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
591
+ `callback_on_step_end_tensor_inputs`.
592
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
593
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
594
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
595
+ `._callback_tensor_inputs` attribute of your pipeline class.
596
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
597
+
598
+ Examples:
599
+
600
+ Returns:
601
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
602
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
603
+ images.
604
+ """
605
+
606
+ height = height or self.default_sample_size * self.vae_scale_factor
607
+ width = width or self.default_sample_size * self.vae_scale_factor
608
+
609
+ # 1. Check inputs. Raise error if not correct
610
+ self.check_inputs(
611
+ prompt,
612
+ prompt_2,
613
+ height,
614
+ width,
615
+ prompt_embeds=prompt_embeds,
616
+ pooled_prompt_embeds=pooled_prompt_embeds,
617
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
618
+ max_sequence_length=max_sequence_length,
619
+ )
620
+
621
+ self._guidance_scale = guidance_scale
622
+ self._joint_attention_kwargs = joint_attention_kwargs
623
+ self._interrupt = False
624
+
625
+ # 2. Define call parameters
626
+ if prompt is not None and isinstance(prompt, str):
627
+ batch_size = 1
628
+ elif prompt is not None and isinstance(prompt, list):
629
+ batch_size = len(prompt)
630
+ else:
631
+ batch_size = prompt_embeds.shape[0]
632
+
633
+ device = self._execution_device
634
+
635
+ lora_scale = (
636
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
637
+ )
638
+ (
639
+ prompt_embeds,
640
+ pooled_prompt_embeds,
641
+ text_ids,
642
+ ) = self.encode_prompt(
643
+ prompt=prompt,
644
+ prompt_2=prompt_2,
645
+ prompt_embeds=prompt_embeds,
646
+ t5_prompt_embeds=t5_prompt_embeds,
647
+ pooled_prompt_embeds=pooled_prompt_embeds,
648
+ device=device,
649
+ num_images_per_prompt=num_images_per_prompt,
650
+ max_sequence_length=max_sequence_length,
651
+ lora_scale=lora_scale,
652
+ )
653
+
654
+ # 4. Prepare latent variables
655
+ num_channels_latents = self.transformer.config.in_channels // 4
656
+ latents, latent_image_ids = self.prepare_latents(
657
+ batch_size * num_images_per_prompt,
658
+ num_channels_latents,
659
+ height,
660
+ width,
661
+ prompt_embeds.dtype,
662
+ device,
663
+ generator,
664
+ latents,
665
+ )
666
+
667
+ # 5. Prepare timesteps
668
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
669
+ image_seq_len = latents.shape[1]
670
+ mu = calculate_shift(
671
+ image_seq_len,
672
+ self.scheduler.config.base_image_seq_len,
673
+ self.scheduler.config.max_image_seq_len,
674
+ self.scheduler.config.base_shift,
675
+ self.scheduler.config.max_shift,
676
+ )
677
+ timesteps, num_inference_steps = retrieve_timesteps(
678
+ self.scheduler,
679
+ num_inference_steps,
680
+ device,
681
+ timesteps,
682
+ sigmas,
683
+ mu=mu,
684
+ )
685
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
686
+ self._num_timesteps = len(timesteps)
687
+
688
+ # 6. Denoising loop
689
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
690
+ for i, t in enumerate(timesteps):
691
+ if self.interrupt:
692
+ continue
693
+
694
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
695
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
696
+
697
+ # handle guidance
698
+ if self.transformer.config.guidance_embeds:
699
+ guidance = torch.tensor([guidance_scale], device=device)
700
+ guidance = guidance.expand(latents.shape[0])
701
+ else:
702
+ guidance = None
703
+
704
+ noise_pred = self.transformer(
705
+ hidden_states=latents,
706
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
707
+ timestep=timestep / 1000,
708
+ guidance=guidance,
709
+ pooled_projections=pooled_prompt_embeds,
710
+ encoder_hidden_states=prompt_embeds,
711
+ t5_encoder_hidden_states=t5_prompt_embeds,
712
+ txt_ids=text_ids,
713
+ img_ids=latent_image_ids,
714
+ joint_attention_kwargs=self.joint_attention_kwargs,
715
+ return_dict=False,
716
+ )[0]
717
+
718
+ # compute the previous noisy sample x_t -> x_t-1
719
+ latents_dtype = latents.dtype
720
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
721
+
722
+ if latents.dtype != latents_dtype:
723
+ if torch.backends.mps.is_available():
724
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
725
+ latents = latents.to(latents_dtype)
726
+
727
+ if callback_on_step_end is not None:
728
+ callback_kwargs = {}
729
+ for k in callback_on_step_end_tensor_inputs:
730
+ callback_kwargs[k] = locals()[k]
731
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
732
+
733
+ latents = callback_outputs.pop("latents", latents)
734
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
735
+
736
+ # call the callback, if provided
737
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
738
+ progress_bar.update()
739
+
740
+ if XLA_AVAILABLE:
741
+ xm.mark_step()
742
+
743
+ if output_type == "latent":
744
+ image = latents
745
+
746
+ else:
747
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
748
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
749
+ image = self.vae.decode(latents, return_dict=False)[0]
750
+ image = self.image_processor.postprocess(image, output_type=output_type)
751
+
752
+ # Offload all models
753
+ self.maybe_free_model_hooks()
754
+
755
+ if not return_dict:
756
+ return (image,)
757
+
758
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_controlnet.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ T5EncoderModel,
24
+ T5TokenizerFast,
25
+ )
26
+
27
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
29
+ from diffusers.loaders import FromSingleFileMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
32
+ from .transformer_flux import FluxTransformer2DModel
33
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import FluxPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers.utils import load_image
62
+ >>> from diffusers import FluxControlNetPipeline
63
+ >>> from diffusers import FluxControlNetModel
64
+
65
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
66
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
67
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
68
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
69
+ ... )
70
+ >>> pipe.to("cuda")
71
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
72
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
73
+ >>> image = pipe(
74
+ ... prompt,
75
+ ... control_image=control_image,
76
+ ... controlnet_conditioning_scale=0.6,
77
+ ... num_inference_steps=28,
78
+ ... guidance_scale=3.5,
79
+ ... ).images[0]
80
+ >>> image.save("flux.png")
81
+ ```
82
+ """
83
+
84
+
85
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
86
+ def calculate_shift(
87
+ image_seq_len,
88
+ base_seq_len: int = 256,
89
+ max_seq_len: int = 4096,
90
+ base_shift: float = 0.5,
91
+ max_shift: float = 1.16,
92
+ ):
93
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
94
+ b = base_shift - m * base_seq_len
95
+ mu = image_seq_len * m + b
96
+ return mu
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ """
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
160
+ r"""
161
+ The Flux pipeline for text-to-image generation.
162
+
163
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
164
+
165
+ Args:
166
+ transformer ([`FluxTransformer2DModel`]):
167
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
168
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
169
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
170
+ vae ([`AutoencoderKL`]):
171
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
172
+ text_encoder ([`CLIPTextModel`]):
173
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
174
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
175
+ text_encoder_2 ([`T5EncoderModel`]):
176
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
177
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
178
+ tokenizer (`CLIPTokenizer`):
179
+ Tokenizer of class
180
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
181
+ tokenizer_2 (`T5TokenizerFast`):
182
+ Second Tokenizer of class
183
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
184
+ """
185
+
186
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
187
+ _optional_components = []
188
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
189
+
190
+ def __init__(
191
+ self,
192
+ scheduler: FlowMatchEulerDiscreteScheduler,
193
+ vae: AutoencoderKL,
194
+ text_encoder: CLIPTextModel,
195
+ tokenizer: CLIPTokenizer,
196
+ transformer: FluxTransformer2DModel,
197
+ controlnet: Union[
198
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
199
+ ],
200
+ text_encoder_2: T5EncoderModel | None = None,
201
+ tokenizer_2: T5TokenizerFast | None = None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.register_modules(
206
+ vae=vae,
207
+ text_encoder=text_encoder,
208
+ #text_encoder_2=text_encoder_2,
209
+ tokenizer=tokenizer,
210
+ #tokenizer_2=tokenizer_2,
211
+ transformer=transformer,
212
+ scheduler=scheduler,
213
+ controlnet=controlnet,
214
+ )
215
+ self.vae_scale_factor = (
216
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
217
+ )
218
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
219
+ self.tokenizer_max_length = (
220
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
221
+ )
222
+ self.default_sample_size = 64
223
+
224
+ def _get_t5_prompt_embeds(
225
+ self,
226
+ prompt: Union[str, List[str]] = None,
227
+ num_images_per_prompt: int = 1,
228
+ max_sequence_length: int = 512,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ ):
232
+ device = device or self._execution_device
233
+ dtype = dtype or self.text_encoder.dtype
234
+
235
+ prompt = [prompt] if isinstance(prompt, str) else prompt
236
+ batch_size = len(prompt)
237
+
238
+ text_inputs = self.tokenizer_2(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ return_length=False,
244
+ return_overflowing_tokens=False,
245
+ return_tensors="pt",
246
+ )
247
+ text_input_ids = text_inputs.input_ids
248
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
249
+
250
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
251
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
252
+ logger.warning(
253
+ "The following part of your input was truncated because `max_sequence_length` is set to "
254
+ f" {max_sequence_length} tokens: {removed_text}"
255
+ )
256
+
257
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
258
+
259
+ dtype = self.text_encoder_2.dtype
260
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
261
+
262
+ _, seq_len, _ = prompt_embeds.shape
263
+
264
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
265
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
267
+
268
+ return prompt_embeds
269
+
270
+ def _get_clip_prompt_embeds(
271
+ self,
272
+ prompt: Union[str, List[str]],
273
+ num_images_per_prompt: int = 1,
274
+ device: Optional[torch.device] = None,
275
+ ):
276
+ device = device or self._execution_device
277
+
278
+ prompt = [prompt] if isinstance(prompt, str) else prompt
279
+ batch_size = len(prompt)
280
+
281
+ text_inputs = self.tokenizer(
282
+ prompt,
283
+ padding="max_length",
284
+ max_length=self.tokenizer_max_length,
285
+ truncation=True,
286
+ return_overflowing_tokens=False,
287
+ return_length=False,
288
+ return_tensors="pt",
289
+ )
290
+
291
+ text_input_ids = text_inputs.input_ids
292
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
293
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
295
+ logger.warning(
296
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
297
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
298
+ )
299
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
300
+
301
+ # Use pooled output of CLIPTextModel
302
+ prompt_embeds = prompt_embeds.pooler_output
303
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304
+
305
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
306
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
307
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
308
+
309
+ return prompt_embeds
310
+
311
+ def encode_prompt(
312
+ self,
313
+ prompt: Union[str, List[str]],
314
+ prompt_2: Union[str, List[str]],
315
+ device: Optional[torch.device] = None,
316
+ num_images_per_prompt: int = 1,
317
+ prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
319
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
320
+ max_sequence_length: int = 512,
321
+ lora_scale: Optional[float] = None,
322
+ ):
323
+ r"""
324
+
325
+ Args:
326
+ prompt (`str` or `List[str]`, *optional*):
327
+ prompt to be encoded
328
+ prompt_2 (`str` or `List[str]`, *optional*):
329
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
330
+ used in all text-encoders
331
+ device: (`torch.device`):
332
+ torch device
333
+ num_images_per_prompt (`int`):
334
+ number of images that should be generated per prompt
335
+ prompt_embeds (`torch.FloatTensor`, *optional*):
336
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
337
+ provided, text embeddings will be generated from `prompt` input argument.
338
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
339
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
340
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
341
+ clip_skip (`int`, *optional*):
342
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
343
+ the output of the pre-final layer will be used for computing the prompt embeddings.
344
+ lora_scale (`float`, *optional*):
345
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
346
+ """
347
+ device = device or self._execution_device
348
+
349
+ # set lora scale so that monkey patched LoRA
350
+ # function of text encoder can correctly access it
351
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
352
+ self._lora_scale = lora_scale
353
+
354
+ # dynamically adjust the LoRA scale
355
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
356
+ scale_lora_layers(self.text_encoder, lora_scale)
357
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
358
+ scale_lora_layers(self.text_encoder_2, lora_scale)
359
+
360
+ prompt = [prompt] if isinstance(prompt, str) else prompt
361
+
362
+ if prompt_embeds is None:
363
+ prompt_2 = prompt_2 or prompt
364
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
365
+
366
+ # We only use the pooled prompt output from the CLIPTextModel
367
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
368
+ prompt=prompt,
369
+ device=device,
370
+ num_images_per_prompt=num_images_per_prompt,
371
+ )
372
+ prompt_embeds = self._get_t5_prompt_embeds(
373
+ prompt=prompt_2,
374
+ num_images_per_prompt=num_images_per_prompt,
375
+ max_sequence_length=max_sequence_length,
376
+ device=device,
377
+ )
378
+
379
+ if self.text_encoder is not None:
380
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
381
+ # Retrieve the original scale by scaling back the LoRA layers
382
+ unscale_lora_layers(self.text_encoder, lora_scale)
383
+
384
+ #if self.text_encoder_2 is not None:
385
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
386
+ # # Retrieve the original scale by scaling back the LoRA layers
387
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
388
+
389
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
390
+ if t5_prompt_embeds is not None:
391
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
392
+ else:
393
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
394
+
395
+ return prompt_embeds, pooled_prompt_embeds, text_ids
396
+
397
+ def check_inputs(
398
+ self,
399
+ prompt,
400
+ prompt_2,
401
+ height,
402
+ width,
403
+ prompt_embeds=None,
404
+ pooled_prompt_embeds=None,
405
+ callback_on_step_end_tensor_inputs=None,
406
+ max_sequence_length=None,
407
+ ):
408
+ if height % 8 != 0 or width % 8 != 0:
409
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
410
+
411
+ if callback_on_step_end_tensor_inputs is not None and not all(
412
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
413
+ ):
414
+ raise ValueError(
415
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
416
+ )
417
+
418
+ if prompt is not None and prompt_embeds is not None:
419
+ raise ValueError(
420
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
421
+ " only forward one of the two."
422
+ )
423
+ elif prompt_2 is not None and prompt_embeds is not None:
424
+ raise ValueError(
425
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
426
+ " only forward one of the two."
427
+ )
428
+ elif prompt is None and prompt_embeds is None:
429
+ raise ValueError(
430
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
431
+ )
432
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
433
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
434
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
435
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
436
+
437
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
438
+ raise ValueError(
439
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
440
+ )
441
+
442
+ if max_sequence_length is not None and max_sequence_length > 512:
443
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
444
+
445
+ @staticmethod
446
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
447
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
448
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
449
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
450
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
451
+
452
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
453
+
454
+ latent_image_ids = latent_image_ids.reshape(
455
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
456
+ )
457
+
458
+ return latent_image_ids.to(device=device, dtype=dtype)
459
+
460
+ @staticmethod
461
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
462
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
463
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
464
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
465
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
466
+
467
+ return latents
468
+
469
+ @staticmethod
470
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
471
+ def _unpack_latents(latents, height, width, vae_scale_factor):
472
+ batch_size, num_patches, channels = latents.shape
473
+
474
+ height = height // vae_scale_factor
475
+ width = width // vae_scale_factor
476
+
477
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
478
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
479
+
480
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
481
+
482
+ return latents
483
+
484
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
485
+ def prepare_latents(
486
+ self,
487
+ batch_size,
488
+ num_channels_latents,
489
+ height,
490
+ width,
491
+ dtype,
492
+ device,
493
+ generator,
494
+ latents=None,
495
+ ):
496
+ height = 2 * (int(height) // self.vae_scale_factor)
497
+ width = 2 * (int(width) // self.vae_scale_factor)
498
+
499
+ shape = (batch_size, num_channels_latents, height, width)
500
+
501
+ if latents is not None:
502
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
503
+ return latents.to(device=device, dtype=dtype), latent_image_ids
504
+
505
+ if isinstance(generator, list) and len(generator) != batch_size:
506
+ raise ValueError(
507
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
508
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
509
+ )
510
+
511
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
512
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
513
+
514
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
515
+
516
+ return latents, latent_image_ids
517
+
518
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
519
+ def prepare_image(
520
+ self,
521
+ image,
522
+ width,
523
+ height,
524
+ batch_size,
525
+ num_images_per_prompt,
526
+ device,
527
+ dtype,
528
+ do_classifier_free_guidance=False,
529
+ guess_mode=False,
530
+ ):
531
+ if isinstance(image, torch.Tensor):
532
+ pass
533
+ else:
534
+ image = self.image_processor.preprocess(image, height=height, width=width)
535
+
536
+ image_batch_size = image.shape[0]
537
+
538
+ if image_batch_size == 1:
539
+ repeat_by = batch_size
540
+ else:
541
+ # image batch size is the same as prompt batch size
542
+ repeat_by = num_images_per_prompt
543
+
544
+ image = image.repeat_interleave(repeat_by, dim=0)
545
+
546
+ image = image.to(device=device, dtype=dtype)
547
+
548
+ if do_classifier_free_guidance and not guess_mode:
549
+ image = torch.cat([image] * 2)
550
+
551
+ return image
552
+
553
+ @property
554
+ def guidance_scale(self):
555
+ return self._guidance_scale
556
+
557
+ @property
558
+ def joint_attention_kwargs(self):
559
+ return self._joint_attention_kwargs
560
+
561
+ @property
562
+ def num_timesteps(self):
563
+ return self._num_timesteps
564
+
565
+ @property
566
+ def interrupt(self):
567
+ return self._interrupt
568
+
569
+ @torch.no_grad()
570
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
571
+ def __call__(
572
+ self,
573
+ prompt: Union[str, List[str]] = None,
574
+ prompt_2: Optional[Union[str, List[str]]] = None,
575
+ height: Optional[int] = None,
576
+ width: Optional[int] = None,
577
+ num_inference_steps: int = 28,
578
+ timesteps: List[int] = None,
579
+ guidance_scale: float = 7.0,
580
+ control_image: PipelineImageInput = None,
581
+ control_mode: Optional[Union[int, List[int]]] = None,
582
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
583
+ num_images_per_prompt: Optional[int] = 1,
584
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
585
+ latents: Optional[torch.FloatTensor] = None,
586
+ prompt_embeds: Optional[torch.FloatTensor] = None,
587
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
588
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
589
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
590
+ output_type: Optional[str] = "pil",
591
+ return_dict: bool = True,
592
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
593
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
594
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
595
+ max_sequence_length: int = 512,
596
+ ):
597
+ r"""
598
+ Function invoked when calling the pipeline for generation.
599
+
600
+ Args:
601
+ prompt (`str` or `List[str]`, *optional*):
602
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
603
+ instead.
604
+ prompt_2 (`str` or `List[str]`, *optional*):
605
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
606
+ will be used instead
607
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
608
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
609
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
611
+ num_inference_steps (`int`, *optional*, defaults to 50):
612
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
613
+ expense of slower inference.
614
+ timesteps (`List[int]`, *optional*):
615
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
616
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
617
+ passed will be used. Must be in descending order.
618
+ guidance_scale (`float`, *optional*, defaults to 7.0):
619
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
620
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
621
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
622
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
623
+ usually at the expense of lower image quality.
624
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
625
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
626
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
627
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
628
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
629
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
630
+ images must be passed as a list such that each element of the list can be correctly batched for input
631
+ to a single ControlNet.
632
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
633
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
634
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
635
+ the corresponding scale as a list.
636
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
637
+ The control mode when applying ControlNet-Union.
638
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
639
+ The number of images to generate per prompt.
640
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
642
+ to make generation deterministic.
643
+ latents (`torch.FloatTensor`, *optional*):
644
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
645
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
646
+ tensor will ge generated by sampling using the supplied random `generator`.
647
+ prompt_embeds (`torch.FloatTensor`, *optional*):
648
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
649
+ provided, text embeddings will be generated from `prompt` input argument.
650
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
651
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
652
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
653
+ output_type (`str`, *optional*, defaults to `"pil"`):
654
+ The output format of the generate image. Choose between
655
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
656
+ return_dict (`bool`, *optional*, defaults to `True`):
657
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
658
+ joint_attention_kwargs (`dict`, *optional*):
659
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
+ `self.processor` in
661
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
+ callback_on_step_end (`Callable`, *optional*):
663
+ A function that calls at the end of each denoising steps during the inference. The function is called
664
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
+ `callback_on_step_end_tensor_inputs`.
667
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
668
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672
+
673
+ Examples:
674
+
675
+ Returns:
676
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
677
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
678
+ images.
679
+ """
680
+
681
+ height = height or self.default_sample_size * self.vae_scale_factor
682
+ width = width or self.default_sample_size * self.vae_scale_factor
683
+
684
+ # 1. Check inputs. Raise error if not correct
685
+ self.check_inputs(
686
+ prompt,
687
+ prompt_2,
688
+ height,
689
+ width,
690
+ prompt_embeds=prompt_embeds,
691
+ pooled_prompt_embeds=pooled_prompt_embeds,
692
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
693
+ max_sequence_length=max_sequence_length,
694
+ )
695
+
696
+ self._guidance_scale = guidance_scale
697
+ self._joint_attention_kwargs = joint_attention_kwargs
698
+ self._interrupt = False
699
+
700
+ # 2. Define call parameters
701
+ if prompt is not None and isinstance(prompt, str):
702
+ batch_size = 1
703
+ elif prompt is not None and isinstance(prompt, list):
704
+ batch_size = len(prompt)
705
+ else:
706
+ batch_size = prompt_embeds.shape[0]
707
+
708
+ device = self._execution_device
709
+ dtype = self.transformer.dtype
710
+
711
+ lora_scale = (
712
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
713
+ )
714
+ (
715
+ prompt_embeds,
716
+ pooled_prompt_embeds,
717
+ text_ids,
718
+ ) = self.encode_prompt(
719
+ prompt=prompt,
720
+ prompt_2=prompt_2,
721
+ prompt_embeds=prompt_embeds,
722
+ t5_prompt_embeds=t5_prompt_embeds,
723
+ pooled_prompt_embeds=pooled_prompt_embeds,
724
+ device=device,
725
+ num_images_per_prompt=num_images_per_prompt,
726
+ max_sequence_length=max_sequence_length,
727
+ lora_scale=lora_scale,
728
+ )
729
+
730
+ # 3. Prepare control image
731
+ num_channels_latents = self.transformer.config.in_channels // 4
732
+ if isinstance(self.controlnet, FluxControlNetModel):
733
+ control_image = self.prepare_image(
734
+ image=control_image,
735
+ width=width,
736
+ height=height,
737
+ batch_size=batch_size * num_images_per_prompt,
738
+ num_images_per_prompt=num_images_per_prompt,
739
+ device=device,
740
+ dtype=self.vae.dtype,
741
+ )
742
+ height, width = control_image.shape[-2:]
743
+
744
+ # vae encode
745
+ control_image = self.vae.encode(control_image).latent_dist.sample()
746
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
747
+
748
+ # pack
749
+ height_control_image, width_control_image = control_image.shape[2:]
750
+ control_image = self._pack_latents(
751
+ control_image,
752
+ batch_size * num_images_per_prompt,
753
+ num_channels_latents,
754
+ height_control_image,
755
+ width_control_image,
756
+ )
757
+
758
+ # Here we ensure that `control_mode` has the same length as the control_image.
759
+ if control_mode is not None:
760
+ if not isinstance(control_mode, int):
761
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
762
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
763
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
764
+
765
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
766
+ control_images = []
767
+
768
+ for control_image_ in control_image:
769
+ control_image_ = self.prepare_image(
770
+ image=control_image_,
771
+ width=width,
772
+ height=height,
773
+ batch_size=batch_size * num_images_per_prompt,
774
+ num_images_per_prompt=num_images_per_prompt,
775
+ device=device,
776
+ dtype=self.vae.dtype,
777
+ )
778
+ height, width = control_image_.shape[-2:]
779
+
780
+ # vae encode
781
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
782
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
783
+
784
+ # pack
785
+ height_control_image, width_control_image = control_image_.shape[2:]
786
+ control_image_ = self._pack_latents(
787
+ control_image_,
788
+ batch_size * num_images_per_prompt,
789
+ num_channels_latents,
790
+ height_control_image,
791
+ width_control_image,
792
+ )
793
+
794
+ control_images.append(control_image_)
795
+
796
+ control_image = control_images
797
+
798
+ # Here we ensure that `control_mode` has the same length as the control_image.
799
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
800
+ raise ValueError(
801
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
802
+ + " length as the number of controlnets (control images) specified"
803
+ )
804
+ if not isinstance(control_mode, list):
805
+ control_mode = [control_mode] * len(control_image)
806
+ # set control mode
807
+ control_modes = []
808
+ for cmode in control_mode:
809
+ if cmode is None:
810
+ cmode = -1
811
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
812
+ control_modes.append(control_mode)
813
+ control_mode = control_modes
814
+
815
+ # 4. Prepare latent variables
816
+ num_channels_latents = self.transformer.config.in_channels // 4
817
+ latents, latent_image_ids = self.prepare_latents(
818
+ batch_size * num_images_per_prompt,
819
+ num_channels_latents,
820
+ height,
821
+ width,
822
+ prompt_embeds.dtype,
823
+ device,
824
+ generator,
825
+ latents,
826
+ )
827
+
828
+ # 5. Prepare timesteps
829
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
830
+ image_seq_len = latents.shape[1]
831
+ mu = calculate_shift(
832
+ image_seq_len,
833
+ self.scheduler.config.base_image_seq_len,
834
+ self.scheduler.config.max_image_seq_len,
835
+ self.scheduler.config.base_shift,
836
+ self.scheduler.config.max_shift,
837
+ )
838
+ timesteps, num_inference_steps = retrieve_timesteps(
839
+ self.scheduler,
840
+ num_inference_steps,
841
+ device,
842
+ timesteps,
843
+ sigmas,
844
+ mu=mu,
845
+ )
846
+
847
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
848
+ self._num_timesteps = len(timesteps)
849
+
850
+ # 6. Denoising loop
851
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
852
+ for i, t in enumerate(timesteps):
853
+ if self.interrupt:
854
+ continue
855
+
856
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
857
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
858
+
859
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
860
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
861
+ else:
862
+ use_guidance = self.controlnet.config.guidance_embeds
863
+
864
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
865
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
866
+
867
+ # controlnet
868
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
869
+ hidden_states=latents,
870
+ controlnet_cond=control_image,
871
+ controlnet_mode=control_mode,
872
+ conditioning_scale=controlnet_conditioning_scale,
873
+ timestep=timestep / 1000,
874
+ guidance=guidance,
875
+ pooled_projections=pooled_prompt_embeds,
876
+ encoder_hidden_states=prompt_embeds_control,
877
+ t5_encoder_hidden_states=t5_prompt_embeds,
878
+ txt_ids=text_ids,
879
+ img_ids=latent_image_ids,
880
+ joint_attention_kwargs=self.joint_attention_kwargs,
881
+ return_dict=False,
882
+ )
883
+
884
+ guidance = (
885
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
886
+ )
887
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
888
+
889
+ noise_pred = self.transformer(
890
+ hidden_states=latents,
891
+ timestep=timestep / 1000,
892
+ guidance=guidance,
893
+ pooled_projections=pooled_prompt_embeds,
894
+ encoder_hidden_states=prompt_embeds,
895
+ t5_encoder_hidden_states=t5_prompt_embeds,
896
+ controlnet_block_samples=controlnet_block_samples,
897
+ controlnet_single_block_samples=controlnet_single_block_samples,
898
+ txt_ids=text_ids,
899
+ img_ids=latent_image_ids,
900
+ joint_attention_kwargs=self.joint_attention_kwargs,
901
+ return_dict=False,
902
+ )[0]
903
+
904
+ # compute the previous noisy sample x_t -> x_t-1
905
+ latents_dtype = latents.dtype
906
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
907
+
908
+ if latents.dtype != latents_dtype:
909
+ if torch.backends.mps.is_available():
910
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
911
+ latents = latents.to(latents_dtype)
912
+
913
+ if callback_on_step_end is not None:
914
+ callback_kwargs = {}
915
+ for k in callback_on_step_end_tensor_inputs:
916
+ callback_kwargs[k] = locals()[k]
917
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
918
+
919
+ latents = callback_outputs.pop("latents", latents)
920
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
921
+
922
+ # call the callback, if provided
923
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
924
+ progress_bar.update()
925
+
926
+ if XLA_AVAILABLE:
927
+ xm.mark_step()
928
+
929
+ if output_type == "latent":
930
+ image = latents
931
+
932
+ else:
933
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
934
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
935
+
936
+ image = self.vae.decode(latents, return_dict=False)[0]
937
+ image = self.image_processor.postprocess(image, output_type=output_type)
938
+
939
+ # Offload all models
940
+ self.maybe_free_model_hooks()
941
+
942
+ if not return_dict:
943
+ return (image,)
944
+
945
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_controlnet_img2img.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ T5EncoderModel,
24
+ T5TokenizerFast,
25
+ )
26
+
27
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
29
+ from diffusers.loaders import FromSingleFileMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
32
+ from .transformer_flux import FluxTransformer2DModel
33
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import FluxPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers.utils import load_image
62
+ >>> from diffusers import FluxControlNetPipeline
63
+ >>> from diffusers import FluxControlNetModel
64
+
65
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
66
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
67
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
68
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
69
+ ... )
70
+ >>> pipe.to("cuda")
71
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
72
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
73
+ >>> image = pipe(
74
+ ... prompt,
75
+ ... control_image=control_image,
76
+ ... controlnet_conditioning_scale=0.6,
77
+ ... num_inference_steps=28,
78
+ ... guidance_scale=3.5,
79
+ ... ).images[0]
80
+ >>> image.save("flux.png")
81
+ ```
82
+ """
83
+
84
+
85
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
86
+ def calculate_shift(
87
+ image_seq_len,
88
+ base_seq_len: int = 256,
89
+ max_seq_len: int = 4096,
90
+ base_shift: float = 0.5,
91
+ max_shift: float = 1.16,
92
+ ):
93
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
94
+ b = base_shift - m * base_seq_len
95
+ mu = image_seq_len * m + b
96
+ return mu
97
+
98
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
99
+ def retrieve_latents(
100
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
101
+ ):
102
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
103
+ return encoder_output.latent_dist.sample(generator)
104
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
105
+ return encoder_output.latent_dist.mode()
106
+ elif hasattr(encoder_output, "latents"):
107
+ return encoder_output.latents
108
+ else:
109
+ raise AttributeError("Could not access latents of provided encoder_output")
110
+
111
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
112
+ def retrieve_timesteps(
113
+ scheduler,
114
+ num_inference_steps: Optional[int] = None,
115
+ device: Optional[Union[str, torch.device]] = None,
116
+ timesteps: Optional[List[int]] = None,
117
+ sigmas: Optional[List[float]] = None,
118
+ **kwargs,
119
+ ):
120
+ """
121
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
122
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
123
+
124
+ Args:
125
+ scheduler (`SchedulerMixin`):
126
+ The scheduler to get timesteps from.
127
+ num_inference_steps (`int`):
128
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
129
+ must be `None`.
130
+ device (`str` or `torch.device`, *optional*):
131
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132
+ timesteps (`List[int]`, *optional*):
133
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
134
+ `num_inference_steps` and `sigmas` must be `None`.
135
+ sigmas (`List[float]`, *optional*):
136
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
137
+ `num_inference_steps` and `timesteps` must be `None`.
138
+
139
+ Returns:
140
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
141
+ second element is the number of inference steps.
142
+ """
143
+ if timesteps is not None and sigmas is not None:
144
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
145
+ if timesteps is not None:
146
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
147
+ if not accepts_timesteps:
148
+ raise ValueError(
149
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
150
+ f" timestep schedules. Please check whether you are using the correct scheduler."
151
+ )
152
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ num_inference_steps = len(timesteps)
155
+ elif sigmas is not None:
156
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
157
+ if not accept_sigmas:
158
+ raise ValueError(
159
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
160
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
161
+ )
162
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
163
+ timesteps = scheduler.timesteps
164
+ num_inference_steps = len(timesteps)
165
+ else:
166
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
167
+ timesteps = scheduler.timesteps
168
+ return timesteps, num_inference_steps
169
+
170
+
171
+ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
172
+ r"""
173
+ The Flux pipeline for text-to-image generation.
174
+
175
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
176
+
177
+ Args:
178
+ transformer ([`FluxTransformer2DModel`]):
179
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
180
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
181
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
182
+ vae ([`AutoencoderKL`]):
183
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
184
+ text_encoder ([`CLIPTextModel`]):
185
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
186
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
187
+ text_encoder_2 ([`T5EncoderModel`]):
188
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
189
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
190
+ tokenizer (`CLIPTokenizer`):
191
+ Tokenizer of class
192
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
193
+ tokenizer_2 (`T5TokenizerFast`):
194
+ Second Tokenizer of class
195
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
196
+ """
197
+
198
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
199
+ _optional_components = []
200
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
201
+
202
+ def __init__(
203
+ self,
204
+ scheduler: FlowMatchEulerDiscreteScheduler,
205
+ vae: AutoencoderKL,
206
+ text_encoder: CLIPTextModel,
207
+ tokenizer: CLIPTokenizer,
208
+ transformer: FluxTransformer2DModel,
209
+ controlnet: Union[
210
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
211
+ ],
212
+ text_encoder_2: T5EncoderModel | None = None,
213
+ tokenizer_2: T5TokenizerFast | None = None,
214
+ ):
215
+ super().__init__()
216
+
217
+ self.register_modules(
218
+ vae=vae,
219
+ text_encoder=text_encoder,
220
+ #text_encoder_2=text_encoder_2,
221
+ tokenizer=tokenizer,
222
+ #tokenizer_2=tokenizer_2,
223
+ transformer=transformer,
224
+ scheduler=scheduler,
225
+ controlnet=controlnet,
226
+ )
227
+ self.vae_scale_factor = (
228
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
229
+ )
230
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
231
+ self.tokenizer_max_length = (
232
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
233
+ )
234
+ self.default_sample_size = 64
235
+
236
+ def _get_t5_prompt_embeds(
237
+ self,
238
+ prompt: Union[str, List[str]] = None,
239
+ num_images_per_prompt: int = 1,
240
+ max_sequence_length: int = 512,
241
+ device: Optional[torch.device] = None,
242
+ dtype: Optional[torch.dtype] = None,
243
+ ):
244
+ device = device or self._execution_device
245
+ dtype = dtype or self.text_encoder.dtype
246
+
247
+ prompt = [prompt] if isinstance(prompt, str) else prompt
248
+ batch_size = len(prompt)
249
+
250
+ text_inputs = self.tokenizer_2(
251
+ prompt,
252
+ padding="max_length",
253
+ max_length=max_sequence_length,
254
+ truncation=True,
255
+ return_length=False,
256
+ return_overflowing_tokens=False,
257
+ return_tensors="pt",
258
+ )
259
+ text_input_ids = text_inputs.input_ids
260
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
261
+
262
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
263
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
264
+ logger.warning(
265
+ "The following part of your input was truncated because `max_sequence_length` is set to "
266
+ f" {max_sequence_length} tokens: {removed_text}"
267
+ )
268
+
269
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
270
+
271
+ dtype = self.text_encoder_2.dtype
272
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
273
+
274
+ _, seq_len, _ = prompt_embeds.shape
275
+
276
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
277
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
278
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
279
+
280
+ return prompt_embeds
281
+
282
+ def _get_clip_prompt_embeds(
283
+ self,
284
+ prompt: Union[str, List[str]],
285
+ num_images_per_prompt: int = 1,
286
+ device: Optional[torch.device] = None,
287
+ ):
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ batch_size = len(prompt)
292
+
293
+ text_inputs = self.tokenizer(
294
+ prompt,
295
+ padding="max_length",
296
+ max_length=self.tokenizer_max_length,
297
+ truncation=True,
298
+ return_overflowing_tokens=False,
299
+ return_length=False,
300
+ return_tensors="pt",
301
+ )
302
+
303
+ text_input_ids = text_inputs.input_ids
304
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
305
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
306
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
307
+ logger.warning(
308
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
309
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
310
+ )
311
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
312
+
313
+ # Use pooled output of CLIPTextModel
314
+ prompt_embeds = prompt_embeds.pooler_output
315
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
316
+
317
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
318
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
319
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
320
+
321
+ return prompt_embeds
322
+
323
+ def encode_prompt(
324
+ self,
325
+ prompt: Union[str, List[str]],
326
+ prompt_2: Union[str, List[str]],
327
+ device: Optional[torch.device] = None,
328
+ num_images_per_prompt: int = 1,
329
+ prompt_embeds: Optional[torch.FloatTensor] = None,
330
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
331
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
332
+ max_sequence_length: int = 512,
333
+ lora_scale: Optional[float] = None,
334
+ ):
335
+ r"""
336
+
337
+ Args:
338
+ prompt (`str` or `List[str]`, *optional*):
339
+ prompt to be encoded
340
+ prompt_2 (`str` or `List[str]`, *optional*):
341
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
342
+ used in all text-encoders
343
+ device: (`torch.device`):
344
+ torch device
345
+ num_images_per_prompt (`int`):
346
+ number of images that should be generated per prompt
347
+ prompt_embeds (`torch.FloatTensor`, *optional*):
348
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
349
+ provided, text embeddings will be generated from `prompt` input argument.
350
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
351
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
352
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
353
+ clip_skip (`int`, *optional*):
354
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
355
+ the output of the pre-final layer will be used for computing the prompt embeddings.
356
+ lora_scale (`float`, *optional*):
357
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
358
+ """
359
+ device = device or self._execution_device
360
+
361
+ # set lora scale so that monkey patched LoRA
362
+ # function of text encoder can correctly access it
363
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
364
+ self._lora_scale = lora_scale
365
+
366
+ # dynamically adjust the LoRA scale
367
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
368
+ scale_lora_layers(self.text_encoder, lora_scale)
369
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
370
+ scale_lora_layers(self.text_encoder_2, lora_scale)
371
+
372
+ prompt = [prompt] if isinstance(prompt, str) else prompt
373
+
374
+ if prompt_embeds is None:
375
+ prompt_2 = prompt_2 or prompt
376
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
377
+
378
+ # We only use the pooled prompt output from the CLIPTextModel
379
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
380
+ prompt=prompt,
381
+ device=device,
382
+ num_images_per_prompt=num_images_per_prompt,
383
+ )
384
+ prompt_embeds = self._get_t5_prompt_embeds(
385
+ prompt=prompt_2,
386
+ num_images_per_prompt=num_images_per_prompt,
387
+ max_sequence_length=max_sequence_length,
388
+ device=device,
389
+ )
390
+
391
+ if self.text_encoder is not None:
392
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
393
+ # Retrieve the original scale by scaling back the LoRA layers
394
+ unscale_lora_layers(self.text_encoder, lora_scale)
395
+
396
+ #if self.text_encoder_2 is not None:
397
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
398
+ # # Retrieve the original scale by scaling back the LoRA layers
399
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
400
+
401
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
402
+ if t5_prompt_embeds is not None:
403
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
404
+ else:
405
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
406
+
407
+ return prompt_embeds, pooled_prompt_embeds, text_ids
408
+
409
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
410
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
411
+ if isinstance(generator, list):
412
+ image_latents = [
413
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
414
+ for i in range(image.shape[0])
415
+ ]
416
+ image_latents = torch.cat(image_latents, dim=0)
417
+ else:
418
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
419
+
420
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
421
+
422
+ return image_latents
423
+
424
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
425
+ def get_timesteps(self, num_inference_steps, strength, device):
426
+ # get the original timestep using init_timestep
427
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
428
+
429
+ t_start = int(max(num_inference_steps - init_timestep, 0))
430
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
431
+ if hasattr(self.scheduler, "set_begin_index"):
432
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
433
+
434
+ return timesteps, num_inference_steps - t_start
435
+
436
+ def check_inputs(
437
+ self,
438
+ prompt,
439
+ prompt_2,
440
+ height,
441
+ width,
442
+ prompt_embeds=None,
443
+ pooled_prompt_embeds=None,
444
+ callback_on_step_end_tensor_inputs=None,
445
+ max_sequence_length=None,
446
+ ):
447
+ if height % 8 != 0 or width % 8 != 0:
448
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
449
+
450
+ if callback_on_step_end_tensor_inputs is not None and not all(
451
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
452
+ ):
453
+ raise ValueError(
454
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
455
+ )
456
+
457
+ if prompt is not None and prompt_embeds is not None:
458
+ raise ValueError(
459
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
460
+ " only forward one of the two."
461
+ )
462
+ elif prompt_2 is not None and prompt_embeds is not None:
463
+ raise ValueError(
464
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
465
+ " only forward one of the two."
466
+ )
467
+ elif prompt is None and prompt_embeds is None:
468
+ raise ValueError(
469
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
470
+ )
471
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
472
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
473
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
474
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
475
+
476
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
477
+ raise ValueError(
478
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
479
+ )
480
+
481
+ if max_sequence_length is not None and max_sequence_length > 512:
482
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
483
+
484
+ @staticmethod
485
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
486
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
487
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
488
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
489
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
490
+
491
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
492
+
493
+ latent_image_ids = latent_image_ids.reshape(
494
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
495
+ )
496
+
497
+ return latent_image_ids.to(device=device, dtype=dtype)
498
+
499
+ @staticmethod
500
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
501
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
502
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
503
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
504
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
505
+
506
+ return latents
507
+
508
+ @staticmethod
509
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
510
+ def _unpack_latents(latents, height, width, vae_scale_factor):
511
+ batch_size, num_patches, channels = latents.shape
512
+
513
+ height = height // vae_scale_factor
514
+ width = width // vae_scale_factor
515
+
516
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
517
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
518
+
519
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
520
+
521
+ return latents
522
+
523
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
524
+ def prepare_latents(
525
+ self,
526
+ image,
527
+ timestep,
528
+ batch_size,
529
+ num_channels_latents,
530
+ height,
531
+ width,
532
+ dtype,
533
+ device,
534
+ generator,
535
+ latents=None,
536
+ ):
537
+ height = 2 * (int(height) // self.vae_scale_factor)
538
+ width = 2 * (int(width) // self.vae_scale_factor)
539
+
540
+ shape = (batch_size, num_channels_latents, height, width)
541
+
542
+ if latents is not None:
543
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
544
+ return latents.to(device=device, dtype=dtype), latent_image_ids
545
+
546
+ if isinstance(generator, list) and len(generator) != batch_size:
547
+ raise ValueError(
548
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
549
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
550
+ )
551
+ image = image.to(device=device, dtype=dtype)
552
+ image_latents = self._encode_vae_image(image=image, generator=generator)
553
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
554
+ if timestep == 28:
555
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
556
+ else:
557
+ latents = noise
558
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
559
+
560
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
561
+
562
+ return latents, latent_image_ids
563
+
564
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
565
+ def prepare_image(
566
+ self,
567
+ image,
568
+ width,
569
+ height,
570
+ batch_size,
571
+ num_images_per_prompt,
572
+ device,
573
+ dtype,
574
+ do_classifier_free_guidance=False,
575
+ guess_mode=False,
576
+ ):
577
+ if isinstance(image, torch.Tensor):
578
+ pass
579
+ else:
580
+ image = self.image_processor.preprocess(image, height=height, width=width)
581
+
582
+ image_batch_size = image.shape[0]
583
+
584
+ if image_batch_size == 1:
585
+ repeat_by = batch_size
586
+ else:
587
+ # image batch size is the same as prompt batch size
588
+ repeat_by = num_images_per_prompt
589
+
590
+ image = image.repeat_interleave(repeat_by, dim=0)
591
+
592
+ image = image.to(device=device, dtype=dtype)
593
+
594
+ if do_classifier_free_guidance and not guess_mode:
595
+ image = torch.cat([image] * 2)
596
+
597
+ return image
598
+
599
+ @property
600
+ def guidance_scale(self):
601
+ return self._guidance_scale
602
+
603
+ @property
604
+ def joint_attention_kwargs(self):
605
+ return self._joint_attention_kwargs
606
+
607
+ @property
608
+ def num_timesteps(self):
609
+ return self._num_timesteps
610
+
611
+ @property
612
+ def interrupt(self):
613
+ return self._interrupt
614
+
615
+ @torch.no_grad()
616
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
617
+ def __call__(
618
+ self,
619
+ prompt: Union[str, List[str]] = None,
620
+ prompt_2: Optional[Union[str, List[str]]] = None,
621
+ image: PipelineImageInput = None,
622
+ height: Optional[int] = None,
623
+ width: Optional[int] = None,
624
+ strength: float = 0.6,
625
+ num_inference_steps: int = 28,
626
+ timesteps: List[int] = None,
627
+ guidance_scale: float = 7.0,
628
+ control_image: PipelineImageInput = None,
629
+ control_mode: Optional[Union[int, List[int]]] = None,
630
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
631
+ num_images_per_prompt: Optional[int] = 1,
632
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
633
+ latents: Optional[torch.FloatTensor] = None,
634
+ prompt_embeds: Optional[torch.FloatTensor] = None,
635
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
636
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
637
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
638
+ output_type: Optional[str] = "pil",
639
+ return_dict: bool = True,
640
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
641
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
642
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
643
+ max_sequence_length: int = 512,
644
+ ):
645
+ r"""
646
+ Function invoked when calling the pipeline for generation.
647
+
648
+ Args:
649
+ prompt (`str` or `List[str]`, *optional*):
650
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
651
+ instead.
652
+ prompt_2 (`str` or `List[str]`, *optional*):
653
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
654
+ will be used instead
655
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
656
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
657
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
658
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
659
+ num_inference_steps (`int`, *optional*, defaults to 50):
660
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
661
+ expense of slower inference.
662
+ timesteps (`List[int]`, *optional*):
663
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
664
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
665
+ passed will be used. Must be in descending order.
666
+ guidance_scale (`float`, *optional*, defaults to 7.0):
667
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
668
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
669
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
670
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
671
+ usually at the expense of lower image quality.
672
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
673
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
674
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
675
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
676
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
677
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
678
+ images must be passed as a list such that each element of the list can be correctly batched for input
679
+ to a single ControlNet.
680
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
681
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
682
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
683
+ the corresponding scale as a list.
684
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
685
+ The control mode when applying ControlNet-Union.
686
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
687
+ The number of images to generate per prompt.
688
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
689
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
690
+ to make generation deterministic.
691
+ latents (`torch.FloatTensor`, *optional*):
692
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
693
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
694
+ tensor will ge generated by sampling using the supplied random `generator`.
695
+ prompt_embeds (`torch.FloatTensor`, *optional*):
696
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
697
+ provided, text embeddings will be generated from `prompt` input argument.
698
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
699
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
700
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
701
+ output_type (`str`, *optional*, defaults to `"pil"`):
702
+ The output format of the generate image. Choose between
703
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
704
+ return_dict (`bool`, *optional*, defaults to `True`):
705
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
706
+ joint_attention_kwargs (`dict`, *optional*):
707
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
708
+ `self.processor` in
709
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
710
+ callback_on_step_end (`Callable`, *optional*):
711
+ A function that calls at the end of each denoising steps during the inference. The function is called
712
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
713
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
714
+ `callback_on_step_end_tensor_inputs`.
715
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
716
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
717
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
718
+ `._callback_tensor_inputs` attribute of your pipeline class.
719
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
720
+
721
+ Examples:
722
+
723
+ Returns:
724
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
725
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
726
+ images.
727
+ """
728
+
729
+ height = height or self.default_sample_size * self.vae_scale_factor
730
+ width = width or self.default_sample_size * self.vae_scale_factor
731
+
732
+ # 1. Check inputs. Raise error if not correct
733
+ self.check_inputs(
734
+ prompt,
735
+ prompt_2,
736
+ height,
737
+ width,
738
+ prompt_embeds=prompt_embeds,
739
+ pooled_prompt_embeds=pooled_prompt_embeds,
740
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
741
+ max_sequence_length=max_sequence_length,
742
+ )
743
+
744
+ self._guidance_scale = guidance_scale
745
+ self._joint_attention_kwargs = joint_attention_kwargs
746
+ self._interrupt = False
747
+
748
+ # 2. Preprocess image
749
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
750
+ init_image = init_image.to(dtype=torch.float32)
751
+
752
+ # 2. Define call parameters
753
+ if prompt is not None and isinstance(prompt, str):
754
+ batch_size = 1
755
+ elif prompt is not None and isinstance(prompt, list):
756
+ batch_size = len(prompt)
757
+ else:
758
+ batch_size = prompt_embeds.shape[0]
759
+
760
+ device = self._execution_device
761
+ dtype = self.transformer.dtype
762
+
763
+ lora_scale = (
764
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
765
+ )
766
+ (
767
+ prompt_embeds,
768
+ pooled_prompt_embeds,
769
+ text_ids,
770
+ ) = self.encode_prompt(
771
+ prompt=prompt,
772
+ prompt_2=prompt_2,
773
+ prompt_embeds=prompt_embeds,
774
+ t5_prompt_embeds=t5_prompt_embeds,
775
+ pooled_prompt_embeds=pooled_prompt_embeds,
776
+ device=device,
777
+ num_images_per_prompt=num_images_per_prompt,
778
+ max_sequence_length=max_sequence_length,
779
+ lora_scale=lora_scale,
780
+ )
781
+
782
+ # 3. Prepare control image
783
+ num_channels_latents = self.transformer.config.in_channels // 4
784
+ if isinstance(self.controlnet, FluxControlNetModel):
785
+ control_image = self.prepare_image(
786
+ image=control_image,
787
+ width=width,
788
+ height=height,
789
+ batch_size=batch_size * num_images_per_prompt,
790
+ num_images_per_prompt=num_images_per_prompt,
791
+ device=device,
792
+ dtype=self.vae.dtype,
793
+ )
794
+ height, width = control_image.shape[-2:]
795
+
796
+ # vae encode
797
+ control_image = self.vae.encode(control_image).latent_dist.sample()
798
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
799
+
800
+ # pack
801
+ height_control_image, width_control_image = control_image.shape[2:]
802
+ control_image = self._pack_latents(
803
+ control_image,
804
+ batch_size * num_images_per_prompt,
805
+ num_channels_latents,
806
+ height_control_image,
807
+ width_control_image,
808
+ )
809
+
810
+ # Here we ensure that `control_mode` has the same length as the control_image.
811
+ if control_mode is not None:
812
+ if not isinstance(control_mode, int):
813
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
814
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
815
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
816
+
817
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
818
+ control_images = []
819
+
820
+ for control_image_ in control_image:
821
+ control_image_ = self.prepare_image(
822
+ image=control_image_,
823
+ width=width,
824
+ height=height,
825
+ batch_size=batch_size * num_images_per_prompt,
826
+ num_images_per_prompt=num_images_per_prompt,
827
+ device=device,
828
+ dtype=self.vae.dtype,
829
+ )
830
+ height, width = control_image_.shape[-2:]
831
+
832
+ # vae encode
833
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
834
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
835
+
836
+ # pack
837
+ height_control_image, width_control_image = control_image_.shape[2:]
838
+ control_image_ = self._pack_latents(
839
+ control_image_,
840
+ batch_size * num_images_per_prompt,
841
+ num_channels_latents,
842
+ height_control_image,
843
+ width_control_image,
844
+ )
845
+
846
+ control_images.append(control_image_)
847
+
848
+ control_image = control_images
849
+
850
+ # Here we ensure that `control_mode` has the same length as the control_image.
851
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
852
+ raise ValueError(
853
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
854
+ + " length as the number of controlnets (control images) specified"
855
+ )
856
+ if not isinstance(control_mode, list):
857
+ control_mode = [control_mode] * len(control_image)
858
+ # set control mode
859
+ control_modes = []
860
+ for cmode in control_mode:
861
+ if cmode is None:
862
+ cmode = -1
863
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
864
+ control_modes.append(control_mode)
865
+ control_mode = control_modes
866
+
867
+ # 5. Prepare timesteps
868
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
869
+ #image_seq_len = latents.shape[1]
870
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
871
+ mu = calculate_shift(
872
+ image_seq_len,
873
+ self.scheduler.config.base_image_seq_len,
874
+ self.scheduler.config.max_image_seq_len,
875
+ self.scheduler.config.base_shift,
876
+ self.scheduler.config.max_shift,
877
+ )
878
+ timesteps, num_inference_steps = retrieve_timesteps(
879
+ self.scheduler,
880
+ num_inference_steps,
881
+ device,
882
+ timesteps,
883
+ sigmas,
884
+ mu=mu,
885
+ )
886
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
887
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
888
+
889
+ # 4. Prepare latent variables
890
+ num_channels_latents = self.transformer.config.in_channels // 4
891
+ latents, latent_image_ids = self.prepare_latents(
892
+ init_image,
893
+ latent_timestep,
894
+ batch_size * num_images_per_prompt,
895
+ num_channels_latents,
896
+ height,
897
+ width,
898
+ prompt_embeds.dtype,
899
+ device,
900
+ generator,
901
+ latents,
902
+ )
903
+
904
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
905
+ self._num_timesteps = len(timesteps)
906
+
907
+ # 6. Denoising loop
908
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
909
+ for i, t in enumerate(timesteps):
910
+ if self.interrupt:
911
+ continue
912
+
913
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
914
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
915
+
916
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
917
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
918
+ else:
919
+ use_guidance = self.controlnet.config.guidance_embeds
920
+
921
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
922
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
923
+
924
+ # controlnet
925
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
926
+ hidden_states=latents,
927
+ controlnet_cond=control_image,
928
+ controlnet_mode=control_mode,
929
+ conditioning_scale=controlnet_conditioning_scale,
930
+ timestep=timestep / 1000,
931
+ guidance=guidance,
932
+ pooled_projections=pooled_prompt_embeds,
933
+ encoder_hidden_states=prompt_embeds_control,
934
+ t5_encoder_hidden_states=t5_prompt_embeds,
935
+ txt_ids=text_ids,
936
+ img_ids=latent_image_ids,
937
+ joint_attention_kwargs=self.joint_attention_kwargs,
938
+ return_dict=False,
939
+ )
940
+
941
+ guidance = (
942
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
943
+ )
944
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
945
+
946
+ noise_pred = self.transformer(
947
+ hidden_states=latents,
948
+ timestep=timestep / 1000,
949
+ guidance=guidance,
950
+ pooled_projections=pooled_prompt_embeds,
951
+ encoder_hidden_states=prompt_embeds,
952
+ t5_encoder_hidden_states=t5_prompt_embeds,
953
+ controlnet_block_samples=controlnet_block_samples,
954
+ controlnet_single_block_samples=controlnet_single_block_samples,
955
+ txt_ids=text_ids,
956
+ img_ids=latent_image_ids,
957
+ joint_attention_kwargs=self.joint_attention_kwargs,
958
+ return_dict=False,
959
+ )[0]
960
+
961
+ # compute the previous noisy sample x_t -> x_t-1
962
+ latents_dtype = latents.dtype
963
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
964
+
965
+ if latents.dtype != latents_dtype:
966
+ if torch.backends.mps.is_available():
967
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
968
+ latents = latents.to(latents_dtype)
969
+
970
+ if callback_on_step_end is not None:
971
+ callback_kwargs = {}
972
+ for k in callback_on_step_end_tensor_inputs:
973
+ callback_kwargs[k] = locals()[k]
974
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
975
+
976
+ latents = callback_outputs.pop("latents", latents)
977
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
978
+
979
+ # call the callback, if provided
980
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
981
+ progress_bar.update()
982
+
983
+ if XLA_AVAILABLE:
984
+ xm.mark_step()
985
+
986
+ if output_type == "latent":
987
+ image = latents
988
+
989
+ else:
990
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
991
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
992
+
993
+ image = self.vae.decode(latents, return_dict=False)[0]
994
+ image = self.image_processor.postprocess(image, output_type=output_type)
995
+
996
+ # Offload all models
997
+ self.maybe_free_model_hooks()
998
+
999
+ if not return_dict:
1000
+ return (image,)
1001
+
1002
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_controlnet_inpainting.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from transformers import (
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ T5EncoderModel,
11
+ T5TokenizerFast,
12
+ )
13
+
14
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
16
+ from diffusers.models.autoencoders import AutoencoderKL
17
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
18
+ from .transformer_flux import FluxTransformer2DModel
19
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
20
+ from diffusers.utils import (
21
+ USE_PEFT_BACKEND,
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ scale_lora_layers,
26
+ unscale_lora_layers,
27
+ )
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from .pipeline_output import FluxPipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```py
45
+ >>> import torch
46
+ >>> from diffusers import FluxControlNetInpaintPipeline
47
+ >>> from diffusers.models import FluxControlNetModel
48
+ >>> from diffusers.utils import load_image
49
+
50
+ >>> controlnet = FluxControlNetModel.from_pretrained(
51
+ ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
52
+ ... )
53
+ >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
54
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
55
+ ... )
56
+ >>> pipe.to("cuda")
57
+
58
+ >>> control_image = load_image(
59
+ ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
60
+ ... )
61
+ >>> init_image = load_image(
62
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
63
+ ... )
64
+ >>> mask_image = load_image(
65
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
66
+ ... )
67
+
68
+ >>> prompt = "A girl holding a sign that says InstantX"
69
+ >>> image = pipe(
70
+ ... prompt,
71
+ ... image=init_image,
72
+ ... mask_image=mask_image,
73
+ ... control_image=control_image,
74
+ ... control_guidance_start=0.2,
75
+ ... control_guidance_end=0.8,
76
+ ... controlnet_conditioning_scale=0.7,
77
+ ... strength=0.7,
78
+ ... num_inference_steps=28,
79
+ ... guidance_scale=3.5,
80
+ ... ).images[0]
81
+ >>> image.save("flux_controlnet_inpaint.png")
82
+ ```
83
+ """
84
+
85
+
86
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.16,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101
+ def retrieve_latents(
102
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103
+ ):
104
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105
+ return encoder_output.latent_dist.sample(generator)
106
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107
+ return encoder_output.latent_dist.mode()
108
+ elif hasattr(encoder_output, "latents"):
109
+ return encoder_output.latents
110
+ else:
111
+ raise AttributeError("Could not access latents of provided encoder_output")
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
115
+ def retrieve_timesteps(
116
+ scheduler,
117
+ num_inference_steps: Optional[int] = None,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ timesteps: Optional[List[int]] = None,
120
+ sigmas: Optional[List[float]] = None,
121
+ **kwargs,
122
+ ):
123
+ r"""
124
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
125
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
126
+
127
+ Args:
128
+ scheduler (`SchedulerMixin`):
129
+ The scheduler to get timesteps from.
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
132
+ must be `None`.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ timesteps (`List[int]`, *optional*):
136
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
137
+ `num_inference_steps` and `sigmas` must be `None`.
138
+ sigmas (`List[float]`, *optional*):
139
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
140
+ `num_inference_steps` and `timesteps` must be `None`.
141
+
142
+ Returns:
143
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
144
+ second element is the number of inference steps.
145
+ """
146
+ if timesteps is not None and sigmas is not None:
147
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
148
+ if timesteps is not None:
149
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accepts_timesteps:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" timestep schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ elif sigmas is not None:
159
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if not accept_sigmas:
161
+ raise ValueError(
162
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
163
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
164
+ )
165
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ num_inference_steps = len(timesteps)
168
+ else:
169
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ return timesteps, num_inference_steps
172
+
173
+
174
+ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
175
+ r"""
176
+ The Flux controlnet pipeline for inpainting.
177
+
178
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
179
+
180
+ Args:
181
+ transformer ([`FluxTransformer2DModel`]):
182
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
183
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
184
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
185
+ vae ([`AutoencoderKL`]):
186
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
187
+ text_encoder ([`CLIPTextModel`]):
188
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
189
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
190
+ text_encoder_2 ([`T5EncoderModel`]):
191
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
192
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
193
+ tokenizer (`CLIPTokenizer`):
194
+ Tokenizer of class
195
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
196
+ tokenizer_2 (`T5TokenizerFast`):
197
+ Second Tokenizer of class
198
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199
+ """
200
+
201
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202
+ _optional_components = []
203
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
204
+
205
+ def __init__(
206
+ self,
207
+ scheduler: FlowMatchEulerDiscreteScheduler,
208
+ vae: AutoencoderKL,
209
+ text_encoder: CLIPTextModel,
210
+ tokenizer: CLIPTokenizer,
211
+ transformer: FluxTransformer2DModel,
212
+ controlnet: Union[
213
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
214
+ ],
215
+ text_encoder_2: T5EncoderModel | None = None,
216
+ tokenizer_2: T5TokenizerFast | None = None,
217
+ ):
218
+ super().__init__()
219
+ if isinstance(controlnet, (list, tuple)):
220
+ controlnet = FluxMultiControlNetModel(controlnet)
221
+
222
+ self.register_modules(
223
+ scheduler=scheduler,
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ text_encoder_2=text_encoder_2,
228
+ tokenizer_2=tokenizer_2,
229
+ transformer=transformer,
230
+ controlnet=controlnet,
231
+ )
232
+
233
+ self.vae_scale_factor = (
234
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
235
+ )
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+ self.mask_processor = VaeImageProcessor(
238
+ vae_scale_factor=self.vae_scale_factor,
239
+ vae_latent_channels=self.vae.config.latent_channels,
240
+ do_normalize=False,
241
+ do_binarize=True,
242
+ do_convert_grayscale=True,
243
+ )
244
+ self.tokenizer_max_length = (
245
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
246
+ )
247
+ self.default_sample_size = 64
248
+
249
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
250
+ def _get_t5_prompt_embeds(
251
+ self,
252
+ prompt: Union[str, List[str]] = None,
253
+ num_images_per_prompt: int = 1,
254
+ max_sequence_length: int = 512,
255
+ device: Optional[torch.device] = None,
256
+ dtype: Optional[torch.dtype] = None,
257
+ ):
258
+ device = device or self._execution_device
259
+ dtype = dtype or self.text_encoder.dtype
260
+
261
+ prompt = [prompt] if isinstance(prompt, str) else prompt
262
+ batch_size = len(prompt)
263
+
264
+ if isinstance(self, TextualInversionLoaderMixin):
265
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
266
+
267
+ text_inputs = self.tokenizer_2(
268
+ prompt,
269
+ padding="max_length",
270
+ max_length=max_sequence_length,
271
+ truncation=True,
272
+ return_length=False,
273
+ return_overflowing_tokens=False,
274
+ return_tensors="pt",
275
+ )
276
+ text_input_ids = text_inputs.input_ids
277
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
278
+
279
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
280
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
281
+ logger.warning(
282
+ "The following part of your input was truncated because `max_sequence_length` is set to "
283
+ f" {max_sequence_length} tokens: {removed_text}"
284
+ )
285
+
286
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
287
+
288
+ dtype = self.text_encoder_2.dtype
289
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
290
+
291
+ _, seq_len, _ = prompt_embeds.shape
292
+
293
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
295
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
296
+
297
+ return prompt_embeds
298
+
299
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
300
+ def _get_clip_prompt_embeds(
301
+ self,
302
+ prompt: Union[str, List[str]],
303
+ num_images_per_prompt: int = 1,
304
+ device: Optional[torch.device] = None,
305
+ ):
306
+ device = device or self._execution_device
307
+
308
+ prompt = [prompt] if isinstance(prompt, str) else prompt
309
+ batch_size = len(prompt)
310
+
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
313
+
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=self.tokenizer_max_length,
318
+ truncation=True,
319
+ return_overflowing_tokens=False,
320
+ return_length=False,
321
+ return_tensors="pt",
322
+ )
323
+
324
+ text_input_ids = text_inputs.input_ids
325
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
326
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
327
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
328
+ logger.warning(
329
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
330
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
331
+ )
332
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
333
+
334
+ # Use pooled output of CLIPTextModel
335
+ prompt_embeds = prompt_embeds.pooler_output
336
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
337
+
338
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
339
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
340
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341
+
342
+ return prompt_embeds
343
+
344
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
345
+ def encode_prompt(
346
+ self,
347
+ prompt: Union[str, List[str]],
348
+ prompt_2: Union[str, List[str]],
349
+ device: Optional[torch.device] = None,
350
+ num_images_per_prompt: int = 1,
351
+ prompt_embeds: Optional[torch.FloatTensor] = None,
352
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
353
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
354
+ max_sequence_length: int = 512,
355
+ lora_scale: Optional[float] = None,
356
+ ):
357
+ r"""
358
+
359
+ Args:
360
+ prompt (`str` or `List[str]`, *optional*):
361
+ prompt to be encoded
362
+ prompt_2 (`str` or `List[str]`, *optional*):
363
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
364
+ used in all text-encoders
365
+ device: (`torch.device`):
366
+ torch device
367
+ num_images_per_prompt (`int`):
368
+ number of images that should be generated per prompt
369
+ prompt_embeds (`torch.FloatTensor`, *optional*):
370
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
371
+ provided, text embeddings will be generated from `prompt` input argument.
372
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
373
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
374
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
375
+ lora_scale (`float`, *optional*):
376
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
377
+ """
378
+ device = device or self._execution_device
379
+
380
+ # set lora scale so that monkey patched LoRA
381
+ # function of text encoder can correctly access it
382
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
383
+ self._lora_scale = lora_scale
384
+
385
+ # dynamically adjust the LoRA scale
386
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
387
+ scale_lora_layers(self.text_encoder, lora_scale)
388
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
389
+ scale_lora_layers(self.text_encoder_2, lora_scale)
390
+
391
+ prompt = [prompt] if isinstance(prompt, str) else prompt
392
+
393
+ if prompt_embeds is None:
394
+ prompt_2 = prompt_2 or prompt
395
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
396
+
397
+ # We only use the pooled prompt output from the CLIPTextModel
398
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
399
+ prompt=prompt,
400
+ device=device,
401
+ num_images_per_prompt=num_images_per_prompt,
402
+ )
403
+ prompt_embeds = self._get_t5_prompt_embeds(
404
+ prompt=prompt_2,
405
+ num_images_per_prompt=num_images_per_prompt,
406
+ max_sequence_length=max_sequence_length,
407
+ device=device,
408
+ )
409
+
410
+ if self.text_encoder is not None:
411
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
412
+ # Retrieve the original scale by scaling back the LoRA layers
413
+ unscale_lora_layers(self.text_encoder, lora_scale)
414
+
415
+ if self.text_encoder_2 is not None:
416
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
417
+ # Retrieve the original scale by scaling back the LoRA layers
418
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
419
+
420
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
421
+ if t5_prompt_embeds is not None:
422
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
423
+ else:
424
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
425
+
426
+
427
+ return prompt_embeds, pooled_prompt_embeds, text_ids
428
+
429
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
430
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
431
+ if isinstance(generator, list):
432
+ image_latents = [
433
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
434
+ for i in range(image.shape[0])
435
+ ]
436
+ image_latents = torch.cat(image_latents, dim=0)
437
+ else:
438
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
439
+
440
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
441
+
442
+ return image_latents
443
+
444
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
445
+ def get_timesteps(self, num_inference_steps, strength, device):
446
+ # get the original timestep using init_timestep
447
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
448
+
449
+ t_start = int(max(num_inference_steps - init_timestep, 0))
450
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
451
+ if hasattr(self.scheduler, "set_begin_index"):
452
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
453
+
454
+ return timesteps, num_inference_steps - t_start
455
+
456
+ def check_inputs(
457
+ self,
458
+ prompt,
459
+ prompt_2,
460
+ image,
461
+ mask_image,
462
+ strength,
463
+ height,
464
+ width,
465
+ output_type,
466
+ prompt_embeds=None,
467
+ pooled_prompt_embeds=None,
468
+ callback_on_step_end_tensor_inputs=None,
469
+ padding_mask_crop=None,
470
+ max_sequence_length=None,
471
+ ):
472
+ if strength < 0 or strength > 1:
473
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
474
+
475
+ if height % 8 != 0 or width % 8 != 0:
476
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
477
+
478
+ if callback_on_step_end_tensor_inputs is not None and not all(
479
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
480
+ ):
481
+ raise ValueError(
482
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
483
+ )
484
+
485
+ if prompt is not None and prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
488
+ " only forward one of the two."
489
+ )
490
+ elif prompt_2 is not None and prompt_embeds is not None:
491
+ raise ValueError(
492
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
493
+ " only forward one of the two."
494
+ )
495
+ elif prompt is None and prompt_embeds is None:
496
+ raise ValueError(
497
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
498
+ )
499
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
500
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
501
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
502
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
503
+
504
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
505
+ raise ValueError(
506
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
507
+ )
508
+
509
+ if padding_mask_crop is not None:
510
+ if not isinstance(image, PIL.Image.Image):
511
+ raise ValueError(
512
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
513
+ )
514
+ if not isinstance(mask_image, PIL.Image.Image):
515
+ raise ValueError(
516
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
517
+ f" {type(mask_image)}."
518
+ )
519
+ if output_type != "pil":
520
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
521
+
522
+ if max_sequence_length is not None and max_sequence_length > 512:
523
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
524
+
525
+ @staticmethod
526
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
527
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
528
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
529
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
530
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
531
+
532
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
533
+
534
+ latent_image_ids = latent_image_ids.reshape(
535
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
536
+ )
537
+
538
+ return latent_image_ids.to(device=device, dtype=dtype)
539
+
540
+ @staticmethod
541
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
542
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
543
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
544
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
545
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
546
+
547
+ return latents
548
+
549
+ @staticmethod
550
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
551
+ def _unpack_latents(latents, height, width, vae_scale_factor):
552
+ batch_size, num_patches, channels = latents.shape
553
+
554
+ height = height // vae_scale_factor
555
+ width = width // vae_scale_factor
556
+
557
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
558
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
559
+
560
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
561
+
562
+ return latents
563
+
564
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
565
+ def prepare_latents(
566
+ self,
567
+ image,
568
+ timestep,
569
+ batch_size,
570
+ num_channels_latents,
571
+ height,
572
+ width,
573
+ dtype,
574
+ device,
575
+ generator,
576
+ latents=None,
577
+ ):
578
+ if isinstance(generator, list) and len(generator) != batch_size:
579
+ raise ValueError(
580
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
581
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
582
+ )
583
+
584
+ height = 2 * (int(height) // self.vae_scale_factor)
585
+ width = 2 * (int(width) // self.vae_scale_factor)
586
+
587
+ shape = (batch_size, num_channels_latents, height, width)
588
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
589
+
590
+ image = image.to(device=device, dtype=dtype)
591
+ image_latents = self._encode_vae_image(image=image, generator=generator)
592
+
593
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
594
+ # expand init_latents for batch_size
595
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
596
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
597
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
598
+ raise ValueError(
599
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
600
+ )
601
+ else:
602
+ image_latents = torch.cat([image_latents], dim=0)
603
+
604
+ if latents is None:
605
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
606
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
607
+ else:
608
+ noise = latents.to(device)
609
+ latents = noise
610
+
611
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
612
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
613
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
614
+ return latents, noise, image_latents, latent_image_ids
615
+
616
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
617
+ def prepare_mask_latents(
618
+ self,
619
+ mask,
620
+ masked_image,
621
+ batch_size,
622
+ num_channels_latents,
623
+ num_images_per_prompt,
624
+ height,
625
+ width,
626
+ dtype,
627
+ device,
628
+ generator,
629
+ ):
630
+ height = 2 * (int(height) // self.vae_scale_factor)
631
+ width = 2 * (int(width) // self.vae_scale_factor)
632
+ # resize the mask to latents shape as we concatenate the mask to the latents
633
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
634
+ # and half precision
635
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
636
+ mask = mask.to(device=device, dtype=dtype)
637
+
638
+ batch_size = batch_size * num_images_per_prompt
639
+
640
+ masked_image = masked_image.to(device=device, dtype=dtype)
641
+
642
+ if masked_image.shape[1] == 16:
643
+ masked_image_latents = masked_image
644
+ else:
645
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
646
+
647
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
648
+
649
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
650
+ if mask.shape[0] < batch_size:
651
+ if not batch_size % mask.shape[0] == 0:
652
+ raise ValueError(
653
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
654
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
655
+ " of masks that you pass is divisible by the total requested batch size."
656
+ )
657
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
658
+ if masked_image_latents.shape[0] < batch_size:
659
+ if not batch_size % masked_image_latents.shape[0] == 0:
660
+ raise ValueError(
661
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
662
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
663
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
664
+ )
665
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
666
+
667
+ # aligning device to prevent device errors when concating it with the latent model input
668
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
669
+
670
+ masked_image_latents = self._pack_latents(
671
+ masked_image_latents,
672
+ batch_size,
673
+ num_channels_latents,
674
+ height,
675
+ width,
676
+ )
677
+ mask = self._pack_latents(
678
+ mask.repeat(1, num_channels_latents, 1, 1),
679
+ batch_size,
680
+ num_channels_latents,
681
+ height,
682
+ width,
683
+ )
684
+
685
+ return mask, masked_image_latents
686
+
687
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
688
+ def prepare_image(
689
+ self,
690
+ image,
691
+ width,
692
+ height,
693
+ batch_size,
694
+ num_images_per_prompt,
695
+ device,
696
+ dtype,
697
+ do_classifier_free_guidance=False,
698
+ guess_mode=False,
699
+ ):
700
+ if isinstance(image, torch.Tensor):
701
+ pass
702
+ else:
703
+ image = self.image_processor.preprocess(image, height=height, width=width)
704
+
705
+ image_batch_size = image.shape[0]
706
+
707
+ if image_batch_size == 1:
708
+ repeat_by = batch_size
709
+ else:
710
+ # image batch size is the same as prompt batch size
711
+ repeat_by = num_images_per_prompt
712
+
713
+ image = image.repeat_interleave(repeat_by, dim=0)
714
+
715
+ image = image.to(device=device, dtype=dtype)
716
+
717
+ if do_classifier_free_guidance and not guess_mode:
718
+ image = torch.cat([image] * 2)
719
+
720
+ return image
721
+
722
+ @property
723
+ def guidance_scale(self):
724
+ return self._guidance_scale
725
+
726
+ @property
727
+ def joint_attention_kwargs(self):
728
+ return self._joint_attention_kwargs
729
+
730
+ @property
731
+ def num_timesteps(self):
732
+ return self._num_timesteps
733
+
734
+ @property
735
+ def interrupt(self):
736
+ return self._interrupt
737
+
738
+ @torch.no_grad()
739
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
740
+ def __call__(
741
+ self,
742
+ prompt: Union[str, List[str]] = None,
743
+ prompt_2: Optional[Union[str, List[str]]] = None,
744
+ image: PipelineImageInput = None,
745
+ mask_image: PipelineImageInput = None,
746
+ masked_image_latents: PipelineImageInput = None,
747
+ control_image: PipelineImageInput = None,
748
+ height: Optional[int] = None,
749
+ width: Optional[int] = None,
750
+ strength: float = 0.6,
751
+ padding_mask_crop: Optional[int] = None,
752
+ timesteps: List[int] = None,
753
+ num_inference_steps: int = 28,
754
+ guidance_scale: float = 7.0,
755
+ control_guidance_start: Union[float, List[float]] = 0.0,
756
+ control_guidance_end: Union[float, List[float]] = 1.0,
757
+ control_mode: Optional[Union[int, List[int]]] = None,
758
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
759
+ num_images_per_prompt: Optional[int] = 1,
760
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
761
+ latents: Optional[torch.FloatTensor] = None,
762
+ prompt_embeds: Optional[torch.FloatTensor] = None,
763
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
764
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
765
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
766
+ output_type: Optional[str] = "pil",
767
+ return_dict: bool = True,
768
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
769
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
770
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
771
+ max_sequence_length: int = 512,
772
+ ):
773
+ """
774
+ Function invoked when calling the pipeline for generation.
775
+
776
+ Args:
777
+ prompt (`str` or `List[str]`, *optional*):
778
+ The prompt or prompts to guide the image generation.
779
+ prompt_2 (`str` or `List[str]`, *optional*):
780
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
781
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
782
+ The image(s) to inpaint.
783
+ mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
784
+ The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
785
+ will be preserved.
786
+ masked_image_latents (`torch.FloatTensor`, *optional*):
787
+ Pre-generated masked image latents.
788
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
789
+ The ControlNet input condition. Image to control the generation.
790
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
791
+ The height in pixels of the generated image.
792
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
793
+ The width in pixels of the generated image.
794
+ strength (`float`, *optional*, defaults to 0.6):
795
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
796
+ padding_mask_crop (`int`, *optional*):
797
+ The size of the padding to use when cropping the mask.
798
+ num_inference_steps (`int`, *optional*, defaults to 28):
799
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
800
+ expense of slower inference.
801
+ timesteps (`List[int]`, *optional*):
802
+ Custom timesteps to use for the denoising process.
803
+ guidance_scale (`float`, *optional*, defaults to 7.0):
804
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
805
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
806
+ The percentage of total steps at which the ControlNet starts applying.
807
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
808
+ The percentage of total steps at which the ControlNet stops applying.
809
+ control_mode (`int` or `List[int]`, *optional*):
810
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
811
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
812
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
813
+ to the residual in the original transformer.
814
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
815
+ The number of images to generate per prompt.
816
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
817
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
818
+ make generation deterministic.
819
+ latents (`torch.FloatTensor`, *optional*):
820
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
821
+ generation. Can be used to tweak the same generation with different prompts.
822
+ prompt_embeds (`torch.FloatTensor`, *optional*):
823
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
824
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
825
+ Pre-generated pooled text embeddings.
826
+ output_type (`str`, *optional*, defaults to `"pil"`):
827
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
828
+ return_dict (`bool`, *optional*, defaults to `True`):
829
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
830
+ joint_attention_kwargs (`dict`, *optional*):
831
+ Additional keyword arguments to be passed to the joint attention mechanism.
832
+ callback_on_step_end (`Callable`, *optional*):
833
+ A function that calls at the end of each denoising step during the inference.
834
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
835
+ The list of tensor inputs for the `callback_on_step_end` function.
836
+ max_sequence_length (`int`, *optional*, defaults to 512):
837
+ The maximum length of the sequence to be generated.
838
+
839
+ Examples:
840
+
841
+ Returns:
842
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
843
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
844
+ images.
845
+ """
846
+ height = height or self.default_sample_size * self.vae_scale_factor
847
+ width = width or self.default_sample_size * self.vae_scale_factor
848
+
849
+ global_height = height
850
+ global_width = width
851
+
852
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
853
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
854
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
855
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
856
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
857
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
858
+ control_guidance_start, control_guidance_end = (
859
+ mult * [control_guidance_start],
860
+ mult * [control_guidance_end],
861
+ )
862
+
863
+ # 1. Check inputs
864
+ self.check_inputs(
865
+ prompt,
866
+ prompt_2,
867
+ image,
868
+ mask_image,
869
+ strength,
870
+ height,
871
+ width,
872
+ output_type=output_type,
873
+ prompt_embeds=prompt_embeds,
874
+ pooled_prompt_embeds=pooled_prompt_embeds,
875
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
876
+ padding_mask_crop=padding_mask_crop,
877
+ max_sequence_length=max_sequence_length,
878
+ )
879
+
880
+ self._guidance_scale = guidance_scale
881
+ self._joint_attention_kwargs = joint_attention_kwargs
882
+ self._interrupt = False
883
+
884
+ # 2. Define call parameters
885
+ if prompt is not None and isinstance(prompt, str):
886
+ batch_size = 1
887
+ elif prompt is not None and isinstance(prompt, list):
888
+ batch_size = len(prompt)
889
+ else:
890
+ batch_size = prompt_embeds.shape[0]
891
+
892
+ device = self._execution_device
893
+ dtype = self.transformer.dtype
894
+
895
+ # 3. Encode input prompt
896
+ lora_scale = (
897
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
898
+ )
899
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
900
+ prompt=prompt,
901
+ prompt_2=prompt_2,
902
+ prompt_embeds=prompt_embeds,
903
+ t5_prompt_embeds=t5_prompt_embeds,
904
+ pooled_prompt_embeds=pooled_prompt_embeds,
905
+ device=device,
906
+ num_images_per_prompt=num_images_per_prompt,
907
+ max_sequence_length=max_sequence_length,
908
+ lora_scale=lora_scale,
909
+ )
910
+
911
+ # 4. Preprocess mask and image
912
+ if padding_mask_crop is not None:
913
+ crops_coords = self.mask_processor.get_crop_region(
914
+ mask_image, global_width, global_height, pad=padding_mask_crop
915
+ )
916
+ resize_mode = "fill"
917
+ else:
918
+ crops_coords = None
919
+ resize_mode = "default"
920
+
921
+ original_image = image
922
+ init_image = self.image_processor.preprocess(
923
+ image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
924
+ )
925
+ init_image = init_image.to(dtype=torch.float32)
926
+
927
+ # 5. Prepare control image
928
+ num_channels_latents = self.transformer.config.in_channels // 4
929
+ if isinstance(self.controlnet, FluxControlNetModel):
930
+ control_image = self.prepare_image(
931
+ image=control_image,
932
+ width=height,
933
+ height=width,
934
+ batch_size=batch_size * num_images_per_prompt,
935
+ num_images_per_prompt=num_images_per_prompt,
936
+ device=device,
937
+ dtype=self.vae.dtype,
938
+ )
939
+ height, width = control_image.shape[-2:]
940
+
941
+ # vae encode
942
+ control_image = self.vae.encode(control_image).latent_dist.sample()
943
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
944
+
945
+ # pack
946
+ height_control_image, width_control_image = control_image.shape[2:]
947
+ control_image = self._pack_latents(
948
+ control_image,
949
+ batch_size * num_images_per_prompt,
950
+ num_channels_latents,
951
+ height_control_image,
952
+ width_control_image,
953
+ )
954
+
955
+ # set control mode
956
+ if control_mode is not None:
957
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
958
+ control_mode = control_mode.reshape([-1, 1])
959
+
960
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
961
+ control_images = []
962
+
963
+ for control_image_ in control_image:
964
+ control_image_ = self.prepare_image(
965
+ image=control_image_,
966
+ width=width,
967
+ height=height,
968
+ batch_size=batch_size * num_images_per_prompt,
969
+ num_images_per_prompt=num_images_per_prompt,
970
+ device=device,
971
+ dtype=self.vae.dtype,
972
+ )
973
+ height, width = control_image_.shape[-2:]
974
+
975
+ # vae encode
976
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
977
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
978
+
979
+ # pack
980
+ height_control_image, width_control_image = control_image_.shape[2:]
981
+ control_image_ = self._pack_latents(
982
+ control_image_,
983
+ batch_size * num_images_per_prompt,
984
+ num_channels_latents,
985
+ height_control_image,
986
+ width_control_image,
987
+ )
988
+
989
+ control_images.append(control_image_)
990
+
991
+ control_image = control_images
992
+
993
+ ## set control mode
994
+ #control_mode_ = []
995
+ #if isinstance(control_mode, list):
996
+ # for cmode in control_mode:
997
+ # if cmode is None:
998
+ # control_mode_.append(-1)
999
+ # else:
1000
+ # control_mode_.append(cmode)
1001
+ #control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
1002
+ #control_mode = control_mode.reshape([-1, 1])
1003
+ control_modes = []
1004
+ for cmode in control_mode:
1005
+ if cmode is None:
1006
+ cmode = -1
1007
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
1008
+ control_modes.append(control_mode)
1009
+ control_mode = control_modes
1010
+
1011
+ # 6. Prepare timesteps
1012
+
1013
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1014
+ image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
1015
+ mu = calculate_shift(
1016
+ image_seq_len,
1017
+ self.scheduler.config.base_image_seq_len,
1018
+ self.scheduler.config.max_image_seq_len,
1019
+ self.scheduler.config.base_shift,
1020
+ self.scheduler.config.max_shift,
1021
+ )
1022
+ timesteps, num_inference_steps = retrieve_timesteps(
1023
+ self.scheduler,
1024
+ num_inference_steps,
1025
+ device,
1026
+ timesteps,
1027
+ sigmas,
1028
+ mu=mu,
1029
+ )
1030
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1031
+
1032
+ if num_inference_steps < 1:
1033
+ raise ValueError(
1034
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1035
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1036
+ )
1037
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1038
+
1039
+ # 7. Prepare latent variables
1040
+
1041
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
1042
+ init_image,
1043
+ latent_timestep,
1044
+ batch_size * num_images_per_prompt,
1045
+ num_channels_latents,
1046
+ global_height,
1047
+ global_width,
1048
+ prompt_embeds.dtype,
1049
+ device,
1050
+ generator,
1051
+ latents,
1052
+ )
1053
+
1054
+ # 8. Prepare mask latents
1055
+ mask_condition = self.mask_processor.preprocess(
1056
+ mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
1057
+ )
1058
+ if masked_image_latents is None:
1059
+ masked_image = init_image * (mask_condition < 0.5)
1060
+ else:
1061
+ masked_image = masked_image_latents
1062
+
1063
+ mask, masked_image_latents = self.prepare_mask_latents(
1064
+ mask_condition,
1065
+ masked_image,
1066
+ batch_size,
1067
+ num_channels_latents,
1068
+ num_images_per_prompt,
1069
+ global_height,
1070
+ global_width,
1071
+ prompt_embeds.dtype,
1072
+ device,
1073
+ generator,
1074
+ )
1075
+
1076
+ controlnet_keep = []
1077
+ for i in range(len(timesteps)):
1078
+ keeps = [
1079
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1080
+ for s, e in zip(control_guidance_start, control_guidance_end)
1081
+ ]
1082
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1083
+
1084
+ # 9. Denoising loop
1085
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1086
+ self._num_timesteps = len(timesteps)
1087
+
1088
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1089
+ for i, t in enumerate(timesteps):
1090
+ if self.interrupt:
1091
+ continue
1092
+
1093
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1094
+
1095
+ # predict the noise residual
1096
+ #if self.controlnet.config.guidance_embeds:
1097
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1098
+ guidance = guidance.expand(latents.shape[0])
1099
+ #else:
1100
+ # guidance = None
1101
+
1102
+ if isinstance(controlnet_keep[i], list):
1103
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1104
+ else:
1105
+ controlnet_cond_scale = controlnet_conditioning_scale
1106
+ if isinstance(controlnet_cond_scale, list):
1107
+ controlnet_cond_scale = controlnet_cond_scale[0]
1108
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1109
+
1110
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1111
+ hidden_states=latents,
1112
+ controlnet_cond=control_image,
1113
+ controlnet_mode=control_mode,
1114
+ conditioning_scale=cond_scale,
1115
+ timestep=timestep / 1000,
1116
+ guidance=guidance,
1117
+ pooled_projections=pooled_prompt_embeds,
1118
+ encoder_hidden_states=prompt_embeds_control,
1119
+ t5_encoder_hidden_states=t5_prompt_embeds,
1120
+ txt_ids=text_ids,
1121
+ img_ids=latent_image_ids,
1122
+ joint_attention_kwargs=self.joint_attention_kwargs,
1123
+ return_dict=False,
1124
+ )
1125
+
1126
+ if self.transformer.config.guidance_embeds:
1127
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1128
+ guidance = guidance.expand(latents.shape[0])
1129
+ else:
1130
+ guidance = None
1131
+
1132
+ noise_pred = self.transformer(
1133
+ hidden_states=latents,
1134
+ timestep=timestep / 1000,
1135
+ guidance=guidance,
1136
+ pooled_projections=pooled_prompt_embeds,
1137
+ encoder_hidden_states=prompt_embeds,
1138
+ t5_encoder_hidden_states=t5_prompt_embeds,
1139
+ controlnet_block_samples=controlnet_block_samples,
1140
+ controlnet_single_block_samples=controlnet_single_block_samples,
1141
+ txt_ids=text_ids,
1142
+ img_ids=latent_image_ids,
1143
+ joint_attention_kwargs=self.joint_attention_kwargs,
1144
+ return_dict=False,
1145
+ )[0]
1146
+
1147
+ # compute the previous noisy sample x_t -> x_t-1
1148
+ latents_dtype = latents.dtype
1149
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1150
+
1151
+ # For inpainting, we need to apply the mask and add the masked image latents
1152
+ init_latents_proper = image_latents
1153
+ init_mask = mask
1154
+
1155
+ if i < len(timesteps) - 1:
1156
+ noise_timestep = timesteps[i + 1]
1157
+ init_latents_proper = self.scheduler.scale_noise(
1158
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1159
+ )
1160
+
1161
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1162
+
1163
+ if latents.dtype != latents_dtype:
1164
+ if torch.backends.mps.is_available():
1165
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1166
+ latents = latents.to(latents_dtype)
1167
+
1168
+ # call the callback, if provided
1169
+ if callback_on_step_end is not None:
1170
+ callback_kwargs = {}
1171
+ for k in callback_on_step_end_tensor_inputs:
1172
+ callback_kwargs[k] = locals()[k]
1173
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1174
+
1175
+ latents = callback_outputs.pop("latents", latents)
1176
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1177
+
1178
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1179
+ progress_bar.update()
1180
+
1181
+ if XLA_AVAILABLE:
1182
+ xm.mark_step()
1183
+
1184
+ # Post-processing
1185
+ if output_type == "latent":
1186
+ image = latents
1187
+ else:
1188
+ latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
1189
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1190
+ image = self.vae.decode(latents, return_dict=False)[0]
1191
+ image = self.image_processor.postprocess(image, output_type=output_type)
1192
+
1193
+ # Offload all models
1194
+ self.maybe_free_model_hooks()
1195
+
1196
+ if not return_dict:
1197
+ return (image,)
1198
+
1199
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_img2img.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from .transformer_flux import FluxTransformer2DModel
26
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from .pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+
55
+ >>> from diffusers import FluxImg2ImgPipeline
56
+ >>> from diffusers.utils import load_image
57
+
58
+ >>> device = "cuda"
59
+ >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
60
+ >>> pipe = pipe.to(device)
61
+
62
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
63
+ >>> init_image = load_image(url).resize((1024, 1024))
64
+
65
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
66
+
67
+ >>> images = pipe(
68
+ ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
69
+ ... ).images[0]
70
+ ```
71
+ """
72
+
73
+
74
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
75
+ def calculate_shift(
76
+ image_seq_len,
77
+ base_seq_len: int = 256,
78
+ max_seq_len: int = 4096,
79
+ base_shift: float = 0.5,
80
+ max_shift: float = 1.16,
81
+ ):
82
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
83
+ b = base_shift - m * base_seq_len
84
+ mu = image_seq_len * m + b
85
+ return mu
86
+
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
89
+ def retrieve_latents(
90
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
91
+ ):
92
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
93
+ return encoder_output.latent_dist.sample(generator)
94
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
95
+ return encoder_output.latent_dist.mode()
96
+ elif hasattr(encoder_output, "latents"):
97
+ return encoder_output.latents
98
+ else:
99
+ raise AttributeError("Could not access latents of provided encoder_output")
100
+
101
+
102
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
103
+ def retrieve_timesteps(
104
+ scheduler,
105
+ num_inference_steps: Optional[int] = None,
106
+ device: Optional[Union[str, torch.device]] = None,
107
+ timesteps: Optional[List[int]] = None,
108
+ sigmas: Optional[List[float]] = None,
109
+ **kwargs,
110
+ ):
111
+ """
112
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
113
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
114
+
115
+ Args:
116
+ scheduler (`SchedulerMixin`):
117
+ The scheduler to get timesteps from.
118
+ num_inference_steps (`int`):
119
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
120
+ must be `None`.
121
+ device (`str` or `torch.device`, *optional*):
122
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
123
+ timesteps (`List[int]`, *optional*):
124
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
125
+ `num_inference_steps` and `sigmas` must be `None`.
126
+ sigmas (`List[float]`, *optional*):
127
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
128
+ `num_inference_steps` and `timesteps` must be `None`.
129
+
130
+ Returns:
131
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
132
+ second element is the number of inference steps.
133
+ """
134
+ if timesteps is not None and sigmas is not None:
135
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
136
+ if timesteps is not None:
137
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
138
+ if not accepts_timesteps:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" timestep schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ elif sigmas is not None:
147
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
148
+ if not accept_sigmas:
149
+ raise ValueError(
150
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
151
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
152
+ )
153
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
154
+ timesteps = scheduler.timesteps
155
+ num_inference_steps = len(timesteps)
156
+ else:
157
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ return timesteps, num_inference_steps
160
+
161
+
162
+ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
163
+ r"""
164
+ The Flux pipeline for image inpainting.
165
+
166
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
167
+
168
+ Args:
169
+ transformer ([`FluxTransformer2DModel`]):
170
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
171
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
172
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
173
+ vae ([`AutoencoderKL`]):
174
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
175
+ text_encoder ([`CLIPTextModel`]):
176
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
177
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
178
+ text_encoder_2 ([`T5EncoderModel`]):
179
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
180
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
181
+ tokenizer (`CLIPTokenizer`):
182
+ Tokenizer of class
183
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
184
+ tokenizer_2 (`T5TokenizerFast`):
185
+ Second Tokenizer of class
186
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
187
+ """
188
+
189
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
190
+ _optional_components = []
191
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
192
+
193
+ def __init__(
194
+ self,
195
+ scheduler: FlowMatchEulerDiscreteScheduler,
196
+ vae: AutoencoderKL,
197
+ text_encoder: CLIPTextModel,
198
+ tokenizer: CLIPTokenizer,
199
+ transformer: FluxTransformer2DModel,
200
+ text_encoder_2: T5EncoderModel | None = None,
201
+ tokenizer_2: T5TokenizerFast | None = None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.register_modules(
206
+ vae=vae,
207
+ text_encoder=text_encoder,
208
+ #text_encoder_2=text_encoder_2,
209
+ tokenizer=tokenizer,
210
+ #tokenizer_2=tokenizer_2,
211
+ transformer=transformer,
212
+ scheduler=scheduler,
213
+ )
214
+ self.vae_scale_factor = (
215
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
216
+ )
217
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
218
+ self.tokenizer_max_length = (
219
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
220
+ )
221
+ self.default_sample_size = 64
222
+
223
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
224
+ def _get_t5_prompt_embeds(
225
+ self,
226
+ prompt: Union[str, List[str]] = None,
227
+ num_images_per_prompt: int = 1,
228
+ max_sequence_length: int = 512,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ ):
232
+ device = device or self._execution_device
233
+ dtype = dtype or self.text_encoder.dtype
234
+
235
+ prompt = [prompt] if isinstance(prompt, str) else prompt
236
+ batch_size = len(prompt)
237
+
238
+ text_inputs = self.tokenizer_2(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ return_length=False,
244
+ return_overflowing_tokens=False,
245
+ return_tensors="pt",
246
+ )
247
+ text_input_ids = text_inputs.input_ids
248
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
249
+
250
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
251
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
252
+ logger.warning(
253
+ "The following part of your input was truncated because `max_sequence_length` is set to "
254
+ f" {max_sequence_length} tokens: {removed_text}"
255
+ )
256
+
257
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
258
+
259
+ dtype = self.text_encoder_2.dtype
260
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
261
+
262
+ _, seq_len, _ = prompt_embeds.shape
263
+
264
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
265
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
267
+
268
+ return prompt_embeds
269
+
270
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
271
+ def _get_clip_prompt_embeds(
272
+ self,
273
+ prompt: Union[str, List[str]],
274
+ num_images_per_prompt: int = 1,
275
+ device: Optional[torch.device] = None,
276
+ ):
277
+ device = device or self._execution_device
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ batch_size = len(prompt)
281
+
282
+ text_inputs = self.tokenizer(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=self.tokenizer_max_length,
286
+ truncation=True,
287
+ return_overflowing_tokens=False,
288
+ return_length=False,
289
+ return_tensors="pt",
290
+ )
291
+
292
+ text_input_ids = text_inputs.input_ids
293
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
294
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
295
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
296
+ logger.warning(
297
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
298
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
299
+ )
300
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
301
+
302
+ # Use pooled output of CLIPTextModel
303
+ prompt_embeds = prompt_embeds.pooler_output
304
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
305
+
306
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
307
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
308
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
309
+
310
+ return prompt_embeds
311
+
312
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
313
+ def encode_prompt(
314
+ self,
315
+ prompt: Union[str, List[str]],
316
+ prompt_2: Union[str, List[str]],
317
+ device: Optional[torch.device] = None,
318
+ num_images_per_prompt: int = 1,
319
+ prompt_embeds: Optional[torch.FloatTensor] = None,
320
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
321
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
322
+ max_sequence_length: int = 512,
323
+ lora_scale: Optional[float] = None,
324
+ ):
325
+ r"""
326
+
327
+ Args:
328
+ prompt (`str` or `List[str]`, *optional*):
329
+ prompt to be encoded
330
+ prompt_2 (`str` or `List[str]`, *optional*):
331
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
332
+ used in all text-encoders
333
+ device: (`torch.device`):
334
+ torch device
335
+ num_images_per_prompt (`int`):
336
+ number of images that should be generated per prompt
337
+ prompt_embeds (`torch.FloatTensor`, *optional*):
338
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
339
+ provided, text embeddings will be generated from `prompt` input argument.
340
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
341
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
342
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
343
+ lora_scale (`float`, *optional*):
344
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
345
+ """
346
+ device = device or self._execution_device
347
+
348
+ # set lora scale so that monkey patched LoRA
349
+ # function of text encoder can correctly access it
350
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
351
+ self._lora_scale = lora_scale
352
+
353
+ # dynamically adjust the LoRA scale
354
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
355
+ scale_lora_layers(self.text_encoder, lora_scale)
356
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
357
+ scale_lora_layers(self.text_encoder_2, lora_scale)
358
+
359
+ prompt = [prompt] if isinstance(prompt, str) else prompt
360
+ if prompt is not None:
361
+ batch_size = len(prompt)
362
+ else:
363
+ batch_size = prompt_embeds.shape[0]
364
+
365
+ if prompt_embeds is None:
366
+ prompt_2 = prompt_2 or prompt
367
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
368
+
369
+ # We only use the pooled prompt output from the CLIPTextModel
370
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
371
+ prompt=prompt,
372
+ device=device,
373
+ num_images_per_prompt=num_images_per_prompt,
374
+ )
375
+ prompt_embeds = self._get_t5_prompt_embeds(
376
+ prompt=prompt_2,
377
+ num_images_per_prompt=num_images_per_prompt,
378
+ max_sequence_length=max_sequence_length,
379
+ device=device,
380
+ )
381
+
382
+ if self.text_encoder is not None:
383
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
384
+ # Retrieve the original scale by scaling back the LoRA layers
385
+ unscale_lora_layers(self.text_encoder, lora_scale)
386
+
387
+ #if self.text_encoder_2 is not None:
388
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
389
+ # # Retrieve the original scale by scaling back the LoRA layers
390
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
391
+
392
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
393
+ if t5_prompt_embeds is not None:
394
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
395
+ else:
396
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
397
+ #text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
398
+
399
+ return prompt_embeds, pooled_prompt_embeds, text_ids
400
+
401
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
402
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
403
+ if isinstance(generator, list):
404
+ image_latents = [
405
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
406
+ for i in range(image.shape[0])
407
+ ]
408
+ image_latents = torch.cat(image_latents, dim=0)
409
+ else:
410
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
411
+
412
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
413
+
414
+ return image_latents
415
+
416
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
417
+ def get_timesteps(self, num_inference_steps, strength, device):
418
+ # get the original timestep using init_timestep
419
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
420
+
421
+ t_start = int(max(num_inference_steps - init_timestep, 0))
422
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
423
+ if hasattr(self.scheduler, "set_begin_index"):
424
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
425
+
426
+ return timesteps, num_inference_steps - t_start
427
+
428
+ def check_inputs(
429
+ self,
430
+ prompt,
431
+ prompt_2,
432
+ strength,
433
+ height,
434
+ width,
435
+ prompt_embeds=None,
436
+ pooled_prompt_embeds=None,
437
+ callback_on_step_end_tensor_inputs=None,
438
+ max_sequence_length=None,
439
+ ):
440
+ if strength < 0 or strength > 1:
441
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
442
+
443
+ if height % 8 != 0 or width % 8 != 0:
444
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
445
+
446
+ if callback_on_step_end_tensor_inputs is not None and not all(
447
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
448
+ ):
449
+ raise ValueError(
450
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
451
+ )
452
+
453
+ if prompt is not None and prompt_embeds is not None:
454
+ raise ValueError(
455
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
456
+ " only forward one of the two."
457
+ )
458
+ elif prompt_2 is not None and prompt_embeds is not None:
459
+ raise ValueError(
460
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
461
+ " only forward one of the two."
462
+ )
463
+ elif prompt is None and prompt_embeds is None:
464
+ raise ValueError(
465
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
466
+ )
467
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
468
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
469
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
470
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
471
+
472
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
473
+ raise ValueError(
474
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
475
+ )
476
+
477
+ if max_sequence_length is not None and max_sequence_length > 512:
478
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
479
+
480
+ @staticmethod
481
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
482
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
483
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
484
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
485
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
486
+
487
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
488
+
489
+ latent_image_ids = latent_image_ids.reshape(
490
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
491
+ )
492
+
493
+ return latent_image_ids.to(device=device, dtype=dtype)
494
+
495
+ @staticmethod
496
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
497
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
498
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
499
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
500
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
501
+
502
+ return latents
503
+
504
+ @staticmethod
505
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
506
+ def _unpack_latents(latents, height, width, vae_scale_factor):
507
+ batch_size, num_patches, channels = latents.shape
508
+
509
+ height = height // vae_scale_factor
510
+ width = width // vae_scale_factor
511
+
512
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
513
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
514
+
515
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
516
+
517
+ return latents
518
+
519
+ def prepare_latents(
520
+ self,
521
+ image,
522
+ timestep,
523
+ batch_size,
524
+ num_channels_latents,
525
+ height,
526
+ width,
527
+ dtype,
528
+ device,
529
+ generator,
530
+ latents=None,
531
+ ):
532
+ if isinstance(generator, list) and len(generator) != batch_size:
533
+ raise ValueError(
534
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
535
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
536
+ )
537
+
538
+ height = 2 * (int(height) // self.vae_scale_factor)
539
+ width = 2 * (int(width) // self.vae_scale_factor)
540
+
541
+ shape = (batch_size, num_channels_latents, height, width)
542
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
543
+
544
+ if latents is not None:
545
+ return latents.to(device=device, dtype=dtype), latent_image_ids
546
+
547
+ image = image.to(device=device, dtype=dtype)
548
+ image_latents = self._encode_vae_image(image=image, generator=generator)
549
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
550
+ # expand init_latents for batch_size
551
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
552
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
553
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
554
+ raise ValueError(
555
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
556
+ )
557
+ else:
558
+ image_latents = torch.cat([image_latents], dim=0)
559
+
560
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
561
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
562
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
563
+ return latents, latent_image_ids
564
+
565
+ @property
566
+ def guidance_scale(self):
567
+ return self._guidance_scale
568
+
569
+ @property
570
+ def joint_attention_kwargs(self):
571
+ return self._joint_attention_kwargs
572
+
573
+ @property
574
+ def num_timesteps(self):
575
+ return self._num_timesteps
576
+
577
+ @property
578
+ def interrupt(self):
579
+ return self._interrupt
580
+
581
+ @torch.no_grad()
582
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
583
+ def __call__(
584
+ self,
585
+ prompt: Union[str, List[str]] = None,
586
+ prompt_2: Optional[Union[str, List[str]]] = None,
587
+ image: PipelineImageInput = None,
588
+ height: Optional[int] = None,
589
+ width: Optional[int] = None,
590
+ strength: float = 0.6,
591
+ num_inference_steps: int = 28,
592
+ timesteps: List[int] = None,
593
+ guidance_scale: float = 7.0,
594
+ num_images_per_prompt: Optional[int] = 1,
595
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
596
+ latents: Optional[torch.FloatTensor] = None,
597
+ prompt_embeds: Optional[torch.FloatTensor] = None,
598
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
599
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
600
+ output_type: Optional[str] = "pil",
601
+ return_dict: bool = True,
602
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
603
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
604
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
605
+ max_sequence_length: int = 512,
606
+ ):
607
+ r"""
608
+ Function invoked when calling the pipeline for generation.
609
+
610
+ Args:
611
+ prompt (`str` or `List[str]`, *optional*):
612
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
613
+ instead.
614
+ prompt_2 (`str` or `List[str]`, *optional*):
615
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
616
+ will be used instead
617
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
618
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
619
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
620
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
621
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
622
+ latents as `image`, but if passing latents directly it is not encoded again.
623
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
624
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
625
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
626
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
627
+ strength (`float`, *optional*, defaults to 1.0):
628
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
629
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
630
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
631
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
632
+ essentially ignores `image`.
633
+ num_inference_steps (`int`, *optional*, defaults to 50):
634
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
635
+ expense of slower inference.
636
+ timesteps (`List[int]`, *optional*):
637
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
638
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
639
+ passed will be used. Must be in descending order.
640
+ guidance_scale (`float`, *optional*, defaults to 7.0):
641
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
642
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
643
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
644
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
645
+ usually at the expense of lower image quality.
646
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
647
+ The number of images to generate per prompt.
648
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
649
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
650
+ to make generation deterministic.
651
+ latents (`torch.FloatTensor`, *optional*):
652
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
653
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
654
+ tensor will ge generated by sampling using the supplied random `generator`.
655
+ prompt_embeds (`torch.FloatTensor`, *optional*):
656
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
657
+ provided, text embeddings will be generated from `prompt` input argument.
658
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
659
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
660
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
661
+ output_type (`str`, *optional*, defaults to `"pil"`):
662
+ The output format of the generate image. Choose between
663
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
664
+ return_dict (`bool`, *optional*, defaults to `True`):
665
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
666
+ joint_attention_kwargs (`dict`, *optional*):
667
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
668
+ `self.processor` in
669
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
670
+ callback_on_step_end (`Callable`, *optional*):
671
+ A function that calls at the end of each denoising steps during the inference. The function is called
672
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
673
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
674
+ `callback_on_step_end_tensor_inputs`.
675
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
676
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
677
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
678
+ `._callback_tensor_inputs` attribute of your pipeline class.
679
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
680
+
681
+ Examples:
682
+
683
+ Returns:
684
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
685
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
686
+ images.
687
+ """
688
+
689
+ height = height or self.default_sample_size * self.vae_scale_factor
690
+ width = width or self.default_sample_size * self.vae_scale_factor
691
+
692
+ # 1. Check inputs. Raise error if not correct
693
+ self.check_inputs(
694
+ prompt,
695
+ prompt_2,
696
+ strength,
697
+ height,
698
+ width,
699
+ prompt_embeds=prompt_embeds,
700
+ pooled_prompt_embeds=pooled_prompt_embeds,
701
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
702
+ max_sequence_length=max_sequence_length,
703
+ )
704
+
705
+ self._guidance_scale = guidance_scale
706
+ self._joint_attention_kwargs = joint_attention_kwargs
707
+ self._interrupt = False
708
+
709
+ # 2. Preprocess image
710
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
711
+ init_image = init_image.to(dtype=torch.float32)
712
+
713
+ # 3. Define call parameters
714
+ if prompt is not None and isinstance(prompt, str):
715
+ batch_size = 1
716
+ elif prompt is not None and isinstance(prompt, list):
717
+ batch_size = len(prompt)
718
+ else:
719
+ batch_size = prompt_embeds.shape[0]
720
+
721
+ device = self._execution_device
722
+
723
+ lora_scale = (
724
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
725
+ )
726
+ (
727
+ prompt_embeds,
728
+ pooled_prompt_embeds,
729
+ text_ids,
730
+ ) = self.encode_prompt(
731
+ prompt=prompt,
732
+ prompt_2=prompt_2,
733
+ prompt_embeds=prompt_embeds,
734
+ t5_prompt_embeds=t5_prompt_embeds,
735
+ pooled_prompt_embeds=pooled_prompt_embeds,
736
+ device=device,
737
+ num_images_per_prompt=num_images_per_prompt,
738
+ max_sequence_length=max_sequence_length,
739
+ lora_scale=lora_scale,
740
+ )
741
+
742
+ # 4.Prepare timesteps
743
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
744
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
745
+ mu = calculate_shift(
746
+ image_seq_len,
747
+ self.scheduler.config.base_image_seq_len,
748
+ self.scheduler.config.max_image_seq_len,
749
+ self.scheduler.config.base_shift,
750
+ self.scheduler.config.max_shift,
751
+ )
752
+ timesteps, num_inference_steps = retrieve_timesteps(
753
+ self.scheduler,
754
+ num_inference_steps,
755
+ device,
756
+ timesteps,
757
+ sigmas,
758
+ mu=mu,
759
+ )
760
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
761
+
762
+ if num_inference_steps < 1:
763
+ raise ValueError(
764
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
765
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
766
+ )
767
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
768
+
769
+ # 5. Prepare latent variables
770
+ num_channels_latents = self.transformer.config.in_channels // 4
771
+
772
+ latents, latent_image_ids = self.prepare_latents(
773
+ init_image,
774
+ latent_timestep,
775
+ batch_size * num_images_per_prompt,
776
+ num_channels_latents,
777
+ height,
778
+ width,
779
+ prompt_embeds.dtype,
780
+ device,
781
+ generator,
782
+ latents,
783
+ )
784
+
785
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
786
+ self._num_timesteps = len(timesteps)
787
+
788
+ # handle guidance
789
+ if self.transformer.config.guidance_embeds:
790
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
791
+ guidance = guidance.expand(latents.shape[0])
792
+ else:
793
+ guidance = None
794
+
795
+ # 6. Denoising loop
796
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
797
+ for i, t in enumerate(timesteps):
798
+ if self.interrupt:
799
+ continue
800
+
801
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
802
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
803
+ noise_pred = self.transformer(
804
+ hidden_states=latents,
805
+ timestep=timestep / 1000,
806
+ guidance=guidance,
807
+ pooled_projections=pooled_prompt_embeds,
808
+ encoder_hidden_states=prompt_embeds,
809
+ t5_encoder_hidden_states=t5_prompt_embeds,
810
+ txt_ids=text_ids,
811
+ img_ids=latent_image_ids,
812
+ joint_attention_kwargs=self.joint_attention_kwargs,
813
+ return_dict=False,
814
+ )[0]
815
+
816
+ # compute the previous noisy sample x_t -> x_t-1
817
+ latents_dtype = latents.dtype
818
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
819
+
820
+ if latents.dtype != latents_dtype:
821
+ if torch.backends.mps.is_available():
822
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
823
+ latents = latents.to(latents_dtype)
824
+
825
+ if callback_on_step_end is not None:
826
+ callback_kwargs = {}
827
+ for k in callback_on_step_end_tensor_inputs:
828
+ callback_kwargs[k] = locals()[k]
829
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
830
+
831
+ latents = callback_outputs.pop("latents", latents)
832
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
833
+
834
+ # call the callback, if provided
835
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
836
+ progress_bar.update()
837
+
838
+ if XLA_AVAILABLE:
839
+ xm.mark_step()
840
+
841
+ if output_type == "latent":
842
+ image = latents
843
+
844
+ else:
845
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
846
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
847
+ image = self.vae.decode(latents, return_dict=False)[0]
848
+ image = self.image_processor.postprocess(image, output_type=output_type)
849
+
850
+ # Offload all models
851
+ self.maybe_free_model_hooks()
852
+
853
+ if not return_dict:
854
+ return (image,)
855
+
856
+ return FluxPipelineOutput(images=image)
flux/pipeline_flux_inpaint.py ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
22
+
23
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
24
+ from .lora.lora_pipeline import FluxLoraLoaderMixin
25
+ from diffusers.models.autoencoders import AutoencoderKL
26
+ from .transformer_flux import FluxTransformer2DModel
27
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ is_torch_xla_available,
31
+ logging,
32
+ replace_example_docstring,
33
+ scale_lora_layers,
34
+ unscale_lora_layers,
35
+ )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
38
+ from .pipeline_output import FluxPipelineOutput
39
+
40
+
41
+ if is_torch_xla_available():
42
+ import torch_xla.core.xla_model as xm
43
+
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```py
54
+ >>> import torch
55
+ >>> from diffusers import FluxInpaintPipeline
56
+ >>> from diffusers.utils import load_image
57
+
58
+ >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
59
+ >>> pipe.to("cuda")
60
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
61
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
62
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
63
+ >>> source = load_image(img_url)
64
+ >>> mask = load_image(mask_url)
65
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
66
+ >>> image.save("flux_inpainting.png")
67
+ ```
68
+ """
69
+
70
+
71
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72
+ def calculate_shift(
73
+ image_seq_len,
74
+ base_seq_len: int = 256,
75
+ max_seq_len: int = 4096,
76
+ base_shift: float = 0.5,
77
+ max_shift: float = 1.16,
78
+ ):
79
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
80
+ b = base_shift - m * base_seq_len
81
+ mu = image_seq_len * m + b
82
+ return mu
83
+
84
+
85
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
86
+ def retrieve_latents(
87
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
88
+ ):
89
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
90
+ return encoder_output.latent_dist.sample(generator)
91
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
92
+ return encoder_output.latent_dist.mode()
93
+ elif hasattr(encoder_output, "latents"):
94
+ return encoder_output.latents
95
+ else:
96
+ raise AttributeError("Could not access latents of provided encoder_output")
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ """
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
160
+ r"""
161
+ The Flux pipeline for image inpainting.
162
+
163
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
164
+
165
+ Args:
166
+ transformer ([`FluxTransformer2DModel`]):
167
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
168
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
169
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
170
+ vae ([`AutoencoderKL`]):
171
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
172
+ text_encoder ([`CLIPTextModel`]):
173
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
174
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
175
+ text_encoder_2 ([`T5EncoderModel`]):
176
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
177
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
178
+ tokenizer (`CLIPTokenizer`):
179
+ Tokenizer of class
180
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
181
+ tokenizer_2 (`T5TokenizerFast`):
182
+ Second Tokenizer of class
183
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
184
+ """
185
+
186
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
187
+ _optional_components = []
188
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
189
+
190
+ def __init__(
191
+ self,
192
+ scheduler: FlowMatchEulerDiscreteScheduler,
193
+ vae: AutoencoderKL,
194
+ text_encoder: CLIPTextModel,
195
+ tokenizer: CLIPTokenizer,
196
+ transformer: FluxTransformer2DModel,
197
+ text_encoder_2: T5EncoderModel | None = None,
198
+ tokenizer_2: T5TokenizerFast | None = None,
199
+ ):
200
+ super().__init__()
201
+
202
+ self.register_modules(
203
+ vae=vae,
204
+ text_encoder=text_encoder,
205
+ #text_encoder_2=text_encoder_2,
206
+ tokenizer=tokenizer,
207
+ #tokenizer_2=tokenizer_2,
208
+ transformer=transformer,
209
+ scheduler=scheduler,
210
+ )
211
+ self.vae_scale_factor = (
212
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
213
+ )
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
215
+ self.mask_processor = VaeImageProcessor(
216
+ vae_scale_factor=self.vae_scale_factor,
217
+ vae_latent_channels=self.vae.config.latent_channels,
218
+ do_normalize=False,
219
+ do_binarize=True,
220
+ do_convert_grayscale=True,
221
+ )
222
+ self.tokenizer_max_length = (
223
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
224
+ )
225
+ self.default_sample_size = 64
226
+
227
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
228
+ def _get_t5_prompt_embeds(
229
+ self,
230
+ prompt: Union[str, List[str]] = None,
231
+ num_images_per_prompt: int = 1,
232
+ max_sequence_length: int = 512,
233
+ device: Optional[torch.device] = None,
234
+ dtype: Optional[torch.dtype] = None,
235
+ ):
236
+ device = device or self._execution_device
237
+ dtype = dtype or self.text_encoder.dtype
238
+
239
+ prompt = [prompt] if isinstance(prompt, str) else prompt
240
+ batch_size = len(prompt)
241
+
242
+ text_inputs = self.tokenizer_2(
243
+ prompt,
244
+ padding="max_length",
245
+ max_length=max_sequence_length,
246
+ truncation=True,
247
+ return_length=False,
248
+ return_overflowing_tokens=False,
249
+ return_tensors="pt",
250
+ )
251
+ text_input_ids = text_inputs.input_ids
252
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
253
+
254
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
255
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
256
+ logger.warning(
257
+ "The following part of your input was truncated because `max_sequence_length` is set to "
258
+ f" {max_sequence_length} tokens: {removed_text}"
259
+ )
260
+
261
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
262
+
263
+ dtype = self.text_encoder_2.dtype
264
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
265
+
266
+ _, seq_len, _ = prompt_embeds.shape
267
+
268
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
271
+
272
+ return prompt_embeds
273
+
274
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
275
+ def _get_clip_prompt_embeds(
276
+ self,
277
+ prompt: Union[str, List[str]],
278
+ num_images_per_prompt: int = 1,
279
+ device: Optional[torch.device] = None,
280
+ ):
281
+ device = device or self._execution_device
282
+
283
+ prompt = [prompt] if isinstance(prompt, str) else prompt
284
+ batch_size = len(prompt)
285
+
286
+ text_inputs = self.tokenizer(
287
+ prompt,
288
+ padding="max_length",
289
+ max_length=self.tokenizer_max_length,
290
+ truncation=True,
291
+ return_overflowing_tokens=False,
292
+ return_length=False,
293
+ return_tensors="pt",
294
+ )
295
+
296
+ text_input_ids = text_inputs.input_ids
297
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
298
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
299
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
300
+ logger.warning(
301
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
302
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
303
+ )
304
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
305
+
306
+ # Use pooled output of CLIPTextModel
307
+ prompt_embeds = prompt_embeds.pooler_output
308
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
309
+
310
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
311
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
312
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
313
+
314
+ return prompt_embeds
315
+
316
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
317
+ def encode_prompt(
318
+ self,
319
+ prompt: Union[str, List[str]],
320
+ prompt_2: Union[str, List[str]],
321
+ device: Optional[torch.device] = None,
322
+ num_images_per_prompt: int = 1,
323
+ prompt_embeds: Optional[torch.FloatTensor] = None,
324
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
325
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
326
+ max_sequence_length: int = 512,
327
+ lora_scale: Optional[float] = None,
328
+ ):
329
+ r"""
330
+
331
+ Args:
332
+ prompt (`str` or `List[str]`, *optional*):
333
+ prompt to be encoded
334
+ prompt_2 (`str` or `List[str]`, *optional*):
335
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
336
+ used in all text-encoders
337
+ device: (`torch.device`):
338
+ torch device
339
+ num_images_per_prompt (`int`):
340
+ number of images that should be generated per prompt
341
+ prompt_embeds (`torch.FloatTensor`, *optional*):
342
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
343
+ provided, text embeddings will be generated from `prompt` input argument.
344
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
345
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
346
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
347
+ lora_scale (`float`, *optional*):
348
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
349
+ """
350
+ device = device or self._execution_device
351
+
352
+ # set lora scale so that monkey patched LoRA
353
+ # function of text encoder can correctly access it
354
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
355
+ self._lora_scale = lora_scale
356
+
357
+ # dynamically adjust the LoRA scale
358
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
359
+ scale_lora_layers(self.text_encoder, lora_scale)
360
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
361
+ scale_lora_layers(self.text_encoder_2, lora_scale)
362
+
363
+ prompt = [prompt] if isinstance(prompt, str) else prompt
364
+ if prompt is not None:
365
+ batch_size = len(prompt)
366
+ else:
367
+ batch_size = prompt_embeds.shape[0]
368
+
369
+ if prompt_embeds is None:
370
+ prompt_2 = prompt_2 or prompt
371
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
372
+
373
+ # We only use the pooled prompt output from the CLIPTextModel
374
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
375
+ prompt=prompt,
376
+ device=device,
377
+ num_images_per_prompt=num_images_per_prompt,
378
+ )
379
+ prompt_embeds = self._get_t5_prompt_embeds(
380
+ prompt=prompt_2,
381
+ num_images_per_prompt=num_images_per_prompt,
382
+ max_sequence_length=max_sequence_length,
383
+ device=device,
384
+ )
385
+
386
+ if self.text_encoder is not None:
387
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
388
+ # Retrieve the original scale by scaling back the LoRA layers
389
+ unscale_lora_layers(self.text_encoder, lora_scale)
390
+
391
+ #if self.text_encoder_2 is not None:
392
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
393
+ # # Retrieve the original scale by scaling back the LoRA layers
394
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
395
+
396
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
397
+ if t5_prompt_embeds is not None:
398
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
399
+ else:
400
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
401
+ #text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
402
+
403
+ return prompt_embeds, pooled_prompt_embeds, text_ids
404
+
405
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
406
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
407
+ if isinstance(generator, list):
408
+ image_latents = [
409
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
410
+ for i in range(image.shape[0])
411
+ ]
412
+ image_latents = torch.cat(image_latents, dim=0)
413
+ else:
414
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
415
+
416
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
417
+
418
+ return image_latents
419
+
420
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
421
+ def get_timesteps(self, num_inference_steps, strength, device):
422
+ # get the original timestep using init_timestep
423
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
424
+
425
+ t_start = int(max(num_inference_steps - init_timestep, 0))
426
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
427
+ if hasattr(self.scheduler, "set_begin_index"):
428
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
429
+
430
+ return timesteps, num_inference_steps - t_start
431
+
432
+ def check_inputs(
433
+ self,
434
+ prompt,
435
+ prompt_2,
436
+ image,
437
+ mask_image,
438
+ strength,
439
+ height,
440
+ width,
441
+ output_type,
442
+ prompt_embeds=None,
443
+ pooled_prompt_embeds=None,
444
+ callback_on_step_end_tensor_inputs=None,
445
+ padding_mask_crop=None,
446
+ max_sequence_length=None,
447
+ ):
448
+ if strength < 0 or strength > 1:
449
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
450
+
451
+ if height % 8 != 0 or width % 8 != 0:
452
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
453
+
454
+ if callback_on_step_end_tensor_inputs is not None and not all(
455
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
456
+ ):
457
+ raise ValueError(
458
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
459
+ )
460
+
461
+ if prompt is not None and prompt_embeds is not None:
462
+ raise ValueError(
463
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
464
+ " only forward one of the two."
465
+ )
466
+ elif prompt_2 is not None and prompt_embeds is not None:
467
+ raise ValueError(
468
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
469
+ " only forward one of the two."
470
+ )
471
+ elif prompt is None and prompt_embeds is None:
472
+ raise ValueError(
473
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
474
+ )
475
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
476
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
477
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
478
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
479
+
480
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
481
+ raise ValueError(
482
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
483
+ )
484
+
485
+ if padding_mask_crop is not None:
486
+ if not isinstance(image, PIL.Image.Image):
487
+ raise ValueError(
488
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
489
+ )
490
+ if not isinstance(mask_image, PIL.Image.Image):
491
+ raise ValueError(
492
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
493
+ f" {type(mask_image)}."
494
+ )
495
+ if output_type != "pil":
496
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
497
+
498
+ if max_sequence_length is not None and max_sequence_length > 512:
499
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
500
+
501
+ @staticmethod
502
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
503
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
504
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
505
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
506
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
507
+
508
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
509
+
510
+ latent_image_ids = latent_image_ids.reshape(
511
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
512
+ )
513
+
514
+ return latent_image_ids.to(device=device, dtype=dtype)
515
+
516
+ @staticmethod
517
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
518
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
519
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
520
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
521
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
522
+
523
+ return latents
524
+
525
+ @staticmethod
526
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
527
+ def _unpack_latents(latents, height, width, vae_scale_factor):
528
+ batch_size, num_patches, channels = latents.shape
529
+
530
+ height = height // vae_scale_factor
531
+ width = width // vae_scale_factor
532
+
533
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
534
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
535
+
536
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
537
+
538
+ return latents
539
+
540
+ def prepare_latents(
541
+ self,
542
+ image,
543
+ timestep,
544
+ batch_size,
545
+ num_channels_latents,
546
+ height,
547
+ width,
548
+ dtype,
549
+ device,
550
+ generator,
551
+ latents=None,
552
+ ):
553
+ if isinstance(generator, list) and len(generator) != batch_size:
554
+ raise ValueError(
555
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
556
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
557
+ )
558
+
559
+ height = 2 * (int(height) // self.vae_scale_factor)
560
+ width = 2 * (int(width) // self.vae_scale_factor)
561
+
562
+ shape = (batch_size, num_channels_latents, height, width)
563
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
564
+
565
+ image = image.to(device=device, dtype=dtype)
566
+ image_latents = self._encode_vae_image(image=image, generator=generator)
567
+
568
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
569
+ # expand init_latents for batch_size
570
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
571
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
572
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
573
+ raise ValueError(
574
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
575
+ )
576
+ else:
577
+ image_latents = torch.cat([image_latents], dim=0)
578
+
579
+ if latents is None:
580
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
581
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
582
+ else:
583
+ noise = latents.to(device)
584
+ latents = noise
585
+
586
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
587
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
588
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
589
+ return latents, noise, image_latents, latent_image_ids
590
+
591
+ def prepare_mask_latents(
592
+ self,
593
+ mask,
594
+ masked_image,
595
+ batch_size,
596
+ num_channels_latents,
597
+ num_images_per_prompt,
598
+ height,
599
+ width,
600
+ dtype,
601
+ device,
602
+ generator,
603
+ ):
604
+ height = 2 * (int(height) // self.vae_scale_factor)
605
+ width = 2 * (int(width) // self.vae_scale_factor)
606
+ # resize the mask to latents shape as we concatenate the mask to the latents
607
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
608
+ # and half precision
609
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
610
+ mask = mask.to(device=device, dtype=dtype)
611
+
612
+ batch_size = batch_size * num_images_per_prompt
613
+
614
+ masked_image = masked_image.to(device=device, dtype=dtype)
615
+
616
+ if masked_image.shape[1] == 16:
617
+ masked_image_latents = masked_image
618
+ else:
619
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
620
+
621
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
622
+
623
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
624
+ if mask.shape[0] < batch_size:
625
+ if not batch_size % mask.shape[0] == 0:
626
+ raise ValueError(
627
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
628
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
629
+ " of masks that you pass is divisible by the total requested batch size."
630
+ )
631
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
632
+ if masked_image_latents.shape[0] < batch_size:
633
+ if not batch_size % masked_image_latents.shape[0] == 0:
634
+ raise ValueError(
635
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
636
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
637
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
638
+ )
639
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
640
+
641
+ # aligning device to prevent device errors when concating it with the latent model input
642
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
643
+
644
+ masked_image_latents = self._pack_latents(
645
+ masked_image_latents,
646
+ batch_size,
647
+ num_channels_latents,
648
+ height,
649
+ width,
650
+ )
651
+ mask = self._pack_latents(
652
+ mask.repeat(1, num_channels_latents, 1, 1),
653
+ batch_size,
654
+ num_channels_latents,
655
+ height,
656
+ width,
657
+ )
658
+
659
+ return mask, masked_image_latents
660
+
661
+ @property
662
+ def guidance_scale(self):
663
+ return self._guidance_scale
664
+
665
+ @property
666
+ def joint_attention_kwargs(self):
667
+ return self._joint_attention_kwargs
668
+
669
+ @property
670
+ def num_timesteps(self):
671
+ return self._num_timesteps
672
+
673
+ @property
674
+ def interrupt(self):
675
+ return self._interrupt
676
+
677
+ @torch.no_grad()
678
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
679
+ def __call__(
680
+ self,
681
+ prompt: Union[str, List[str]] = None,
682
+ prompt_2: Optional[Union[str, List[str]]] = None,
683
+ image: PipelineImageInput = None,
684
+ mask_image: PipelineImageInput = None,
685
+ masked_image_latents: PipelineImageInput = None,
686
+ height: Optional[int] = None,
687
+ width: Optional[int] = None,
688
+ padding_mask_crop: Optional[int] = None,
689
+ strength: float = 0.6,
690
+ num_inference_steps: int = 28,
691
+ timesteps: List[int] = None,
692
+ guidance_scale: float = 7.0,
693
+ num_images_per_prompt: Optional[int] = 1,
694
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
695
+ latents: Optional[torch.FloatTensor] = None,
696
+ prompt_embeds: Optional[torch.FloatTensor] = None,
697
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
698
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
699
+ output_type: Optional[str] = "pil",
700
+ return_dict: bool = True,
701
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
702
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
703
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
704
+ max_sequence_length: int = 512,
705
+ ):
706
+ r"""
707
+ Function invoked when calling the pipeline for generation.
708
+
709
+ Args:
710
+ prompt (`str` or `List[str]`, *optional*):
711
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
712
+ instead.
713
+ prompt_2 (`str` or `List[str]`, *optional*):
714
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
715
+ will be used instead
716
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
717
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
718
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
719
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
720
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
721
+ latents as `image`, but if passing latents directly it is not encoded again.
722
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
723
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
724
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
725
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
726
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
727
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
728
+ 1)`, or `(H, W)`.
729
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
730
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
731
+ latents tensor will ge generated by `mask_image`.
732
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
734
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
735
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
736
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
737
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
738
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
739
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
740
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
741
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
742
+ the image is large and contain information irrelevant for inpainting, such as background.
743
+ strength (`float`, *optional*, defaults to 1.0):
744
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
745
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
746
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
747
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
748
+ essentially ignores `image`.
749
+ num_inference_steps (`int`, *optional*, defaults to 50):
750
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
751
+ expense of slower inference.
752
+ timesteps (`List[int]`, *optional*):
753
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
754
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
755
+ passed will be used. Must be in descending order.
756
+ guidance_scale (`float`, *optional*, defaults to 7.0):
757
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
758
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
759
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
760
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
761
+ usually at the expense of lower image quality.
762
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
763
+ The number of images to generate per prompt.
764
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
765
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
766
+ to make generation deterministic.
767
+ latents (`torch.FloatTensor`, *optional*):
768
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
769
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
770
+ tensor will ge generated by sampling using the supplied random `generator`.
771
+ prompt_embeds (`torch.FloatTensor`, *optional*):
772
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
773
+ provided, text embeddings will be generated from `prompt` input argument.
774
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
775
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
776
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
777
+ output_type (`str`, *optional*, defaults to `"pil"`):
778
+ The output format of the generate image. Choose between
779
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
780
+ return_dict (`bool`, *optional*, defaults to `True`):
781
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
782
+ joint_attention_kwargs (`dict`, *optional*):
783
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
784
+ `self.processor` in
785
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
786
+ callback_on_step_end (`Callable`, *optional*):
787
+ A function that calls at the end of each denoising steps during the inference. The function is called
788
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
789
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
790
+ `callback_on_step_end_tensor_inputs`.
791
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
792
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
793
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
794
+ `._callback_tensor_inputs` attribute of your pipeline class.
795
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
796
+
797
+ Examples:
798
+
799
+ Returns:
800
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
801
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
802
+ images.
803
+ """
804
+
805
+ height = height or self.default_sample_size * self.vae_scale_factor
806
+ width = width or self.default_sample_size * self.vae_scale_factor
807
+
808
+ # 1. Check inputs. Raise error if not correct
809
+ self.check_inputs(
810
+ prompt,
811
+ prompt_2,
812
+ image,
813
+ mask_image,
814
+ strength,
815
+ height,
816
+ width,
817
+ output_type=output_type,
818
+ prompt_embeds=prompt_embeds,
819
+ pooled_prompt_embeds=pooled_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
821
+ padding_mask_crop=padding_mask_crop,
822
+ max_sequence_length=max_sequence_length,
823
+ )
824
+
825
+ self._guidance_scale = guidance_scale
826
+ self._joint_attention_kwargs = joint_attention_kwargs
827
+ self._interrupt = False
828
+
829
+ # 2. Preprocess mask and image
830
+ if padding_mask_crop is not None:
831
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
832
+ resize_mode = "fill"
833
+ else:
834
+ crops_coords = None
835
+ resize_mode = "default"
836
+
837
+ original_image = image
838
+ init_image = self.image_processor.preprocess(
839
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
840
+ )
841
+ init_image = init_image.to(dtype=torch.float32)
842
+
843
+ # 3. Define call parameters
844
+ if prompt is not None and isinstance(prompt, str):
845
+ batch_size = 1
846
+ elif prompt is not None and isinstance(prompt, list):
847
+ batch_size = len(prompt)
848
+ else:
849
+ batch_size = prompt_embeds.shape[0]
850
+
851
+ device = self._execution_device
852
+
853
+ lora_scale = (
854
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
855
+ )
856
+ (
857
+ prompt_embeds,
858
+ pooled_prompt_embeds,
859
+ text_ids,
860
+ ) = self.encode_prompt(
861
+ prompt=prompt,
862
+ prompt_2=prompt_2,
863
+ prompt_embeds=prompt_embeds,
864
+ t5_prompt_embeds=t5_prompt_embeds,
865
+ pooled_prompt_embeds=pooled_prompt_embeds,
866
+ device=device,
867
+ num_images_per_prompt=num_images_per_prompt,
868
+ max_sequence_length=max_sequence_length,
869
+ lora_scale=lora_scale,
870
+ )
871
+
872
+ # 4.Prepare timesteps
873
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
874
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
875
+ mu = calculate_shift(
876
+ image_seq_len,
877
+ self.scheduler.config.base_image_seq_len,
878
+ self.scheduler.config.max_image_seq_len,
879
+ self.scheduler.config.base_shift,
880
+ self.scheduler.config.max_shift,
881
+ )
882
+ timesteps, num_inference_steps = retrieve_timesteps(
883
+ self.scheduler,
884
+ num_inference_steps,
885
+ device,
886
+ timesteps,
887
+ sigmas,
888
+ mu=mu,
889
+ )
890
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
891
+
892
+ if num_inference_steps < 1:
893
+ raise ValueError(
894
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
895
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
896
+ )
897
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
898
+
899
+ # 5. Prepare latent variables
900
+ num_channels_latents = self.transformer.config.in_channels // 4
901
+ num_channels_transformer = self.transformer.config.in_channels
902
+
903
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
904
+ init_image,
905
+ latent_timestep,
906
+ batch_size * num_images_per_prompt,
907
+ num_channels_latents,
908
+ height,
909
+ width,
910
+ prompt_embeds.dtype,
911
+ device,
912
+ generator,
913
+ latents,
914
+ )
915
+
916
+ mask_condition = self.mask_processor.preprocess(
917
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
918
+ )
919
+
920
+ if masked_image_latents is None:
921
+ masked_image = init_image * (mask_condition < 0.5)
922
+ else:
923
+ masked_image = masked_image_latents
924
+
925
+ mask, masked_image_latents = self.prepare_mask_latents(
926
+ mask_condition,
927
+ masked_image,
928
+ batch_size,
929
+ num_channels_latents,
930
+ num_images_per_prompt,
931
+ height,
932
+ width,
933
+ prompt_embeds.dtype,
934
+ device,
935
+ generator,
936
+ )
937
+
938
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
939
+ self._num_timesteps = len(timesteps)
940
+
941
+ # handle guidance
942
+ if self.transformer.config.guidance_embeds:
943
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
944
+ guidance = guidance.expand(latents.shape[0])
945
+ else:
946
+ guidance = None
947
+
948
+ # 6. Denoising loop
949
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
950
+ for i, t in enumerate(timesteps):
951
+ if self.interrupt:
952
+ continue
953
+
954
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
955
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
956
+ noise_pred = self.transformer(
957
+ hidden_states=latents,
958
+ timestep=timestep / 1000,
959
+ guidance=guidance,
960
+ pooled_projections=pooled_prompt_embeds,
961
+ encoder_hidden_states=prompt_embeds,
962
+ t5_encoder_hidden_states=t5_prompt_embeds,
963
+ txt_ids=text_ids,
964
+ img_ids=latent_image_ids,
965
+ joint_attention_kwargs=self.joint_attention_kwargs,
966
+ return_dict=False,
967
+ )[0]
968
+
969
+ # compute the previous noisy sample x_t -> x_t-1
970
+ latents_dtype = latents.dtype
971
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
972
+
973
+ # for 64 channel transformer only.
974
+ init_latents_proper = image_latents
975
+ init_mask = mask
976
+
977
+ if i < len(timesteps) - 1:
978
+ noise_timestep = timesteps[i + 1]
979
+ init_latents_proper = self.scheduler.scale_noise(
980
+ init_latents_proper, torch.tensor([noise_timestep]), noise
981
+ )
982
+
983
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
984
+
985
+ if latents.dtype != latents_dtype:
986
+ if torch.backends.mps.is_available():
987
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
988
+ latents = latents.to(latents_dtype)
989
+
990
+ if callback_on_step_end is not None:
991
+ callback_kwargs = {}
992
+ for k in callback_on_step_end_tensor_inputs:
993
+ callback_kwargs[k] = locals()[k]
994
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
995
+
996
+ latents = callback_outputs.pop("latents", latents)
997
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
998
+
999
+ # call the callback, if provided
1000
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1001
+ progress_bar.update()
1002
+
1003
+ if XLA_AVAILABLE:
1004
+ xm.mark_step()
1005
+
1006
+ if output_type == "latent":
1007
+ image = latents
1008
+
1009
+ else:
1010
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1011
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1012
+ image = self.vae.decode(latents, return_dict=False)[0]
1013
+ image = self.image_processor.postprocess(image, output_type=output_type)
1014
+
1015
+ # Offload all models
1016
+ self.maybe_free_model_hooks()
1017
+
1018
+ if not return_dict:
1019
+ return (image,)
1020
+
1021
+ return FluxPipelineOutput(images=image)
flux/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from diffusers.utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class FluxPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
flux/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ use_dynamic_shifting=False,
70
+ base_shift: Optional[float] = 0.5,
71
+ max_shift: Optional[float] = 1.15,
72
+ base_image_seq_len: Optional[int] = 256,
73
+ max_image_seq_len: Optional[int] = 4096,
74
+ ):
75
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
76
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
77
+
78
+ sigmas = timesteps / num_train_timesteps
79
+ if not use_dynamic_shifting:
80
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
81
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
82
+
83
+ self.timesteps = sigmas * num_train_timesteps
84
+
85
+ self._step_index = None
86
+ self._begin_index = None
87
+
88
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
+ self.sigma_min = self.sigmas[-1].item()
90
+ self.sigma_max = self.sigmas[0].item()
91
+
92
+ @property
93
+ def step_index(self):
94
+ """
95
+ The index counter for current timestep. It will increase 1 after each scheduler step.
96
+ """
97
+ return self._step_index
98
+
99
+ @property
100
+ def begin_index(self):
101
+ """
102
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
103
+ """
104
+ return self._begin_index
105
+
106
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
107
+ def set_begin_index(self, begin_index: int = 0):
108
+ """
109
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
110
+
111
+ Args:
112
+ begin_index (`int`):
113
+ The begin index for the scheduler.
114
+ """
115
+ self._begin_index = begin_index
116
+
117
+ def scale_noise(
118
+ self,
119
+ sample: torch.FloatTensor,
120
+ timestep: Union[float, torch.FloatTensor],
121
+ noise: Optional[torch.FloatTensor] = None,
122
+ ) -> torch.FloatTensor:
123
+ """
124
+ Forward process in flow-matching
125
+
126
+ Args:
127
+ sample (`torch.FloatTensor`):
128
+ The input sample.
129
+ timestep (`int`, *optional*):
130
+ The current timestep in the diffusion chain.
131
+
132
+ Returns:
133
+ `torch.FloatTensor`:
134
+ A scaled input sample.
135
+ """
136
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
137
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
138
+
139
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
140
+ # mps does not support float64
141
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
142
+ timestep = timestep.to(sample.device, dtype=torch.float32)
143
+ else:
144
+ schedule_timesteps = self.timesteps.to(sample.device)
145
+ timestep = timestep.to(sample.device)
146
+
147
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
148
+ if self.begin_index is None:
149
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
150
+ elif self.step_index is not None:
151
+ # add_noise is called after first denoising step (for inpainting)
152
+ step_indices = [self.step_index] * timestep.shape[0]
153
+ else:
154
+ # add noise is called before first denoising step to create initial latent(img2img)
155
+ step_indices = [self.begin_index] * timestep.shape[0]
156
+
157
+ sigma = sigmas[step_indices].flatten()
158
+ while len(sigma.shape) < len(sample.shape):
159
+ sigma = sigma.unsqueeze(-1)
160
+
161
+ sample = sigma * noise + (1.0 - sigma) * sample
162
+
163
+ return sample
164
+
165
+ def _sigma_to_t(self, sigma):
166
+ return sigma * self.config.num_train_timesteps
167
+
168
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
169
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
170
+
171
+ def set_timesteps(
172
+ self,
173
+ num_inference_steps: int = None,
174
+ device: Union[str, torch.device] = None,
175
+ sigmas: Optional[List[float]] = None,
176
+ mu: Optional[float] = None,
177
+ ):
178
+ """
179
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
180
+
181
+ Args:
182
+ num_inference_steps (`int`):
183
+ The number of diffusion steps used when generating samples with a pre-trained model.
184
+ device (`str` or `torch.device`, *optional*):
185
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
186
+ """
187
+
188
+ if self.config.use_dynamic_shifting and mu is None:
189
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
190
+
191
+ if sigmas is None:
192
+ self.num_inference_steps = num_inference_steps
193
+ timesteps = np.linspace(
194
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
195
+ )
196
+
197
+ sigmas = timesteps / self.config.num_train_timesteps
198
+
199
+ if self.config.use_dynamic_shifting:
200
+ sigmas = self.time_shift(mu, 1.0, sigmas)
201
+ else:
202
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
203
+
204
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
205
+ timesteps = sigmas * self.config.num_train_timesteps
206
+
207
+ self.timesteps = timesteps.to(device=device)
208
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
209
+
210
+ self._step_index = None
211
+ self._begin_index = None
212
+
213
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
214
+ if schedule_timesteps is None:
215
+ schedule_timesteps = self.timesteps
216
+
217
+ indices = (schedule_timesteps == timestep).nonzero()
218
+
219
+ # The sigma index that is taken for the **very** first `step`
220
+ # is always the second index (or the last index if there is only 1)
221
+ # This way we can ensure we don't accidentally skip a sigma in
222
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
223
+ pos = 1 if len(indices) > 1 else 0
224
+
225
+ return indices[pos].item()
226
+
227
+ def _init_step_index(self, timestep):
228
+ if self.begin_index is None:
229
+ if isinstance(timestep, torch.Tensor):
230
+ timestep = timestep.to(self.timesteps.device)
231
+ self._step_index = self.index_for_timestep(timestep)
232
+ else:
233
+ self._step_index = self._begin_index
234
+
235
+ def step(
236
+ self,
237
+ model_output: torch.FloatTensor,
238
+ timestep: Union[float, torch.FloatTensor],
239
+ sample: torch.FloatTensor,
240
+ s_churn: float = 0.0,
241
+ s_tmin: float = 0.0,
242
+ s_tmax: float = float("inf"),
243
+ s_noise: float = 1.0,
244
+ generator: Optional[torch.Generator] = None,
245
+ return_dict: bool = True,
246
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
247
+ """
248
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
249
+ process from the learned model outputs (most often the predicted noise).
250
+
251
+ Args:
252
+ model_output (`torch.FloatTensor`):
253
+ The direct output from learned diffusion model.
254
+ timestep (`float`):
255
+ The current discrete timestep in the diffusion chain.
256
+ sample (`torch.FloatTensor`):
257
+ A current instance of a sample created by the diffusion process.
258
+ s_churn (`float`):
259
+ s_tmin (`float`):
260
+ s_tmax (`float`):
261
+ s_noise (`float`, defaults to 1.0):
262
+ Scaling factor for noise added to the sample.
263
+ generator (`torch.Generator`, *optional*):
264
+ A random number generator.
265
+ return_dict (`bool`):
266
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
267
+ tuple.
268
+
269
+ Returns:
270
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
271
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
272
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
273
+ """
274
+
275
+ if (
276
+ isinstance(timestep, int)
277
+ or isinstance(timestep, torch.IntTensor)
278
+ or isinstance(timestep, torch.LongTensor)
279
+ ):
280
+ raise ValueError(
281
+ (
282
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
283
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
284
+ " one of the `scheduler.timesteps` as a timestep."
285
+ ),
286
+ )
287
+
288
+ if self.step_index is None:
289
+ self._init_step_index(timestep)
290
+
291
+ # Upcast to avoid precision issues when computing prev_sample
292
+ sample = sample.to(torch.float32)
293
+
294
+ sigma = self.sigmas[self.step_index]
295
+ sigma_next = self.sigmas[self.step_index + 1]
296
+
297
+ prev_sample = sample + (sigma_next - sigma) * model_output
298
+
299
+ # Cast sample back to model compatible dtype
300
+ prev_sample = prev_sample.to(model_output.dtype)
301
+
302
+ # upon completion increase step index by one
303
+ self._step_index += 1
304
+
305
+ if not return_dict:
306
+ return (prev_sample,)
307
+
308
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
309
+
310
+ def __len__(self):
311
+ return self.config.num_train_timesteps
312
+
313
+ def step_to_x0(self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor) -> torch.FloatTensor:
314
+ """
315
+ Compute the predicted x_0 given the model output and current sample at timestep t.
316
+ """
317
+ if self.step_index is None:
318
+ self._init_step_index(timestep)
319
+
320
+ sigma = self.sigmas[self.step_index]
321
+ sigma_from = sigma
322
+ sigma_to = self.sigmas[-1] # This corresponds to x_0
323
+
324
+ x0 = sample + (sigma_to - sigma_from) * model_output
325
+ return x0
flux/transformer_flux.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from .lora.peft import PeftAdapterMixin
24
+ from diffusers.models.attention import FeedForward
25
+ from .attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from .normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
28
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
30
+ from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ import numpy as np
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ def get_1d_rotary_pos_embed(
38
+ dim: int,
39
+ pos: Union[np.ndarray, int],
40
+ theta: float = 10000.0,
41
+ use_real=False,
42
+ linear_factor=1.0,
43
+ ntk_factor=1.0,
44
+ repeat_interleave_real=True,
45
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
46
+ ):
47
+ """
48
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
49
+
50
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
51
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
52
+ data type.
53
+
54
+ Args:
55
+ dim (`int`): Dimension of the frequency tensor.
56
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
57
+ theta (`float`, *optional*, defaults to 10000.0):
58
+ Scaling factor for frequency computation. Defaults to 10000.0.
59
+ use_real (`bool`, *optional*):
60
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
61
+ linear_factor (`float`, *optional*, defaults to 1.0):
62
+ Scaling factor for the context extrapolation. Defaults to 1.0.
63
+ ntk_factor (`float`, *optional*, defaults to 1.0):
64
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
65
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
66
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
67
+ Otherwise, they are concateanted with themselves.
68
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
69
+ the dtype of the frequency tensor.
70
+ Returns:
71
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
72
+ """
73
+ assert dim % 2 == 0
74
+
75
+ if isinstance(pos, int):
76
+ pos = torch.arange(pos)
77
+ if isinstance(pos, np.ndarray):
78
+ pos = torch.from_numpy(pos) # type: ignore # [S]
79
+
80
+ theta = theta * ntk_factor
81
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
82
+ freqs = freqs.to(pos.device)
83
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
84
+ if use_real and repeat_interleave_real:
85
+ # flux, hunyuan-dit, cogvideox
86
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
87
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
88
+ return freqs_cos, freqs_sin
89
+ elif use_real:
90
+ # stable audio
91
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
92
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
93
+ return freqs_cos, freqs_sin
94
+ else:
95
+ # lumina
96
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
97
+ return freqs_cis
98
+
99
+ class FluxPosEmbed(nn.Module):
100
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
101
+ def __init__(self, theta: int, axes_dim: List[int]):
102
+ super().__init__()
103
+ self.theta = theta
104
+ self.axes_dim = axes_dim
105
+
106
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
107
+ n_axes = ids.shape[-1]
108
+ cos_out = []
109
+ sin_out = []
110
+ pos = ids.squeeze().float()
111
+ is_mps = ids.device.type == "mps"
112
+ freqs_dtype = torch.float32 if is_mps else torch.float64
113
+ for i in range(n_axes):
114
+ cos, sin = get_1d_rotary_pos_embed(
115
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
116
+ )
117
+ cos_out.append(cos)
118
+ sin_out.append(sin)
119
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
120
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
121
+ return freqs_cos, freqs_sin
122
+
123
+ # YiYi to-do: refactor rope related functions/classes
124
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
125
+ assert dim % 2 == 0, "The dimension must be even."
126
+
127
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
128
+ omega = 1.0 / (theta**scale)
129
+
130
+ batch_size, seq_length = pos.shape
131
+ out = torch.einsum("...n,d->...nd", pos, omega)
132
+ cos_out = torch.cos(out)
133
+ sin_out = torch.sin(out)
134
+
135
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
136
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
137
+ return out.float()
138
+
139
+
140
+ # YiYi to-do: refactor rope related functions/classes
141
+ class EmbedND(nn.Module):
142
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
143
+ super().__init__()
144
+ self.dim = dim
145
+ self.theta = theta
146
+ self.axes_dim = axes_dim
147
+
148
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
149
+ n_axes = ids.shape[-1]
150
+ emb = torch.cat(
151
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
152
+ dim=-3,
153
+ )
154
+ return emb.unsqueeze(1)
155
+
156
+
157
+ @maybe_allow_in_graph
158
+ class FluxSingleTransformerBlock(nn.Module):
159
+ r"""
160
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
161
+
162
+ Reference: https://arxiv.org/abs/2403.03206
163
+
164
+ Parameters:
165
+ dim (`int`): The number of channels in the input and output.
166
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
167
+ attention_head_dim (`int`): The number of channels in each head.
168
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
169
+ processing of `context` conditions.
170
+ """
171
+
172
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
173
+ super().__init__()
174
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
175
+
176
+ self.norm = AdaLayerNormZeroSingle(dim)
177
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
178
+ self.act_mlp = nn.GELU(approximate="tanh")
179
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
180
+
181
+ processor = FluxSingleAttnProcessor2_0()
182
+ self.attn = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=None,
185
+ dim_head=attention_head_dim,
186
+ heads=num_attention_heads,
187
+ out_dim=dim,
188
+ bias=True,
189
+ processor=processor,
190
+ qk_norm="rms_norm",
191
+ eps=1e-6,
192
+ pre_only=True,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states: torch.FloatTensor,
198
+ temb: torch.FloatTensor,
199
+ image_rotary_emb=None,
200
+ ):
201
+ residual = hidden_states
202
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
203
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
204
+
205
+ attn_output = self.attn(
206
+ hidden_states=norm_hidden_states,
207
+ image_rotary_emb=image_rotary_emb,
208
+ )
209
+
210
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
211
+ gate = gate.unsqueeze(1)
212
+ hidden_states = gate * self.proj_out(hidden_states)
213
+ hidden_states = residual + hidden_states
214
+
215
+ return hidden_states
216
+
217
+
218
+ @maybe_allow_in_graph
219
+ class FluxTransformerBlock(nn.Module):
220
+ r"""
221
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
222
+
223
+ Reference: https://arxiv.org/abs/2403.03206
224
+
225
+ Parameters:
226
+ dim (`int`): The number of channels in the input and output.
227
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
228
+ attention_head_dim (`int`): The number of channels in each head.
229
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
230
+ processing of `context` conditions.
231
+ """
232
+
233
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
234
+ super().__init__()
235
+
236
+ self.norm1 = AdaLayerNormZero(dim)
237
+
238
+ self.norm1_context = AdaLayerNormZero(dim)
239
+
240
+ if hasattr(F, "scaled_dot_product_attention"):
241
+ processor = FluxAttnProcessor2_0()
242
+ else:
243
+ raise ValueError(
244
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
245
+ )
246
+ self.attn = Attention(
247
+ query_dim=dim,
248
+ cross_attention_dim=None,
249
+ added_kv_proj_dim=dim,
250
+ dim_head=attention_head_dim,
251
+ heads=num_attention_heads,
252
+ out_dim=dim,
253
+ context_pre_only=False,
254
+ bias=True,
255
+ processor=processor,
256
+ qk_norm=qk_norm,
257
+ eps=eps,
258
+ )
259
+
260
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
261
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
262
+
263
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
264
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
265
+
266
+ # let chunk size default to None
267
+ self._chunk_size = None
268
+ self._chunk_dim = 0
269
+
270
+ def forward(
271
+ self,
272
+ hidden_states: torch.FloatTensor,
273
+ encoder_hidden_states: torch.FloatTensor,
274
+ temb: torch.FloatTensor,
275
+ image_rotary_emb=None,
276
+ ):
277
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
278
+
279
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
280
+ encoder_hidden_states, emb=temb
281
+ )
282
+
283
+ # Attention.
284
+ attn_output, context_attn_output = self.attn(
285
+ hidden_states=norm_hidden_states,
286
+ encoder_hidden_states=norm_encoder_hidden_states,
287
+ image_rotary_emb=image_rotary_emb,
288
+ )
289
+
290
+ # Process attention outputs for the `hidden_states`.
291
+ attn_output = gate_msa.unsqueeze(1) * attn_output
292
+ hidden_states = hidden_states + attn_output
293
+
294
+ norm_hidden_states = self.norm2(hidden_states)
295
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
296
+
297
+ ff_output = self.ff(norm_hidden_states)
298
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
299
+
300
+ hidden_states = hidden_states + ff_output
301
+
302
+ # Process attention outputs for the `encoder_hidden_states`.
303
+
304
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
305
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
306
+
307
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
308
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
309
+
310
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
311
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
312
+
313
+ return encoder_hidden_states, hidden_states
314
+
315
+
316
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
317
+ """
318
+ The Transformer model introduced in Flux.
319
+
320
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
321
+
322
+ Parameters:
323
+ patch_size (`int`): Patch size to turn the input data into small patches.
324
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
325
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
326
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
327
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
328
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
329
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
330
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
331
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
332
+ """
333
+
334
+ _supports_gradient_checkpointing = True
335
+
336
+ @register_to_config
337
+ def __init__(
338
+ self,
339
+ patch_size: int = 1,
340
+ in_channels: int = 64,
341
+ num_layers: int = 19,
342
+ num_single_layers: int = 38,
343
+ attention_head_dim: int = 128,
344
+ num_attention_heads: int = 24,
345
+ joint_attention_dim: int = 4096,
346
+ pooled_projection_dim: int = 768,
347
+ guidance_embeds: bool = False,
348
+ axes_dims_rope: List[int] = [16, 56, 56],
349
+ ):
350
+ super().__init__()
351
+ self.out_channels = in_channels
352
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
353
+
354
+ #self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
355
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
356
+ text_time_guidance_cls = (
357
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
358
+ )
359
+ self.time_text_embed = text_time_guidance_cls(
360
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
361
+ )
362
+
363
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
364
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
365
+
366
+ self.transformer_blocks = nn.ModuleList(
367
+ [
368
+ FluxTransformerBlock(
369
+ dim=self.inner_dim,
370
+ num_attention_heads=self.config.num_attention_heads,
371
+ attention_head_dim=self.config.attention_head_dim,
372
+ )
373
+ for i in range(self.config.num_layers)
374
+ ]
375
+ )
376
+
377
+ self.single_transformer_blocks = nn.ModuleList(
378
+ [
379
+ FluxSingleTransformerBlock(
380
+ dim=self.inner_dim,
381
+ num_attention_heads=self.config.num_attention_heads,
382
+ attention_head_dim=self.config.attention_head_dim,
383
+ )
384
+ for i in range(self.config.num_single_layers)
385
+ ]
386
+ )
387
+
388
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
389
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
390
+
391
+ self.gradient_checkpointing = True
392
+
393
+ def _set_gradient_checkpointing(self, module, value=False):
394
+ if hasattr(module, "gradient_checkpointing"):
395
+ module.gradient_checkpointing = value
396
+
397
+ def forward(
398
+ self,
399
+ hidden_states: torch.Tensor,
400
+ encoder_hidden_states: torch.Tensor = None,
401
+ t5_encoder_hidden_states: torch.Tensor = None,
402
+ pooled_projections: torch.Tensor = None,
403
+ timestep: torch.LongTensor = None,
404
+ img_ids: torch.Tensor = None,
405
+ txt_ids: torch.Tensor = None,
406
+ guidance: torch.Tensor = None,
407
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
408
+ controlnet_block_samples=None,
409
+ controlnet_single_block_samples=None,
410
+ return_dict: bool = True,
411
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
412
+ """
413
+ The [`FluxTransformer2DModel`] forward method.
414
+
415
+ Args:
416
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
417
+ Input `hidden_states`.
418
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
419
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
420
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
421
+ from the embeddings of input conditions.
422
+ timestep ( `torch.LongTensor`):
423
+ Used to indicate denoising step.
424
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
425
+ A list of tensors that if specified are added to the residuals of transformer blocks.
426
+ joint_attention_kwargs (`dict`, *optional*):
427
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
428
+ `self.processor` in
429
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
430
+ return_dict (`bool`, *optional*, defaults to `True`):
431
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
432
+ tuple.
433
+
434
+ Returns:
435
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
436
+ `tuple` where the first element is the sample tensor.
437
+ """
438
+ if joint_attention_kwargs is not None:
439
+ joint_attention_kwargs = joint_attention_kwargs.copy()
440
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
441
+ else:
442
+ lora_scale = 1.0
443
+
444
+ if USE_PEFT_BACKEND:
445
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
446
+ scale_lora_layers(self, lora_scale)
447
+ else:
448
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
449
+ logger.warning(
450
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
451
+ )
452
+ hidden_states = self.x_embedder(hidden_states)
453
+
454
+ timestep = timestep.to(hidden_states.dtype) * 1000
455
+ if guidance is not None:
456
+ guidance = guidance.to(hidden_states.dtype) * 1000
457
+ else:
458
+ guidance = None
459
+ temb = (
460
+ self.time_text_embed(timestep, pooled_projections)
461
+ if guidance is None
462
+ else self.time_text_embed(timestep, guidance, pooled_projections)
463
+ )
464
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
465
+ if t5_encoder_hidden_states is not None:
466
+ encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
467
+
468
+ #ids = torch.cat((txt_ids, img_ids), dim=1)
469
+ if txt_ids.ndim == 3:
470
+ #logger.warning(
471
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
472
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
473
+ #)
474
+ txt_ids = txt_ids[0]
475
+ if img_ids.ndim == 3:
476
+ #logger.warning(
477
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
478
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
479
+ #)
480
+ img_ids = img_ids[0]
481
+ ids = torch.cat((txt_ids, img_ids), dim=0)
482
+ image_rotary_emb = self.pos_embed(ids)
483
+
484
+ for index_block, block in enumerate(self.transformer_blocks):
485
+
486
+ if self.training and self.gradient_checkpointing:
487
+
488
+ def create_custom_forward(module, return_dict=None):
489
+ def custom_forward(*inputs):
490
+ if return_dict is not None:
491
+ return module(*inputs, return_dict=return_dict)
492
+ else:
493
+ return module(*inputs)
494
+
495
+ return custom_forward
496
+
497
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
498
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
499
+ create_custom_forward(block),
500
+ hidden_states,
501
+ encoder_hidden_states,
502
+ temb,
503
+ image_rotary_emb,
504
+ **ckpt_kwargs,
505
+ )
506
+
507
+ else:
508
+ encoder_hidden_states, hidden_states = block(
509
+ hidden_states=hidden_states,
510
+ encoder_hidden_states=encoder_hidden_states,
511
+ temb=temb,
512
+ image_rotary_emb=image_rotary_emb,
513
+ )
514
+
515
+ # controlnet residual
516
+ if controlnet_block_samples is not None:
517
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
518
+ interval_control = int(np.ceil(interval_control))
519
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
520
+
521
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
522
+
523
+ for index_block, block in enumerate(self.single_transformer_blocks):
524
+ if self.training and self.gradient_checkpointing:
525
+
526
+ def create_custom_forward(module, return_dict=None):
527
+ def custom_forward(*inputs):
528
+ if return_dict is not None:
529
+ return module(*inputs, return_dict=return_dict)
530
+ else:
531
+ return module(*inputs)
532
+
533
+ return custom_forward
534
+
535
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
536
+ hidden_states = torch.utils.checkpoint.checkpoint(
537
+ create_custom_forward(block),
538
+ hidden_states,
539
+ temb,
540
+ image_rotary_emb,
541
+ **ckpt_kwargs,
542
+ )
543
+
544
+ else:
545
+ hidden_states = block(
546
+ hidden_states=hidden_states,
547
+ temb=temb,
548
+ image_rotary_emb=image_rotary_emb,
549
+ )
550
+
551
+ # controlnet residual
552
+ if controlnet_single_block_samples is not None:
553
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
554
+ interval_control = int(np.ceil(interval_control))
555
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
556
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
557
+ + controlnet_single_block_samples[index_block // interval_control]
558
+ )
559
+
560
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
561
+
562
+ hidden_states = self.norm_out(hidden_states, temb)
563
+ output = self.proj_out(hidden_states)
564
+
565
+ if USE_PEFT_BACKEND:
566
+ # remove `lora_scale` from each PEFT layer
567
+ unscale_lora_layers(self, lora_scale)
568
+
569
+ if not return_dict:
570
+ return (output,)
571
+
572
+ return Transformer2DModelOutput(sample=output)
main.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from model import FluxModel
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description='Flux Image Generation Tool')
9
+
10
+ # Required arguments
11
+ parser.add_argument('--mode', type=str, required=True,
12
+ choices=['variation', 'img2img', 'inpaint', 'controlnet', 'controlnet-inpaint'],
13
+ help='Generation mode')
14
+ parser.add_argument('--input_image', type=str, required=True,
15
+ help='Path to the input image')
16
+
17
+ # Optional arguments
18
+ parser.add_argument('--prompt', type=str, default="",
19
+ help='Text prompt to guide the generation')
20
+ parser.add_argument('--reference_image', type=str, default=None,
21
+ help='Path to the reference image (for img2img/controlnet modes)')
22
+ parser.add_argument('--mask_image', type=str, default=None,
23
+ help='Path to the mask image (for inpainting modes)')
24
+ parser.add_argument('--output_dir', type=str, default='outputs',
25
+ help='Directory to save generated images')
26
+ parser.add_argument('--image_count', type=int, default=1,
27
+ help='Number of images to generate')
28
+ parser.add_argument('--aspect_ratio', type=str, default='1:1',
29
+ choices=['1:1', '16:9', '9:16', '2.4:1', '3:4', '4:3'],
30
+ help='Output image aspect ratio')
31
+ parser.add_argument('--steps', type=int, default=28,
32
+ help='Number of inference steps')
33
+ parser.add_argument('--guidance_scale', type=float, default=7.5,
34
+ help='Guidance scale for generation')
35
+ parser.add_argument('--denoise_strength', type=float, default=0.8,
36
+ help='Denoising strength for img2img/inpaint')
37
+
38
+ # Attention related arguments
39
+ parser.add_argument('--center_x', type=float, default=None,
40
+ help='X coordinate of attention center (0-1)')
41
+ parser.add_argument('--center_y', type=float, default=None,
42
+ help='Y coordinate of attention center (0-1)')
43
+ parser.add_argument('--radius', type=float, default=None,
44
+ help='Radius of attention circle (0-1)')
45
+
46
+ # ControlNet related arguments
47
+ parser.add_argument('--line_mode', action='store_true',
48
+ help='Enable line detection mode for ControlNet')
49
+ parser.add_argument('--depth_mode', action='store_true',
50
+ help='Enable depth mode for ControlNet')
51
+ parser.add_argument('--line_strength', type=float, default=0.4,
52
+ help='Strength of line guidance')
53
+ parser.add_argument('--depth_strength', type=float, default=0.2,
54
+ help='Strength of depth guidance')
55
+
56
+ # Device selection
57
+ parser.add_argument('--device', type=str, default='cuda',
58
+ choices=['cuda', 'cpu'],
59
+ help='Device to run the model on')
60
+ parser.add_argument('--turbo', action='store_true',
61
+ help='Enable turbo mode for faster inference')
62
+
63
+ return parser.parse_args()
64
+
65
+ def load_image(image_path):
66
+ """Load and return a PIL Image."""
67
+ try:
68
+ return Image.open(image_path).convert('RGB')
69
+ except Exception as e:
70
+ raise ValueError(f"Error loading image {image_path}: {str(e)}")
71
+
72
+ def save_images(images, output_dir, prefix="generated"):
73
+ """Save generated images with sequential numbering."""
74
+ import os
75
+ os.makedirs(output_dir, exist_ok=True)
76
+
77
+ for i, image in enumerate(images):
78
+ output_path = os.path.join(output_dir, f"{prefix}_{i+1}.png")
79
+ image.save(output_path)
80
+ print(f"Saved image to {output_path}")
81
+
82
+ def get_required_features(args):
83
+ """Determine which model features are required based on the arguments."""
84
+ features = []
85
+
86
+ if args.mode in ['controlnet', 'controlnet-inpaint']:
87
+ features.append('controlnet')
88
+ if args.depth_mode:
89
+ features.append('depth')
90
+ if args.line_mode:
91
+ features.append('line')
92
+
93
+ if args.mode in ['inpaint', 'controlnet-inpaint']:
94
+ features.append('sam') # If you're using SAM for mask generation
95
+
96
+ return features
97
+
98
+
99
+ def main():
100
+ args = parse_args()
101
+
102
+ # Check CUDA availability if requested
103
+ if args.device == 'cuda' and not torch.cuda.is_available():
104
+ print("CUDA requested but not available. Falling back to CPU.")
105
+ args.device = 'cpu'
106
+
107
+ # Determine required features based on mode and arguments
108
+ required_features = get_required_features(args)
109
+
110
+ # Initialize model with only required features
111
+ print(f"Initializing model on {args.device} with features: {required_features}")
112
+ model = FluxModel(
113
+ is_turbo=args.turbo,
114
+ device=args.device,
115
+ required_features=required_features
116
+ )
117
+
118
+ # Load input images
119
+ input_image = load_image(args.input_image)
120
+ reference_image = load_image(args.reference_image) if args.reference_image else None
121
+ mask_image = load_image(args.mask_image) if args.mask_image else None
122
+
123
+ # Validate inputs based on mode
124
+ if args.mode in ['inpaint', 'controlnet-inpaint'] and mask_image is None:
125
+ raise ValueError(f"{args.mode} mode requires a mask image")
126
+
127
+ # Generate images
128
+ print(f"Generating {args.image_count} images in {args.mode} mode...")
129
+ generated_images = model.generate(
130
+ input_image_a=input_image,
131
+ input_image_b=reference_image,
132
+ prompt=args.prompt,
133
+ mask_image=mask_image,
134
+ mode=args.mode,
135
+ imageCount=args.image_count,
136
+ aspect_ratio=args.aspect_ratio,
137
+ num_inference_steps=args.steps,
138
+ guidance_scale=args.guidance_scale,
139
+ denoise_strength=args.denoise_strength,
140
+ center_x=args.center_x,
141
+ center_y=args.center_y,
142
+ radius=args.radius,
143
+ line_mode=args.line_mode,
144
+ depth_mode=args.depth_mode,
145
+ line_strength=args.line_strength,
146
+ depth_strength=args.depth_strength
147
+ )
148
+
149
+ # Save generated images
150
+ save_images(generated_images, args.output_dir)
151
+ print("Generation completed successfully!")
152
+
153
+ if __name__ == "__main__":
154
+ main()
model.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from PIL import Image
4
+ from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
5
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
6
+ from flux.transformer_flux import FluxTransformer2DModel
7
+
8
+ from flux.pipeline_flux_chameleon import FluxPipeline
9
+ from flux.pipeline_flux_img2img import FluxImg2ImgPipeline
10
+ from flux.pipeline_flux_inpaint import FluxInpaintPipeline
11
+ from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel
12
+ from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline
13
+ from flux.controlnet_flux import FluxMultiControlNetModel
14
+ from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
15
+
16
+ from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
17
+ import os
18
+ import cv2
19
+ import numpy as np
20
+ import math
21
+
22
+ def get_model_path(model_name):
23
+ """Get the full path for a model based on the checkpoints directory."""
24
+ base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override
25
+ return os.path.join(base_dir, model_name)
26
+
27
+ # Model paths configuration
28
+ MODEL_PATHS = {
29
+ 'flux': get_model_path('flux'),
30
+ 'qwen2vl': get_model_path('qwen2-vl'),
31
+ 'controlnet': get_model_path('controlnet'),
32
+ 'depth_anything': {
33
+ 'path': get_model_path('depth-anything-v2'),
34
+ 'weights': 'depth_anything_v2_vitl.pth'
35
+ },
36
+ 'anyline': {
37
+ 'path': get_model_path('anyline'),
38
+ 'weights': 'MTEED.pth'
39
+ },
40
+ 'sam2': {
41
+ 'path': get_model_path('segment-anything-2'),
42
+ 'weights': 'sam2_hiera_large.pt',
43
+ 'config': 'sam2_hiera_l.yaml'
44
+ }
45
+ }
46
+
47
+
48
+ ASPECT_RATIOS = {
49
+ "1:1": (1024, 1024),
50
+ "16:9": (1344, 768),
51
+ "9:16": (768, 1344),
52
+ "2.4:1": (1536, 640),
53
+ "3:4": (896, 1152),
54
+ "4:3": (1152, 896),
55
+ }
56
+
57
+ class Qwen2Connector(nn.Module):
58
+ def __init__(self, input_dim=3584, output_dim=4096):
59
+ super().__init__()
60
+ self.linear = nn.Linear(input_dim, output_dim)
61
+
62
+ def forward(self, x):
63
+ return self.linear(x)
64
+
65
+ class FluxModel:
66
+ def __init__(self, is_turbo=False, device="cuda", required_features=None):
67
+ """
68
+ Initialize FluxModel with specified features
69
+ Args:
70
+ is_turbo: Enable turbo mode for faster inference
71
+ device: Device to run the model on
72
+ required_features: List of required features ['controlnet', 'depth', 'line', 'sam']
73
+ """
74
+ self.device = torch.device(device)
75
+ self.dtype = torch.bfloat16
76
+ if required_features is None:
77
+ required_features = []
78
+
79
+ self._line_detector_imported = False
80
+ self._depth_model_imported = False
81
+ self._sam_imported = False
82
+ self._turbo_imported = False
83
+
84
+ # Initialize base models (always required)
85
+ self._init_base_models()
86
+
87
+ # Initialize optional models based on requirements
88
+ if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']):
89
+ self._init_controlnet()
90
+
91
+ if 'depth' in required_features:
92
+ self._init_depth_model()
93
+
94
+ if 'line' in required_features:
95
+ self._init_line_detector()
96
+
97
+ if 'sam' in required_features:
98
+ self._init_sam()
99
+
100
+ if is_turbo:
101
+ self._enable_turbo()
102
+
103
+ def _init_base_models(self):
104
+ """Initialize the core models that are always needed"""
105
+ # Qwen2VL and connector initialization
106
+ self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
107
+ MODEL_PATHS['qwen2vl'],
108
+ torch_dtype=self.dtype
109
+ )
110
+ self.qwen2vl.requires_grad_(False).to(self.device)
111
+
112
+ self.connector = Qwen2Connector(input_dim=3584, output_dim=4096)
113
+ connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt")
114
+ if os.path.exists(connector_path):
115
+ connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True)
116
+ connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()}
117
+ self.connector.load_state_dict(connector_state_dict)
118
+ self.connector.to(self.dtype).to(self.device)
119
+
120
+ # Text encoders initialization
121
+ self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer")
122
+ self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder")
123
+ self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2")
124
+ self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2")
125
+
126
+ self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device)
127
+ self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device)
128
+
129
+ # T5 context embedder
130
+ self.t5_context_embedder = nn.Linear(4096, 3072)
131
+ t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt")
132
+ t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True)
133
+ self.t5_context_embedder.load_state_dict(t5_embedder_state_dict)
134
+ self.t5_context_embedder.to(self.dtype).to(self.device)
135
+
136
+ # Basic components
137
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1)
138
+ self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae")
139
+ self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer")
140
+
141
+ self.vae.requires_grad_(False).to(self.dtype).to(self.device)
142
+ self.transformer.requires_grad_(False).to(self.dtype).to(self.device)
143
+
144
+ def _init_controlnet(self):
145
+ """Initialize ControlNet model"""
146
+ self.controlnet_union = FluxControlNetModel.from_pretrained(
147
+ MODEL_PATHS['controlnet'],
148
+ torch_dtype=torch.bfloat16
149
+ )
150
+ self.controlnet_union.requires_grad_(False).to(self.device)
151
+ self.controlnet = FluxMultiControlNetModel([self.controlnet_union])
152
+
153
+ def _init_depth_model(self):
154
+ """Initialize Depth Anything V2 model"""
155
+ if not self._depth_model_imported:
156
+ from depth_anything_v2.dpt import DepthAnythingV2
157
+ self._depth_model_imported = True
158
+
159
+ self.depth_model = DepthAnythingV2(
160
+ encoder='vitl',
161
+ features=256,
162
+ out_channels=[256, 512, 1024, 1024]
163
+ )
164
+ depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'],
165
+ MODEL_PATHS['depth_anything']['weights'])
166
+ self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device))
167
+ self.depth_model.requires_grad_(False).to(self.device)
168
+
169
+ def _init_line_detector(self):
170
+ """Initialize line detection model"""
171
+ if not self._line_detector_imported:
172
+ from controlnet_aux import AnylineDetector
173
+ self._line_detector_imported = True
174
+
175
+ self.anyline = AnylineDetector.from_pretrained(
176
+ MODEL_PATHS['anyline']['path'],
177
+ filename=MODEL_PATHS['anyline']['weights']
178
+ )
179
+ self.anyline.to(self.device)
180
+
181
+ def _init_sam(self):
182
+ """Initialize SAM2 model"""
183
+ if not self._sam_imported:
184
+ from sam2.build_sam import build_sam2
185
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
186
+ self._sam_imported = True
187
+
188
+ sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'],
189
+ MODEL_PATHS['sam2']['weights'])
190
+ model_cfg = os.path.join(MODEL_PATHS['sam2']['path'],
191
+ MODEL_PATHS['sam2']['config'])
192
+ self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
193
+ self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
194
+
195
+ def _enable_turbo(self):
196
+ """Enable turbo mode for faster inference"""
197
+ if not self._turbo_imported:
198
+ from optimum.quanto import freeze, qfloat8, quantize
199
+ self._turbo_imported = True
200
+
201
+ quantize(
202
+ self.transformer,
203
+ weights=qfloat8,
204
+ exclude=[
205
+ "*.norm", "*.norm1", "*.norm2", "*.norm2_context",
206
+ "proj_out", "x_embedder", "norm_out", "context_embedder",
207
+ ],
208
+ )
209
+ freeze(self.transformer)
210
+
211
+ def generate_mask(self, image, input_points, input_labels):
212
+ """
213
+ 使用SAM2生成分割mask
214
+
215
+ Args:
216
+ image: PIL Image或numpy数组
217
+ input_points: numpy数组,形状为(N, 2),包含点的坐标
218
+ input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点
219
+
220
+ Returns:
221
+ PIL Image: 最高分数的mask
222
+ """
223
+ try:
224
+ # 确保图像是numpy数组
225
+ if isinstance(image, Image.Image):
226
+ image_array = np.array(image)
227
+ else:
228
+ image_array = image
229
+
230
+ # 设置图像
231
+ self.sam2_predictor.set_image(image_array)
232
+
233
+ # 进行预测
234
+ with torch.inference_mode():
235
+ masks, scores, logits = self.sam2_predictor.predict(
236
+ point_coords=input_points,
237
+ point_labels=input_labels,
238
+ multimask_output=True,
239
+ )
240
+
241
+ # 返回得分最高的mask
242
+ best_mask_idx = scores.argmax()
243
+ mask = masks[best_mask_idx]
244
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
245
+ return mask_image
246
+
247
+ except Exception as e:
248
+ print(f"Mask generation failed: {str(e)}")
249
+ raise
250
+
251
+ def recover_2d_shape(self, image_hidden_state, grid_thw):
252
+ batch_size, num_tokens, hidden_dim = image_hidden_state.shape
253
+ _, h, w = grid_thw
254
+ h_out = h // 2
255
+ w_out = w // 2
256
+ # 重塑为 (batch_size, height, width, hidden_dim)
257
+ reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim)
258
+ return reshaped
259
+
260
+ def generate_attention_matrix(self, center_x, center_y, radius, image_shape):
261
+ height, width = image_shape
262
+ y, x = np.ogrid[:height, :width]
263
+ center_y, center_x = center_y * height, center_x * width
264
+ distances = np.sqrt((x - center_x)**2 + (y - center_y)**2)
265
+ attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1)
266
+ return attention
267
+
268
+ def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius):
269
+ qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0]))
270
+ attention_matrix = self.generate_attention_matrix(
271
+ center_x, center_y, radius,
272
+ (qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2))
273
+ )
274
+ attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1)
275
+ qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device)
276
+ return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3))
277
+
278
+ def compute_text_embeddings(self, prompt):
279
+ with torch.no_grad():
280
+ text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
281
+ text_input_ids = text_inputs.input_ids.to(self.device)
282
+ prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False)
283
+ pooled_prompt_embeds = prompt_embeds.pooler_output
284
+ return pooled_prompt_embeds.to(self.dtype)
285
+
286
+ def compute_t5_text_embeddings(
287
+ self,
288
+ max_sequence_length=256,
289
+ prompt=None,
290
+ num_images_per_prompt=1,
291
+ device=None,
292
+ ):
293
+ prompt = [prompt] if isinstance(prompt, str) else prompt
294
+ batch_size = len(prompt)
295
+
296
+ text_inputs = self.tokenizer_two(
297
+ prompt,
298
+ padding="max_length",
299
+ max_length=max_sequence_length,
300
+ truncation=True,
301
+ return_length=False,
302
+ return_overflowing_tokens=False,
303
+ return_tensors="pt",
304
+ )
305
+ text_input_ids = text_inputs.input_ids
306
+ prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0]
307
+
308
+ dtype = self.text_encoder_two.dtype
309
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
310
+
311
+ _, seq_len, _ = prompt_embeds.shape
312
+
313
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
314
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
315
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
316
+
317
+ return prompt_embeds
318
+
319
+ def process_image(self, image):
320
+ message = [
321
+ {
322
+ "role": "user",
323
+ "content": [
324
+ {"type": "image", "image": image},
325
+ {"type": "text", "text": "Describe this image."},
326
+ ]
327
+ }
328
+ ]
329
+ text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
330
+
331
+ with torch.no_grad():
332
+ inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
333
+ output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs)
334
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
335
+
336
+ return image_hidden_state, image_grid_thw
337
+
338
+ def resize_image(self, img, max_pixels=1050000):
339
+ # 确保输入是 PIL Image
340
+ if not isinstance(img, Image.Image):
341
+ img = Image.fromarray(img)
342
+
343
+ width, height = img.size
344
+ num_pixels = width * height
345
+
346
+ if num_pixels > max_pixels:
347
+ scale = math.sqrt(max_pixels / num_pixels)
348
+ new_width = int(width * scale)
349
+ new_height = int(height * scale)
350
+ # 调整宽度和高度,使其能被8整除
351
+ new_width = new_width - (new_width % 8)
352
+ new_height = new_height - (new_height % 8)
353
+ img = img.resize((new_width, new_height), Image.LANCZOS)
354
+ else:
355
+ # 如果图片不需要缩小,仍然需要确保尺寸能被8整除
356
+ new_width = width - (width % 8)
357
+ new_height = height - (height % 8)
358
+ if new_width != width or new_height != height:
359
+ img = img.resize((new_width, new_height), Image.LANCZOS)
360
+
361
+ return img
362
+
363
+ def generate_depth_map(self, image):
364
+ """Generate depth map using Depth Anything V2"""
365
+ # Convert PIL to numpy array
366
+ image_np = np.array(image)
367
+
368
+ # Convert RGB to BGR for cv2
369
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
370
+
371
+ # Generate depth map
372
+ with torch.no_grad():
373
+ depth = self.depth_model.infer_image(image_bgr)
374
+
375
+ # Normalize depth to 0-1 range
376
+ depth_norm = (depth - depth.min()) / (depth.max() - depth.min())
377
+
378
+ # Convert to RGB image
379
+ depth_rgb = (depth_norm * 255).astype(np.uint8)
380
+ depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB)
381
+
382
+ return Image.fromarray(depth_rgb)
383
+
384
+
385
+ def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28,
386
+ aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation",
387
+ denoise_strength=0.8, mask_image=None, imageCount=2,
388
+ line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2):
389
+
390
+ batch_size = imageCount
391
+ if aspect_ratio not in ASPECT_RATIOS:
392
+ raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
393
+
394
+ width, height = ASPECT_RATIOS[aspect_ratio]
395
+
396
+ pooled_prompt_embeds = self.compute_text_embeddings(prompt="")
397
+ t5_prompt_embeds = None
398
+ if prompt != "":
399
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28)
400
+ t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device)
401
+ t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds)
402
+ else:
403
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28)
404
+
405
+ qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a)
406
+ # 只有当所有注意力参数都被提供时,才应用注意力机制
407
+ if mode == "variation":
408
+ if center_x is not None and center_y is not None and radius is not None:
409
+ qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius)
410
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
411
+
412
+ if mode == "img2img" or mode == "inpaint":
413
+ if input_image_b:
414
+ qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
415
+ if center_x is not None and center_y is not None and radius is not None:
416
+ qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
417
+ qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
418
+ else:
419
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
420
+ qwen2_hidden_state_b = None
421
+
422
+ if mode == "controlnet" or mode == "controlnet-inpaint":
423
+ qwen2_hidden_state_b = None
424
+ if input_image_b:
425
+ qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
426
+ if center_x is not None and center_y is not None and radius is not None:
427
+ qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
428
+ qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
429
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
430
+
431
+ #############################
432
+ # IMAGE GENERATION
433
+ #############################
434
+ if mode == "variation":
435
+ # Initialize different pipelines
436
+ pipeline = FluxPipeline(
437
+ transformer=self.transformer,
438
+ scheduler=self.noise_scheduler,
439
+ vae=self.vae,
440
+ text_encoder=self.text_encoder,
441
+ tokenizer=self.tokenizer,
442
+ )
443
+
444
+ gen_images = pipeline(
445
+ prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
446
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
447
+ pooled_prompt_embeds=pooled_prompt_embeds,
448
+ num_inference_steps=num_inference_steps,
449
+ guidance_scale=guidance_scale,
450
+ height=height,
451
+ width=width,
452
+ ).images
453
+
454
+
455
+ #############################
456
+ # IMAGE-TO-IMAGE
457
+ #############################
458
+ elif mode == "img2img":
459
+ input_image_a = self.resize_image(input_image_a)
460
+ width, height = input_image_a.size
461
+
462
+ img2img_pipeline = FluxImg2ImgPipeline(
463
+ transformer=self.transformer,
464
+ scheduler=self.noise_scheduler,
465
+ vae=self.vae,
466
+ text_encoder=self.text_encoder,
467
+ tokenizer=self.tokenizer,
468
+ )
469
+
470
+ gen_images = img2img_pipeline(
471
+ image=input_image_a,
472
+ strength=denoise_strength,
473
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
474
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
475
+ pooled_prompt_embeds=pooled_prompt_embeds,
476
+ num_inference_steps=num_inference_steps,
477
+ guidance_scale=guidance_scale,
478
+ height=height,
479
+ width=width,
480
+ ).images
481
+
482
+
483
+ #############################
484
+ # INPAINTING
485
+ #############################
486
+ elif mode == "inpaint":
487
+ if mask_image is None:
488
+ raise ValueError("Mask image is required for inpainting mode")
489
+
490
+ input_image_a = self.resize_image(input_image_a)
491
+ mask_image = self.resize_image(mask_image)
492
+ width, height = input_image_a.size
493
+
494
+ inpaint_pipeline = FluxInpaintPipeline(
495
+ transformer=self.transformer,
496
+ scheduler=self.noise_scheduler,
497
+ vae=self.vae,
498
+ text_encoder=self.text_encoder,
499
+ tokenizer=self.tokenizer,
500
+ )
501
+
502
+ gen_images = inpaint_pipeline(
503
+ image=input_image_a,
504
+ mask_image=mask_image,
505
+ strength=denoise_strength,
506
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
507
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
508
+ pooled_prompt_embeds=pooled_prompt_embeds,
509
+ num_inference_steps=num_inference_steps,
510
+ guidance_scale=guidance_scale,
511
+ height=height,
512
+ width=width,
513
+ ).images
514
+
515
+ #############################
516
+ # CONTROLNET
517
+ #############################
518
+ elif mode == "controlnet":
519
+ input_image_a = self.resize_image(input_image_a)
520
+ width, height = input_image_a.size
521
+
522
+ controlnet_pipeline = FluxControlNetImg2ImgPipeline(
523
+ transformer=self.transformer,
524
+ scheduler=self.noise_scheduler,
525
+ vae=self.vae,
526
+ text_encoder=self.text_encoder,
527
+ tokenizer=self.tokenizer,
528
+ controlnet=self.controlnet,
529
+ )
530
+
531
+ # 准备控制图像和模式列表
532
+ control_images = []
533
+ control_modes = []
534
+ conditioning_scales = []
535
+
536
+ # 根据用户选择添加控制模式
537
+ if depth_mode:
538
+ control_image_depth = self.generate_depth_map(input_image_a)
539
+ control_images.append(control_image_depth)
540
+ control_modes.append(2) # depth mode
541
+ conditioning_scales.append(depth_strength)
542
+
543
+ if line_mode:
544
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
545
+ control_images.append(control_image_canny)
546
+ control_modes.append(0) # line mode
547
+ conditioning_scales.append(line_strength)
548
+
549
+ # 如果没有启用任何模式,默认使用line+depth模式
550
+ if not line_mode and not depth_mode:
551
+ control_image_depth = self.generate_depth_map(input_image_a)
552
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
553
+ control_images = [control_image_depth, control_image_canny]
554
+ control_modes = [2, 0]
555
+ conditioning_scales = [0.2, 0.4]
556
+
557
+ if qwen2_hidden_state_b is not None:
558
+ qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
559
+ qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
560
+
561
+ gen_images = controlnet_pipeline(
562
+ image=input_image_a,
563
+ strength=denoise_strength,
564
+ control_image=control_images,
565
+ control_mode=control_modes,
566
+ controlnet_conditioning_scale=conditioning_scales,
567
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
568
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
569
+ prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
570
+ pooled_prompt_embeds=pooled_prompt_embeds,
571
+ num_inference_steps=num_inference_steps,
572
+ guidance_scale=guidance_scale,
573
+ height=height,
574
+ width=width,
575
+ ).images
576
+
577
+ #############################
578
+ # CONTROLNET INPAINT
579
+ #############################
580
+ elif mode == "controlnet-inpaint":
581
+ input_image_a = self.resize_image(input_image_a)
582
+ mask_image = self.resize_image(mask_image)
583
+ width, height = input_image_a.size
584
+
585
+ controlnet_pipeline = FluxControlNetInpaintPipeline(
586
+ transformer=self.transformer,
587
+ scheduler=self.noise_scheduler,
588
+ vae=self.vae,
589
+ text_encoder=self.text_encoder,
590
+ tokenizer=self.tokenizer,
591
+ controlnet=self.controlnet,
592
+ )
593
+
594
+ # 准备控制图像和模式列表
595
+ control_images = []
596
+ control_modes = []
597
+ conditioning_scales = []
598
+
599
+ # 根据用户选择添加控制模式
600
+ if depth_mode:
601
+ control_image_depth = self.generate_depth_map(input_image_a)
602
+ control_images.append(control_image_depth)
603
+ control_modes.append(2) # depth mode
604
+ conditioning_scales.append(depth_strength)
605
+
606
+ if line_mode:
607
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
608
+ control_images.append(control_image_canny)
609
+ control_modes.append(0) # line mode
610
+ conditioning_scales.append(line_strength)
611
+
612
+ # 如果没有启用任何模式,默认使用line+depth模式
613
+ if not line_mode and not depth_mode:
614
+ control_image_depth = self.generate_depth_map(input_image_a)
615
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
616
+ control_images = [control_image_depth, control_image_canny]
617
+ control_modes = [2, 0]
618
+ conditioning_scales = [0.2, 0.4]
619
+
620
+ if qwen2_hidden_state_b is not None:
621
+ qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
622
+ qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
623
+
624
+ gen_images = controlnet_pipeline(
625
+ image=input_image_a,
626
+ mask_image=mask_image,
627
+ control_image=control_images,
628
+ control_mode=control_modes,
629
+ controlnet_conditioning_scale=conditioning_scales,
630
+ strength=denoise_strength,
631
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
632
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
633
+ prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
634
+ pooled_prompt_embeds=pooled_prompt_embeds,
635
+ num_inference_steps=num_inference_steps,
636
+ guidance_scale=guidance_scale,
637
+ height=height,
638
+ width=width,
639
+ ).images
640
+
641
+ else:
642
+ raise ValueError(f"Invalid mode: {mode}")
643
+
644
+ return gen_images
modelmod.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from PIL import Image
4
+ from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast, BitsAndBytesConfig
5
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
6
+ from flux.transformer_flux import FluxTransformer2DModel
7
+
8
+ from flux.pipeline_flux_chameleon import FluxPipeline
9
+ from flux.pipeline_flux_img2img import FluxImg2ImgPipeline
10
+ from flux.pipeline_flux_inpaint import FluxInpaintPipeline
11
+ from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel
12
+ from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline
13
+ from flux.controlnet_flux import FluxMultiControlNetModel
14
+ from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
15
+
16
+ from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
17
+ import os
18
+ import cv2
19
+ import numpy as np
20
+ import math
21
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
22
+
23
+
24
+ def get_model_path(model_name):
25
+ """Get the full path for a model based on the checkpoints directory."""
26
+ base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override
27
+ return os.path.join(base_dir, model_name)
28
+
29
+ # Model paths configuration
30
+ MODEL_PATHS = {
31
+ 'flux': get_model_path('flux'),
32
+ 'qwen2vl': get_model_path('qwen2-vl'),
33
+ 'controlnet': get_model_path('controlnet'),
34
+ 'depth_anything': {
35
+ 'path': get_model_path('depth-anything-v2'),
36
+ 'weights': 'depth_anything_v2_vitl.pth'
37
+ },
38
+ 'anyline': {
39
+ 'path': get_model_path('anyline'),
40
+ 'weights': 'MTEED.pth'
41
+ },
42
+ 'sam2': {
43
+ 'path': get_model_path('segment-anything-2'),
44
+ 'weights': 'sam2_hiera_large.pt',
45
+ 'config': 'sam2_hiera_l.yaml'
46
+ }
47
+ }
48
+
49
+
50
+ ASPECT_RATIOS = {
51
+ "1:1": (1024, 1024),
52
+ "16:9": (1344, 768),
53
+ "9:16": (768, 1344),
54
+ "2.4:1": (1536, 640),
55
+ "3:4": (896, 1152),
56
+ "4:3": (1152, 896),
57
+ }
58
+
59
+ class Qwen2Connector(nn.Module):
60
+ def __init__(self, input_dim=3584, output_dim=4096):
61
+ super().__init__()
62
+ self.linear = nn.Linear(input_dim, output_dim)
63
+
64
+ def forward(self, x):
65
+ return self.linear(x)
66
+
67
+ class FluxModel:
68
+ def __init__(self, is_turbo=False, device="cuda", required_features=None, is_quantization=True):
69
+ """
70
+ Initialize FluxModel with specified features
71
+ Args:
72
+ is_turbo: Enable turbo mode for faster inference
73
+ device: Device to run the model on
74
+ required_features: List of required features ['controlnet', 'depth', 'line', 'sam']
75
+ """
76
+ self.device = torch.device(device)
77
+ self.qkwargs = {"quantization_config": nf4_config} if is_quantization else {}
78
+ self.dtype = torch.bfloat16
79
+ if required_features is None:
80
+ required_features = []
81
+
82
+ self._line_detector_imported = False
83
+ self._depth_model_imported = False
84
+ self._sam_imported = False
85
+ self._turbo_imported = False
86
+
87
+ # Initialize base models (always required)
88
+ self._init_base_models()
89
+
90
+ # Initialize optional models based on requirements
91
+ if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']):
92
+ self._init_controlnet()
93
+
94
+ if 'depth' in required_features:
95
+ self._init_depth_model()
96
+
97
+ if 'line' in required_features:
98
+ self._init_line_detector()
99
+
100
+ if 'sam' in required_features:
101
+ self._init_sam()
102
+
103
+ if is_turbo:
104
+ self._enable_turbo()
105
+
106
+ def _init_base_models(self):
107
+ """Initialize the core models that are always needed"""
108
+ # Qwen2VL and connector initialization
109
+ self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
110
+ MODEL_PATHS['qwen2vl'],
111
+ torch_dtype=self.dtype,
112
+ **self.qkwargs
113
+ )
114
+ self.qwen2vl.requires_grad_(False).to(self.device)
115
+
116
+ self.connector = Qwen2Connector(input_dim=3584, output_dim=4096)
117
+ connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt")
118
+ if os.path.exists(connector_path):
119
+ connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True)
120
+ connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()}
121
+ self.connector.load_state_dict(connector_state_dict)
122
+ self.connector.to(self.dtype).to(self.device)
123
+
124
+ # Text encoders initialization
125
+ self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer")
126
+ self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder")
127
+ self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2", **self.qkwargs)
128
+ self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2")
129
+
130
+ self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device)
131
+ #self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device)
132
+ self.text_encoder_two.requires_grad_(False).to(self.device)
133
+
134
+ # T5 context embedder
135
+ self.t5_context_embedder = nn.Linear(4096, 3072)
136
+ t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt")
137
+ t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True)
138
+ self.t5_context_embedder.load_state_dict(t5_embedder_state_dict)
139
+ self.t5_context_embedder.to(self.dtype).to(self.device)
140
+
141
+ # Basic components
142
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1)
143
+ self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae")
144
+ self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer", **self.qkwargs)
145
+
146
+ self.vae.requires_grad_(False).to(self.dtype).to(self.device)
147
+ #self.transformer.requires_grad_(False).to(self.dtype).to(self.device)
148
+ self.transformer.requires_grad_(False).to(self.device)
149
+
150
+ def _init_controlnet(self):
151
+ """Initialize ControlNet model"""
152
+ self.controlnet_union = FluxControlNetModel.from_pretrained(
153
+ MODEL_PATHS['controlnet'],
154
+ torch_dtype=torch.bfloat16
155
+ )
156
+ self.controlnet_union.requires_grad_(False).to(self.device)
157
+ self.controlnet = FluxMultiControlNetModel([self.controlnet_union])
158
+
159
+ def _init_depth_model(self):
160
+ """Initialize Depth Anything V2 model"""
161
+ if not self._depth_model_imported:
162
+ from depth_anything_v2.dpt import DepthAnythingV2
163
+ self._depth_model_imported = True
164
+
165
+ self.depth_model = DepthAnythingV2(
166
+ encoder='vitl',
167
+ features=256,
168
+ out_channels=[256, 512, 1024, 1024]
169
+ )
170
+ depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'],
171
+ MODEL_PATHS['depth_anything']['weights'])
172
+ self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device))
173
+ self.depth_model.requires_grad_(False).to(self.device)
174
+
175
+ def _init_line_detector(self):
176
+ """Initialize line detection model"""
177
+ if not self._line_detector_imported:
178
+ from controlnet_aux import AnylineDetector
179
+ self._line_detector_imported = True
180
+
181
+ self.anyline = AnylineDetector.from_pretrained(
182
+ MODEL_PATHS['anyline']['path'],
183
+ filename=MODEL_PATHS['anyline']['weights']
184
+ )
185
+ self.anyline.to(self.device)
186
+
187
+ def _init_sam(self):
188
+ """Initialize SAM2 model"""
189
+ if not self._sam_imported:
190
+ from sam2.build_sam import build_sam2
191
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
192
+ self._sam_imported = True
193
+
194
+ sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'],
195
+ MODEL_PATHS['sam2']['weights'])
196
+ model_cfg = os.path.join(MODEL_PATHS['sam2']['path'],
197
+ MODEL_PATHS['sam2']['config'])
198
+ self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
199
+ self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
200
+
201
+ def _enable_turbo(self):
202
+ """Enable turbo mode for faster inference"""
203
+ if not self._turbo_imported:
204
+ from optimum.quanto import freeze, qfloat8, quantize
205
+ self._turbo_imported = True
206
+
207
+ quantize(
208
+ self.transformer,
209
+ weights=qfloat8,
210
+ exclude=[
211
+ "*.norm", "*.norm1", "*.norm2", "*.norm2_context",
212
+ "proj_out", "x_embedder", "norm_out", "context_embedder",
213
+ ],
214
+ )
215
+ freeze(self.transformer)
216
+
217
+ def generate_mask(self, image, input_points, input_labels):
218
+ """
219
+ 使用SAM2生成分割mask
220
+
221
+ Args:
222
+ image: PIL Image或numpy数组
223
+ input_points: numpy数组,形状为(N, 2),包含点的坐标
224
+ input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点
225
+
226
+ Returns:
227
+ PIL Image: 最高分数的mask
228
+ """
229
+ try:
230
+ # 确保图像是numpy数组
231
+ if isinstance(image, Image.Image):
232
+ image_array = np.array(image)
233
+ else:
234
+ image_array = image
235
+
236
+ # 设置图像
237
+ self.sam2_predictor.set_image(image_array)
238
+
239
+ # 进行预测
240
+ with torch.inference_mode():
241
+ masks, scores, logits = self.sam2_predictor.predict(
242
+ point_coords=input_points,
243
+ point_labels=input_labels,
244
+ multimask_output=True,
245
+ )
246
+
247
+ # 返回得分最高的mask
248
+ best_mask_idx = scores.argmax()
249
+ mask = masks[best_mask_idx]
250
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
251
+ return mask_image
252
+
253
+ except Exception as e:
254
+ print(f"Mask generation failed: {str(e)}")
255
+ raise
256
+
257
+ def recover_2d_shape(self, image_hidden_state, grid_thw):
258
+ batch_size, num_tokens, hidden_dim = image_hidden_state.shape
259
+ _, h, w = grid_thw
260
+ h_out = h // 2
261
+ w_out = w // 2
262
+ # 重塑为 (batch_size, height, width, hidden_dim)
263
+ reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim)
264
+ return reshaped
265
+
266
+ def generate_attention_matrix(self, center_x, center_y, radius, image_shape):
267
+ height, width = image_shape
268
+ y, x = np.ogrid[:height, :width]
269
+ center_y, center_x = center_y * height, center_x * width
270
+ distances = np.sqrt((x - center_x)**2 + (y - center_y)**2)
271
+ attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1)
272
+ return attention
273
+
274
+ def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius):
275
+ qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0]))
276
+ attention_matrix = self.generate_attention_matrix(
277
+ center_x, center_y, radius,
278
+ (qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2))
279
+ )
280
+ attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1)
281
+ qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device)
282
+ return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3))
283
+
284
+ def compute_text_embeddings(self, prompt):
285
+ with torch.no_grad():
286
+ text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
287
+ text_input_ids = text_inputs.input_ids.to(self.device)
288
+ prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False)
289
+ pooled_prompt_embeds = prompt_embeds.pooler_output
290
+ return pooled_prompt_embeds.to(self.dtype)
291
+
292
+ def compute_t5_text_embeddings(
293
+ self,
294
+ max_sequence_length=256,
295
+ prompt=None,
296
+ num_images_per_prompt=1,
297
+ device=None,
298
+ ):
299
+ prompt = [prompt] if isinstance(prompt, str) else prompt
300
+ batch_size = len(prompt)
301
+
302
+ text_inputs = self.tokenizer_two(
303
+ prompt,
304
+ padding="max_length",
305
+ max_length=max_sequence_length,
306
+ truncation=True,
307
+ return_length=False,
308
+ return_overflowing_tokens=False,
309
+ return_tensors="pt",
310
+ )
311
+ text_input_ids = text_inputs.input_ids
312
+ prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0]
313
+
314
+ dtype = self.text_encoder_two.dtype
315
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
316
+
317
+ _, seq_len, _ = prompt_embeds.shape
318
+
319
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
320
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
322
+
323
+ return prompt_embeds
324
+
325
+ def process_image(self, image):
326
+ message = [
327
+ {
328
+ "role": "user",
329
+ "content": [
330
+ {"type": "image", "image": image},
331
+ {"type": "text", "text": "Describe this image."},
332
+ ]
333
+ }
334
+ ]
335
+ text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
336
+
337
+ with torch.no_grad():
338
+ inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
339
+ output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs)
340
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
341
+
342
+ return image_hidden_state, image_grid_thw
343
+
344
+ def resize_image(self, img, max_pixels=1050000):
345
+ # 确保输入是 PIL Image
346
+ if not isinstance(img, Image.Image):
347
+ img = Image.fromarray(img)
348
+
349
+ width, height = img.size
350
+ num_pixels = width * height
351
+
352
+ if num_pixels > max_pixels:
353
+ scale = math.sqrt(max_pixels / num_pixels)
354
+ new_width = int(width * scale)
355
+ new_height = int(height * scale)
356
+ # 调整宽度和高度,使其能被8整除
357
+ new_width = new_width - (new_width % 8)
358
+ new_height = new_height - (new_height % 8)
359
+ img = img.resize((new_width, new_height), Image.LANCZOS)
360
+ else:
361
+ # 如果图片不需要缩小,仍然需要确保尺寸能被8整除
362
+ new_width = width - (width % 8)
363
+ new_height = height - (height % 8)
364
+ if new_width != width or new_height != height:
365
+ img = img.resize((new_width, new_height), Image.LANCZOS)
366
+
367
+ return img
368
+
369
+ def generate_depth_map(self, image):
370
+ """Generate depth map using Depth Anything V2"""
371
+ # Convert PIL to numpy array
372
+ image_np = np.array(image)
373
+
374
+ # Convert RGB to BGR for cv2
375
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
376
+
377
+ # Generate depth map
378
+ with torch.no_grad():
379
+ depth = self.depth_model.infer_image(image_bgr)
380
+
381
+ # Normalize depth to 0-1 range
382
+ depth_norm = (depth - depth.min()) / (depth.max() - depth.min())
383
+
384
+ # Convert to RGB image
385
+ depth_rgb = (depth_norm * 255).astype(np.uint8)
386
+ depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB)
387
+
388
+ return Image.fromarray(depth_rgb)
389
+
390
+
391
+ def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28,
392
+ aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation",
393
+ denoise_strength=0.8, mask_image=None, imageCount=2,
394
+ line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2):
395
+
396
+ batch_size = imageCount
397
+ if aspect_ratio not in ASPECT_RATIOS:
398
+ raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
399
+
400
+ width, height = ASPECT_RATIOS[aspect_ratio]
401
+
402
+ pooled_prompt_embeds = self.compute_text_embeddings(prompt="")
403
+ t5_prompt_embeds = None
404
+ if prompt != "":
405
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28)
406
+ t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device).to(self.dtype)
407
+ t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds)
408
+ else:
409
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28)
410
+
411
+ qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a)
412
+ # 只有当所有注意力参数都被提供时,才应用注意力机制
413
+ if mode == "variation":
414
+ if center_x is not None and center_y is not None and radius is not None:
415
+ qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius)
416
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
417
+
418
+ if mode == "img2img" or mode == "inpaint":
419
+ if input_image_b:
420
+ qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
421
+ if center_x is not None and center_y is not None and radius is not None:
422
+ qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
423
+ qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
424
+ else:
425
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
426
+ qwen2_hidden_state_b = None
427
+
428
+ if mode == "controlnet" or mode == "controlnet-inpaint":
429
+ qwen2_hidden_state_b = None
430
+ if input_image_b:
431
+ qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
432
+ if center_x is not None and center_y is not None and radius is not None:
433
+ qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
434
+ qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
435
+ qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
436
+
437
+ #############################
438
+ # IMAGE GENERATION
439
+ #############################
440
+ if mode == "variation":
441
+ # Initialize different pipelines
442
+ pipeline = FluxPipeline(
443
+ transformer=self.transformer,
444
+ scheduler=self.noise_scheduler,
445
+ vae=self.vae,
446
+ text_encoder=self.text_encoder,
447
+ tokenizer=self.tokenizer,
448
+ )
449
+
450
+ gen_images = pipeline(
451
+ prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
452
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
453
+ pooled_prompt_embeds=pooled_prompt_embeds,
454
+ num_inference_steps=num_inference_steps,
455
+ guidance_scale=guidance_scale,
456
+ height=height,
457
+ width=width,
458
+ ).images
459
+
460
+
461
+ #############################
462
+ # IMAGE-TO-IMAGE
463
+ #############################
464
+ elif mode == "img2img":
465
+ input_image_a = self.resize_image(input_image_a)
466
+ width, height = input_image_a.size
467
+
468
+ img2img_pipeline = FluxImg2ImgPipeline(
469
+ transformer=self.transformer,
470
+ scheduler=self.noise_scheduler,
471
+ vae=self.vae,
472
+ text_encoder=self.text_encoder,
473
+ tokenizer=self.tokenizer,
474
+ )
475
+
476
+ gen_images = img2img_pipeline(
477
+ image=input_image_a,
478
+ strength=denoise_strength,
479
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
480
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
481
+ pooled_prompt_embeds=pooled_prompt_embeds,
482
+ num_inference_steps=num_inference_steps,
483
+ guidance_scale=guidance_scale,
484
+ height=height,
485
+ width=width,
486
+ ).images
487
+
488
+
489
+ #############################
490
+ # INPAINTING
491
+ #############################
492
+ elif mode == "inpaint":
493
+ if mask_image is None:
494
+ raise ValueError("Mask image is required for inpainting mode")
495
+
496
+ input_image_a = self.resize_image(input_image_a)
497
+ mask_image = self.resize_image(mask_image)
498
+ width, height = input_image_a.size
499
+
500
+ inpaint_pipeline = FluxInpaintPipeline(
501
+ transformer=self.transformer,
502
+ scheduler=self.noise_scheduler,
503
+ vae=self.vae,
504
+ text_encoder=self.text_encoder,
505
+ tokenizer=self.tokenizer,
506
+ )
507
+
508
+ gen_images = inpaint_pipeline(
509
+ image=input_image_a,
510
+ mask_image=mask_image,
511
+ strength=denoise_strength,
512
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
513
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
514
+ pooled_prompt_embeds=pooled_prompt_embeds,
515
+ num_inference_steps=num_inference_steps,
516
+ guidance_scale=guidance_scale,
517
+ height=height,
518
+ width=width,
519
+ ).images
520
+
521
+ #############################
522
+ # CONTROLNET
523
+ #############################
524
+ elif mode == "controlnet":
525
+ input_image_a = self.resize_image(input_image_a)
526
+ width, height = input_image_a.size
527
+
528
+ controlnet_pipeline = FluxControlNetImg2ImgPipeline(
529
+ transformer=self.transformer,
530
+ scheduler=self.noise_scheduler,
531
+ vae=self.vae,
532
+ text_encoder=self.text_encoder,
533
+ tokenizer=self.tokenizer,
534
+ controlnet=self.controlnet,
535
+ )
536
+
537
+ # 准备控制图像和模式列表
538
+ control_images = []
539
+ control_modes = []
540
+ conditioning_scales = []
541
+
542
+ # 根据用户选择添加控制模式
543
+ if depth_mode:
544
+ control_image_depth = self.generate_depth_map(input_image_a)
545
+ control_images.append(control_image_depth)
546
+ control_modes.append(2) # depth mode
547
+ conditioning_scales.append(depth_strength)
548
+
549
+ if line_mode:
550
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
551
+ control_images.append(control_image_canny)
552
+ control_modes.append(0) # line mode
553
+ conditioning_scales.append(line_strength)
554
+
555
+ # 如果没有启用任何模式,默认使用line+depth模式
556
+ if not line_mode and not depth_mode:
557
+ control_image_depth = self.generate_depth_map(input_image_a)
558
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
559
+ control_images = [control_image_depth, control_image_canny]
560
+ control_modes = [2, 0]
561
+ conditioning_scales = [0.2, 0.4]
562
+
563
+ if qwen2_hidden_state_b is not None:
564
+ qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
565
+ qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
566
+
567
+ gen_images = controlnet_pipeline(
568
+ image=input_image_a,
569
+ strength=denoise_strength,
570
+ control_image=control_images,
571
+ control_mode=control_modes,
572
+ controlnet_conditioning_scale=conditioning_scales,
573
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
574
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
575
+ prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
576
+ pooled_prompt_embeds=pooled_prompt_embeds,
577
+ num_inference_steps=num_inference_steps,
578
+ guidance_scale=guidance_scale,
579
+ height=height,
580
+ width=width,
581
+ ).images
582
+
583
+ #############################
584
+ # CONTROLNET INPAINT
585
+ #############################
586
+ elif mode == "controlnet-inpaint":
587
+ input_image_a = self.resize_image(input_image_a)
588
+ mask_image = self.resize_image(mask_image)
589
+ width, height = input_image_a.size
590
+
591
+ controlnet_pipeline = FluxControlNetInpaintPipeline(
592
+ transformer=self.transformer,
593
+ scheduler=self.noise_scheduler,
594
+ vae=self.vae,
595
+ text_encoder=self.text_encoder,
596
+ tokenizer=self.tokenizer,
597
+ controlnet=self.controlnet,
598
+ )
599
+
600
+ # 准备控制图像和模式列表
601
+ control_images = []
602
+ control_modes = []
603
+ conditioning_scales = []
604
+
605
+ # 根据用户选择添加控制模式
606
+ if depth_mode:
607
+ control_image_depth = self.generate_depth_map(input_image_a)
608
+ control_images.append(control_image_depth)
609
+ control_modes.append(2) # depth mode
610
+ conditioning_scales.append(depth_strength)
611
+
612
+ if line_mode:
613
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
614
+ control_images.append(control_image_canny)
615
+ control_modes.append(0) # line mode
616
+ conditioning_scales.append(line_strength)
617
+
618
+ # 如果没有启用任何模式,默认使用line+depth模式
619
+ if not line_mode and not depth_mode:
620
+ control_image_depth = self.generate_depth_map(input_image_a)
621
+ control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
622
+ control_images = [control_image_depth, control_image_canny]
623
+ control_modes = [2, 0]
624
+ conditioning_scales = [0.2, 0.4]
625
+
626
+ if qwen2_hidden_state_b is not None:
627
+ qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
628
+ qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
629
+
630
+ gen_images = controlnet_pipeline(
631
+ image=input_image_a,
632
+ mask_image=mask_image,
633
+ control_image=control_images,
634
+ control_mode=control_modes,
635
+ controlnet_conditioning_scale=conditioning_scales,
636
+ strength=denoise_strength,
637
+ prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
638
+ t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
639
+ prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
640
+ pooled_prompt_embeds=pooled_prompt_embeds,
641
+ num_inference_steps=num_inference_steps,
642
+ guidance_scale=guidance_scale,
643
+ height=height,
644
+ width=width,
645
+ ).images
646
+
647
+ else:
648
+ raise ValueError(f"Invalid mode: {mode}")
649
+
650
+ return gen_images
qwen2_vl/configuration_qwen2_vl.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2VL model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class Qwen2VLVisionConfig(PretrainedConfig):
28
+ model_type = "qwen2_vl"
29
+
30
+ def __init__(
31
+ self,
32
+ depth=32,
33
+ embed_dim=1280,
34
+ hidden_size=3584,
35
+ hidden_act="quick_gelu",
36
+ mlp_ratio=4,
37
+ num_heads=16,
38
+ in_channels=3,
39
+ patch_size=14,
40
+ spatial_merge_size=2,
41
+ temporal_patch_size=2,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ self.depth = depth
47
+ self.embed_dim = embed_dim
48
+ self.hidden_size = hidden_size
49
+ self.hidden_act = hidden_act
50
+ self.mlp_ratio = mlp_ratio
51
+ self.num_heads = num_heads
52
+ self.in_channels = in_channels
53
+ self.patch_size = patch_size
54
+ self.spatial_merge_size = spatial_merge_size
55
+ self.temporal_patch_size = temporal_patch_size
56
+
57
+ @classmethod
58
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
59
+ cls._set_token_in_kwargs(kwargs)
60
+
61
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
62
+
63
+ if config_dict.get("model_type") == "qwen2_vl":
64
+ config_dict = config_dict["vision_config"]
65
+
66
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
67
+ logger.warning(
68
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
69
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
70
+ )
71
+
72
+ return cls.from_dict(config_dict, **kwargs)
73
+
74
+
75
+ class Qwen2VLConfig(PretrainedConfig):
76
+ r"""
77
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
78
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
79
+ with the defaults will yield a similar configuration to that of
80
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
81
+
82
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
83
+ documentation from [`PretrainedConfig`] for more information.
84
+
85
+
86
+ Args:
87
+ vocab_size (`int`, *optional*, defaults to 152064):
88
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
89
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
90
+ hidden_size (`int`, *optional*, defaults to 8192):
91
+ Dimension of the hidden representations.
92
+ intermediate_size (`int`, *optional*, defaults to 29568):
93
+ Dimension of the MLP representations.
94
+ num_hidden_layers (`int`, *optional*, defaults to 80):
95
+ Number of hidden layers in the Transformer encoder.
96
+ num_attention_heads (`int`, *optional*, defaults to 64):
97
+ Number of attention heads for each attention layer in the Transformer encoder.
98
+ num_key_value_heads (`int`, *optional*, defaults to 8):
99
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
100
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
101
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
102
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
103
+ by meanpooling all the original heads within that group. For more details checkout [this
104
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
105
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
106
+ The non-linear activation function (function or string) in the decoder.
107
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
108
+ The maximum sequence length that this model might ever be used with.
109
+ initializer_range (`float`, *optional*, defaults to 0.02):
110
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
111
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
112
+ The epsilon used by the rms normalization layers.
113
+ use_cache (`bool`, *optional*, defaults to `True`):
114
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
115
+ relevant if `config.is_decoder=True`.
116
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
117
+ Whether the model's input and output word embeddings should be tied.
118
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
119
+ The base period of the RoPE embeddings.
120
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
121
+ Whether to use sliding window attention.
122
+ sliding_window (`int`, *optional*, defaults to 4096):
123
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
124
+ max_window_layers (`int`, *optional*, defaults to 80):
125
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
126
+ attention_dropout (`float`, *optional*, defaults to 0.0):
127
+ The dropout ratio for the attention probabilities.
128
+ vision_config (`Dict`, *optional*):
129
+ The config for the visual encoder initialization.
130
+ rope_scaling (`Dict`, *optional*):
131
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
132
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
133
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
134
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
135
+ these scaling strategies behave:
136
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
137
+ experimental feature, subject to breaking API changes in future versions.
138
+
139
+ ```python
140
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
141
+
142
+ >>> # Initializing a Qwen2VL style configuration
143
+ >>> configuration = Qwen2VLConfig()
144
+
145
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
146
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
147
+
148
+ >>> # Accessing the model configuration
149
+ >>> configuration = model.config
150
+ ```"""
151
+
152
+ model_type = "qwen2_vl"
153
+ keys_to_ignore_at_inference = ["past_key_values"]
154
+
155
+ def __init__(
156
+ self,
157
+ vocab_size=152064,
158
+ hidden_size=8192,
159
+ intermediate_size=29568,
160
+ num_hidden_layers=80,
161
+ num_attention_heads=64,
162
+ num_key_value_heads=8,
163
+ hidden_act="silu",
164
+ max_position_embeddings=32768,
165
+ initializer_range=0.02,
166
+ rms_norm_eps=1e-05,
167
+ use_cache=True,
168
+ tie_word_embeddings=False,
169
+ rope_theta=1000000.0,
170
+ use_sliding_window=False,
171
+ sliding_window=4096,
172
+ max_window_layers=80,
173
+ attention_dropout=0.0,
174
+ vision_config=None,
175
+ rope_scaling=None,
176
+ **kwargs,
177
+ ):
178
+ if isinstance(vision_config, dict):
179
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
180
+ elif vision_config is None:
181
+ self.vision_config = Qwen2VLVisionConfig()
182
+
183
+ self.vocab_size = vocab_size
184
+ self.max_position_embeddings = max_position_embeddings
185
+ self.hidden_size = hidden_size
186
+ self.intermediate_size = intermediate_size
187
+ self.num_hidden_layers = num_hidden_layers
188
+ self.num_attention_heads = num_attention_heads
189
+ self.use_sliding_window = use_sliding_window
190
+ self.sliding_window = sliding_window
191
+ self.max_window_layers = max_window_layers
192
+
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.hidden_act = hidden_act
199
+ self.initializer_range = initializer_range
200
+ self.rms_norm_eps = rms_norm_eps
201
+ self.use_cache = use_cache
202
+ self.rope_theta = rope_theta
203
+ self.attention_dropout = attention_dropout
204
+ self.rope_scaling = rope_scaling
205
+
206
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
qwen2_vl/image_processing_qwen2_vl.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Image processor class for Qwen2-VL."""
21
+
22
+ import math
23
+ from typing import Dict, List, Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
28
+ from transformers.image_transforms import (
29
+ convert_to_rgb,
30
+ resize,
31
+ to_channel_dimension_format,
32
+ )
33
+ from transformers.image_utils import (
34
+ OPENAI_CLIP_MEAN,
35
+ OPENAI_CLIP_STD,
36
+ ChannelDimension,
37
+ ImageInput,
38
+ PILImageResampling,
39
+ VideoInput,
40
+ get_image_size,
41
+ infer_channel_dimension_format,
42
+ is_scaled_image,
43
+ is_valid_image,
44
+ make_list_of_images,
45
+ to_numpy_array,
46
+ valid_images,
47
+ validate_preprocess_arguments,
48
+ )
49
+ from transformers.utils import TensorType, is_vision_available, logging
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ if is_vision_available():
56
+ from PIL import Image
57
+
58
+
59
+ def make_batched_images(images) -> List[List[ImageInput]]:
60
+ """
61
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
62
+
63
+ Args:
64
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
65
+ The input image.
66
+
67
+ Returns:
68
+ list: A list of images.
69
+ """
70
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
71
+ return [img for img_list in images for img in img_list]
72
+
73
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
74
+ return images
75
+
76
+ elif is_valid_image(images):
77
+ return [images]
78
+
79
+ raise ValueError(f"Could not make batched images from {images}")
80
+
81
+
82
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
83
+ def make_batched_videos(videos) -> List[VideoInput]:
84
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
85
+ return videos
86
+
87
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
88
+ if isinstance(videos[0], Image.Image):
89
+ return [videos]
90
+ elif len(videos[0].shape) == 4:
91
+ return [list(video) for video in videos]
92
+
93
+ elif is_valid_image(videos) and len(videos.shape) == 4:
94
+ return [list(videos)]
95
+
96
+ raise ValueError(f"Could not make batched video from {videos}")
97
+
98
+
99
+ def smart_resize(
100
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
101
+ ):
102
+ """Rescales the image so that the following conditions are met:
103
+
104
+ 1. Both dimensions (height and width) are divisible by 'factor'.
105
+
106
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
107
+
108
+ 3. The aspect ratio of the image is maintained as closely as possible.
109
+
110
+ """
111
+ if height < factor or width < factor:
112
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
113
+ elif max(height, width) / min(height, width) > 200:
114
+ raise ValueError(
115
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
116
+ )
117
+ h_bar = round(height / factor) * factor
118
+ w_bar = round(width / factor) * factor
119
+ if h_bar * w_bar > max_pixels:
120
+ beta = math.sqrt((height * width) / max_pixels)
121
+ h_bar = math.floor(height / beta / factor) * factor
122
+ w_bar = math.floor(width / beta / factor) * factor
123
+ elif h_bar * w_bar < min_pixels:
124
+ beta = math.sqrt(min_pixels / (height * width))
125
+ h_bar = math.ceil(height * beta / factor) * factor
126
+ w_bar = math.ceil(width * beta / factor) * factor
127
+ return h_bar, w_bar
128
+
129
+
130
+ class Qwen2VLImageProcessor(BaseImageProcessor):
131
+ r"""
132
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
133
+
134
+ Args:
135
+ do_resize (`bool`, *optional*, defaults to `True`):
136
+ Whether to resize the image's (height, width) dimensions.
137
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
138
+ Resampling filter to use when resizing the image.
139
+ do_rescale (`bool`, *optional*, defaults to `True`):
140
+ Whether to rescale the image by the specified scale `rescale_factor`.
141
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
142
+ Scale factor to use if rescaling the image.
143
+ do_normalize (`bool`, *optional*, defaults to `True`):
144
+ Whether to normalize the image.
145
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
146
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
147
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
148
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
149
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
150
+ Whether to convert the image to RGB.
151
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
152
+ The min pixels of the image to resize the image.
153
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
154
+ The max pixels of the image to resize the image.
155
+ patch_size (`int`, *optional*, defaults to 14):
156
+ The spacial patch size of the vision encoder.
157
+ temporal_patch_size (`int`, *optional*, defaults to 2):
158
+ The temporal patch size of the vision encoder.
159
+ merge_size (`int`, *optional*, defaults to 2):
160
+ The merge size of the vision encoder to llm encoder.
161
+ """
162
+
163
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
164
+
165
+ def __init__(
166
+ self,
167
+ do_resize: bool = True,
168
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
169
+ do_rescale: bool = True,
170
+ rescale_factor: Union[int, float] = 1 / 255,
171
+ do_normalize: bool = True,
172
+ image_mean: Optional[Union[float, List[float]]] = None,
173
+ image_std: Optional[Union[float, List[float]]] = None,
174
+ do_convert_rgb: bool = True,
175
+ min_pixels: int = 56 * 56,
176
+ max_pixels: int = 28 * 28 * 1280,
177
+ patch_size: int = 14,
178
+ temporal_patch_size: int = 2,
179
+ merge_size: int = 2,
180
+ **kwargs,
181
+ ) -> None:
182
+ super().__init__(**kwargs)
183
+ self.do_resize = do_resize
184
+ self.resample = resample
185
+ self.do_rescale = do_rescale
186
+ self.rescale_factor = rescale_factor
187
+ self.do_normalize = do_normalize
188
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
189
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
190
+ self.min_pixels = min_pixels
191
+ self.max_pixels = max_pixels
192
+ self.patch_size = patch_size
193
+ self.temporal_patch_size = temporal_patch_size
194
+ self.merge_size = merge_size
195
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
196
+ self.do_convert_rgb = do_convert_rgb
197
+
198
+ def _preprocess(
199
+ self,
200
+ images: Union[ImageInput, VideoInput],
201
+ do_resize: bool = None,
202
+ resample: PILImageResampling = None,
203
+ do_rescale: bool = None,
204
+ rescale_factor: float = None,
205
+ do_normalize: bool = None,
206
+ image_mean: Optional[Union[float, List[float]]] = None,
207
+ image_std: Optional[Union[float, List[float]]] = None,
208
+ do_convert_rgb: bool = None,
209
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
210
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
211
+ ):
212
+ """
213
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
214
+
215
+ Args:
216
+ images (`ImageInput`):
217
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
218
+ vision_info (`List[Dict]`, *optional*):
219
+ Optional list of dictionaries containing additional information about vision inputs.
220
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
221
+ Whether to resize the image.
222
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
223
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
224
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
225
+ Whether to rescale the image.
226
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
227
+ Scale factor to use if rescaling the image.
228
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
229
+ Whether to normalize the image.
230
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
231
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
232
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
233
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
234
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
235
+ Whether to convert the image to RGB.
236
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
237
+ The channel dimension format for the output image. Can be one of:
238
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
239
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
240
+ - Unset: Use the channel dimension format of the input image.
241
+ input_data_format (`ChannelDimension` or `str`, *optional*):
242
+ The channel dimension format for the input image. Can be one of:
243
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
244
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
245
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
246
+ """
247
+ images = make_list_of_images(images)
248
+
249
+ if do_convert_rgb:
250
+ images = [convert_to_rgb(image) for image in images]
251
+
252
+ # All transformations expect numpy arrays.
253
+ images = [to_numpy_array(image) for image in images]
254
+
255
+ if is_scaled_image(images[0]) and do_rescale:
256
+ logger.warning_once(
257
+ "It looks like you are trying to rescale already rescaled images. If the input"
258
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
259
+ )
260
+ if input_data_format is None:
261
+ # We assume that all images have the same channel dimension format.
262
+ input_data_format = infer_channel_dimension_format(images[0])
263
+
264
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
265
+ resized_height, resized_width = height, width
266
+ processed_images = []
267
+ for image in images:
268
+ if do_resize:
269
+ resized_height, resized_width = smart_resize(
270
+ height,
271
+ width,
272
+ factor=self.patch_size * self.merge_size,
273
+ min_pixels=self.min_pixels,
274
+ max_pixels=self.max_pixels,
275
+ )
276
+ image = resize(
277
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
278
+ )
279
+
280
+ if do_rescale:
281
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
282
+
283
+ if do_normalize:
284
+ image = self.normalize(
285
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
286
+ )
287
+
288
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
289
+ processed_images.append(image)
290
+
291
+ patches = np.array(processed_images)
292
+ if data_format == ChannelDimension.LAST:
293
+ patches = patches.transpose(0, 3, 1, 2)
294
+ if patches.shape[0] == 1:
295
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
296
+ channel = patches.shape[1]
297
+ grid_t = patches.shape[0] // self.temporal_patch_size
298
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
299
+ patches = patches.reshape(
300
+ grid_t,
301
+ self.temporal_patch_size,
302
+ channel,
303
+ grid_h // self.merge_size,
304
+ self.merge_size,
305
+ self.patch_size,
306
+ grid_w // self.merge_size,
307
+ self.merge_size,
308
+ self.patch_size,
309
+ )
310
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
311
+ flatten_patches = patches.reshape(
312
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
313
+ )
314
+
315
+ return flatten_patches, (grid_t, grid_h, grid_w)
316
+
317
+ def preprocess(
318
+ self,
319
+ images: ImageInput,
320
+ videos: VideoInput = None,
321
+ do_resize: bool = None,
322
+ size: Dict[str, int] = None,
323
+ resample: PILImageResampling = None,
324
+ do_rescale: bool = None,
325
+ rescale_factor: float = None,
326
+ do_normalize: bool = None,
327
+ image_mean: Optional[Union[float, List[float]]] = None,
328
+ image_std: Optional[Union[float, List[float]]] = None,
329
+ do_convert_rgb: bool = None,
330
+ return_tensors: Optional[Union[str, TensorType]] = None,
331
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
332
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
333
+ ):
334
+ """
335
+ Args:
336
+ images (`ImageInput`):
337
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
338
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
339
+ videos (`VideoInput`):
340
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
341
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
342
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
343
+ Whether to resize the image.
344
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
345
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
346
+ the longest edge resized to keep the input aspect ratio.
347
+ resample (`int`, *optional*, defaults to `self.resample`):
348
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
349
+ has an effect if `do_resize` is set to `True`.
350
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
351
+ Whether to rescale the image.
352
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
353
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
354
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
355
+ Whether to normalize the image.
356
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
357
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
358
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
359
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
360
+ `True`.
361
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
362
+ Whether to convert the image to RGB.
363
+ return_tensors (`str` or `TensorType`, *optional*):
364
+ The type of tensors to return. Can be one of:
365
+ - Unset: Return a list of `np.ndarray`.
366
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
367
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
368
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
369
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
370
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
371
+ The channel dimension format for the output image. Can be one of:
372
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
373
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
374
+ - Unset: Use the channel dimension format of the input image.
375
+ input_data_format (`ChannelDimension` or `str`, *optional*):
376
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
377
+ from the input image. Can be one of:
378
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
379
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
380
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
381
+
382
+ """
383
+ do_resize = do_resize if do_resize is not None else self.do_resize
384
+ size = size if size is not None else self.size
385
+ resample = resample if resample is not None else self.resample
386
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
387
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
388
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
389
+ image_mean = image_mean if image_mean is not None else self.image_mean
390
+ image_std = image_std if image_std is not None else self.image_std
391
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
392
+
393
+ if images is not None:
394
+ images = make_batched_images(images)
395
+ if videos is not None:
396
+ videos = make_batched_videos(videos)
397
+
398
+ if images is not None and not valid_images(images):
399
+ raise ValueError(
400
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
401
+ "torch.Tensor, tf.Tensor or jax.ndarray."
402
+ )
403
+
404
+ validate_preprocess_arguments(
405
+ rescale_factor=rescale_factor,
406
+ do_normalize=do_normalize,
407
+ image_mean=image_mean,
408
+ image_std=image_std,
409
+ do_resize=do_resize,
410
+ size=size,
411
+ resample=resample,
412
+ )
413
+
414
+ if images is not None:
415
+ pixel_values, vision_grid_thws = [], []
416
+ for image in images:
417
+ patches, image_grid_thw = self._preprocess(
418
+ image,
419
+ do_resize=do_resize,
420
+ resample=resample,
421
+ do_rescale=do_rescale,
422
+ rescale_factor=rescale_factor,
423
+ do_normalize=do_normalize,
424
+ image_mean=image_mean,
425
+ image_std=image_std,
426
+ data_format=data_format,
427
+ do_convert_rgb=do_convert_rgb,
428
+ input_data_format=input_data_format,
429
+ )
430
+ pixel_values.extend(patches)
431
+ vision_grid_thws.append(image_grid_thw)
432
+ pixel_values = np.array(pixel_values)
433
+ vision_grid_thws = np.array(vision_grid_thws)
434
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
435
+
436
+ if videos is not None:
437
+ pixel_values, vision_grid_thws = [], []
438
+ for images in videos:
439
+ patches, video_grid_thw = self._preprocess(
440
+ images,
441
+ do_resize=do_resize,
442
+ resample=resample,
443
+ do_rescale=do_rescale,
444
+ rescale_factor=rescale_factor,
445
+ do_normalize=do_normalize,
446
+ image_mean=image_mean,
447
+ image_std=image_std,
448
+ data_format=data_format,
449
+ do_convert_rgb=do_convert_rgb,
450
+ input_data_format=input_data_format,
451
+ )
452
+ pixel_values.extend(patches)
453
+ vision_grid_thws.append(video_grid_thw)
454
+ pixel_values = np.array(pixel_values)
455
+ vision_grid_thws = np.array(vision_grid_thws)
456
+ data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
457
+
458
+ return BatchFeature(data=data, tensor_type=return_tensors)
qwen2_vl/modeling_qwen2_vl.py ADDED
@@ -0,0 +1,1952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2-VL model."""
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch.nn import CrossEntropyLoss, LayerNorm
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, StaticCache
34
+ from transformers.modeling_attn_mask_utils import (
35
+ AttentionMaskConverter,
36
+ )
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ ModelOutput,
40
+ )
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_2_available,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
51
+
52
+ import traceback
53
+
54
+
55
+ if is_flash_attn_2_available():
56
+ from flash_attn import flash_attn_varlen_func
57
+
58
+ from ...modeling_flash_attention_utils import _flash_attention_forward
59
+ else:
60
+ flash_attn_varlen_func = None
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+
65
+ _CONFIG_FOR_DOC = "Qwen2VLConfig"
66
+
67
+
68
+ @dataclass
69
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
70
+ """
71
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
72
+
73
+ Args:
74
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
75
+ Language modeling loss (for next-token prediction).
76
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
77
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
78
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
79
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
80
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
81
+
82
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
83
+ `past_key_values` input) to speed up sequential decoding.
84
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
85
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
86
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
87
+
88
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
89
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
90
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
91
+ sequence_length)`.
92
+
93
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
94
+ heads.
95
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
96
+ The rope index difference between sequence length and multimodal rope.
97
+ """
98
+
99
+ loss: Optional[torch.FloatTensor] = None
100
+ logits: torch.FloatTensor = None
101
+ past_key_values: Optional[List[torch.FloatTensor]] = None
102
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
103
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
104
+ rope_deltas: Optional[torch.LongTensor] = None
105
+
106
+
107
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding
108
+ class Qwen2RotaryEmbedding(nn.Module):
109
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
110
+ super().__init__()
111
+
112
+ self.dim = dim
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.base = base
115
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
116
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
117
+
118
+ # Build here to make `torch.jit.trace` work.
119
+ self._set_cos_sin_cache(
120
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
121
+ )
122
+
123
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
124
+ self.max_seq_len_cached = seq_len
125
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
126
+
127
+ freqs = torch.outer(t, self.inv_freq)
128
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
129
+ emb = torch.cat((freqs, freqs), dim=-1)
130
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
131
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
132
+
133
+ def forward(self, x, seq_len=None):
134
+ # x: [bs, num_attention_heads, seq_len, head_size]
135
+ if seq_len > self.max_seq_len_cached:
136
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
137
+
138
+ return (
139
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
140
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
141
+ )
142
+
143
+
144
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
145
+ def rotate_half(x):
146
+ """Rotates half the hidden dims of the input."""
147
+ x1 = x[..., : x.shape[-1] // 2]
148
+ x2 = x[..., x.shape[-1] // 2 :]
149
+ return torch.cat((-x2, x1), dim=-1)
150
+
151
+
152
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
153
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
154
+
155
+ Explanation:
156
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
157
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
158
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
159
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
160
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
161
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
162
+ difference with modern LLMs.
163
+
164
+ Args:
165
+ q (`torch.Tensor`): The query tensor.
166
+ k (`torch.Tensor`): The key tensor.
167
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
168
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
169
+ position_ids (`torch.Tensor`):
170
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
171
+ used to pass offsetted position ids when working with a KV-cache.
172
+ mrope_section(`List(int)`):
173
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
174
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
175
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
176
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
177
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
178
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
179
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
180
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
181
+ Returns:
182
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
183
+ """
184
+ cos = cos[position_ids]
185
+ sin = sin[position_ids]
186
+ mrope_section = mrope_section * 2
187
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
188
+ unsqueeze_dim
189
+ )
190
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
191
+ unsqueeze_dim
192
+ )
193
+
194
+ q_embed = (q * cos) + (rotate_half(q) * sin)
195
+ k_embed = (k * cos) + (rotate_half(k) * sin)
196
+ return q_embed, k_embed
197
+
198
+
199
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
200
+ orig_dtype = tensor.dtype
201
+ tensor = tensor.float()
202
+ cos = freqs.cos()
203
+ sin = freqs.sin()
204
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
205
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
206
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
207
+ output = output.to(orig_dtype)
208
+ return output
209
+
210
+
211
+ class VisionRotaryEmbedding(nn.Module):
212
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
213
+ super().__init__()
214
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
215
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
216
+
217
+ def forward(self, seqlen: int) -> torch.Tensor:
218
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
219
+ freqs = torch.outer(seq, self.inv_freq)
220
+ return freqs
221
+
222
+
223
+ class PatchEmbed(nn.Module):
224
+ def __init__(
225
+ self,
226
+ patch_size: int = 14,
227
+ temporal_patch_size: int = 2,
228
+ in_channels: int = 3,
229
+ embed_dim: int = 1152,
230
+ ) -> None:
231
+ super().__init__()
232
+ self.patch_size = patch_size
233
+ self.temporal_patch_size = temporal_patch_size
234
+ self.in_channels = in_channels
235
+ self.embed_dim = embed_dim
236
+
237
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
238
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
239
+
240
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
241
+ target_dtype = self.proj.weight.dtype
242
+ hidden_states = hidden_states.view(
243
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
244
+ )
245
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
246
+ return hidden_states
247
+
248
+
249
+ class PatchMerger(nn.Module):
250
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
251
+ super().__init__()
252
+ self.hidden_size = context_dim * (spatial_merge_size**2)
253
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
254
+ self.mlp = nn.Sequential(
255
+ nn.Linear(self.hidden_size, self.hidden_size),
256
+ nn.GELU(),
257
+ nn.Linear(self.hidden_size, dim),
258
+ )
259
+
260
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
261
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
262
+ return x
263
+
264
+
265
+ class VisionMlp(nn.Module):
266
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
267
+ super().__init__()
268
+ self.fc1 = nn.Linear(dim, hidden_dim)
269
+ self.act = ACT2FN[hidden_act]
270
+ self.fc2 = nn.Linear(hidden_dim, dim)
271
+
272
+ def forward(self, x) -> torch.Tensor:
273
+ return self.fc2(self.act(self.fc1(x)))
274
+
275
+
276
+ class VisionAttention(nn.Module):
277
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
278
+ super().__init__()
279
+ self.num_heads = num_heads
280
+ self.head_dim = dim // num_heads
281
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
282
+ self.proj = nn.Linear(dim, dim)
283
+
284
+ def forward(
285
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
286
+ ) -> torch.Tensor:
287
+ seq_length = hidden_states.shape[0]
288
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
289
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
290
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
291
+
292
+ attention_mask = torch.full(
293
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
294
+ )
295
+ for i in range(1, len(cu_seqlens)):
296
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
297
+
298
+ q = q.transpose(0, 1)
299
+ k = k.transpose(0, 1)
300
+ v = v.transpose(0, 1)
301
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
302
+ attn_weights = attn_weights + attention_mask
303
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
304
+ attn_output = torch.matmul(attn_weights, v)
305
+ attn_output = attn_output.transpose(0, 1)
306
+ attn_output = attn_output.reshape(seq_length, -1)
307
+ attn_output = self.proj(attn_output)
308
+ return attn_output
309
+
310
+
311
+ class VisionFlashAttention2(nn.Module):
312
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
313
+ super().__init__()
314
+ self.num_heads = num_heads
315
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
316
+ self.proj = nn.Linear(dim, dim)
317
+
318
+ def forward(
319
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
320
+ ) -> torch.Tensor:
321
+ seq_length = hidden_states.shape[0]
322
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
323
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
324
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
325
+
326
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
327
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
328
+ seq_length, -1
329
+ )
330
+ attn_output = self.proj(attn_output)
331
+ return attn_output
332
+
333
+
334
+ class VisionSdpaAttention(nn.Module):
335
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
336
+ super().__init__()
337
+ self.num_heads = num_heads
338
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
339
+ self.proj = nn.Linear(dim, dim)
340
+
341
+ def forward(
342
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
343
+ ) -> torch.Tensor:
344
+ seq_length = hidden_states.shape[0]
345
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
346
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
347
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
348
+
349
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
350
+ for i in range(1, len(cu_seqlens)):
351
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
352
+ q = q.transpose(0, 1)
353
+ k = k.transpose(0, 1)
354
+ v = v.transpose(0, 1)
355
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
356
+ attn_output = attn_output.transpose(0, 1)
357
+ attn_output = attn_output.reshape(seq_length, -1)
358
+ attn_output = self.proj(attn_output)
359
+ return attn_output
360
+
361
+
362
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
363
+ "eager": VisionAttention,
364
+ "flash_attention_2": VisionFlashAttention2,
365
+ "sdpa": VisionSdpaAttention,
366
+ }
367
+
368
+
369
+ class Qwen2VLVisionBlock(nn.Module):
370
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
371
+ super().__init__()
372
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
373
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
374
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
375
+
376
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
377
+ config.embed_dim, num_heads=config.num_heads
378
+ )
379
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
380
+
381
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
382
+ hidden_states = hidden_states + self.attn(
383
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
384
+ )
385
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
386
+ return hidden_states
387
+
388
+
389
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
390
+ def _prepare_4d_causal_attention_mask_with_cache_position(
391
+ attention_mask: torch.Tensor,
392
+ sequence_length: int,
393
+ target_length: int,
394
+ dtype: torch.dtype,
395
+ device: torch.device,
396
+ min_dtype: float,
397
+ cache_position: torch.Tensor,
398
+ batch_size: int,
399
+ ):
400
+ """
401
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
402
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
403
+
404
+ Args:
405
+ attention_mask (`torch.Tensor`):
406
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
407
+ sequence_length (`int`):
408
+ The sequence length being processed.
409
+ target_length (`int`):
410
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
411
+ dtype (`torch.dtype`):
412
+ The dtype to use for the 4D attention mask.
413
+ device (`torch.device`):
414
+ The device to plcae the 4D attention mask on.
415
+ min_dtype (`float`):
416
+ The minimum value representable with the dtype `dtype`.
417
+ cache_position (`torch.Tensor`):
418
+ Indices depicting the position of the input sequence tokens in the sequence.
419
+ batch_size (`torch.Tensor`):
420
+ Batch size.
421
+ """
422
+ if attention_mask is not None and attention_mask.dim() == 4:
423
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
424
+ causal_mask = attention_mask
425
+ else:
426
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
427
+ if sequence_length != 1:
428
+ causal_mask = torch.triu(causal_mask, diagonal=1)
429
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
430
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
431
+ if attention_mask is not None:
432
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
433
+ mask_length = attention_mask.shape[-1]
434
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
435
+ padding_mask = padding_mask == 0
436
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
437
+ padding_mask, min_dtype
438
+ )
439
+
440
+ return causal_mask
441
+
442
+
443
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
444
+ class Qwen2RMSNorm(nn.Module):
445
+ def __init__(self, hidden_size, eps=1e-6):
446
+ """
447
+ Qwen2RMSNorm is equivalent to T5LayerNorm
448
+ """
449
+ super().__init__()
450
+ self.weight = nn.Parameter(torch.ones(hidden_size))
451
+ self.variance_epsilon = eps
452
+
453
+ def forward(self, hidden_states):
454
+ input_dtype = hidden_states.dtype
455
+ hidden_states = hidden_states.to(torch.float32)
456
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
457
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
458
+ return self.weight * hidden_states.to(input_dtype)
459
+
460
+ def extra_repr(self):
461
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
462
+
463
+
464
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
465
+ class Qwen2MLP(nn.Module):
466
+ def __init__(self, config):
467
+ super().__init__()
468
+ self.hidden_size = config.hidden_size
469
+ self.intermediate_size = config.intermediate_size
470
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
471
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
472
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
473
+ self.act_fn = ACT2FN[config.hidden_act]
474
+
475
+ def forward(self, hidden_state):
476
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
477
+
478
+
479
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
480
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
481
+ """
482
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
483
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
484
+ """
485
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
486
+ if n_rep == 1:
487
+ return hidden_states
488
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
489
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
490
+
491
+
492
+ class Qwen2VLAttention(nn.Module):
493
+ """
494
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
495
+ and "Generating Long Sequences with Sparse Transformers".
496
+ """
497
+
498
+ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
499
+ super().__init__()
500
+ self.config = config
501
+ self.layer_idx = layer_idx
502
+ if layer_idx is None:
503
+ logger.warning_once(
504
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
505
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
506
+ "when creating this class."
507
+ )
508
+
509
+ self.hidden_size = config.hidden_size
510
+ self.num_heads = config.num_attention_heads
511
+ self.head_dim = self.hidden_size // self.num_heads
512
+ self.num_key_value_heads = config.num_key_value_heads
513
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
514
+ self.max_position_embeddings = config.max_position_embeddings
515
+ self.rope_theta = config.rope_theta
516
+ self.is_causal = True
517
+ self.attention_dropout = config.attention_dropout
518
+ self.rope_scaling = config.rope_scaling
519
+
520
+ if (self.head_dim * self.num_heads) != self.hidden_size:
521
+ raise ValueError(
522
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
523
+ f" and `num_heads`: {self.num_heads})."
524
+ )
525
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
526
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
527
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
528
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
529
+
530
+ self.rotary_emb = Qwen2RotaryEmbedding(
531
+ self.head_dim,
532
+ max_position_embeddings=self.max_position_embeddings,
533
+ base=self.rope_theta,
534
+ )
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.Tensor,
539
+ attention_mask: Optional[torch.Tensor] = None,
540
+ position_ids: Optional[torch.LongTensor] = None,
541
+ past_key_value: Optional[Cache] = None,
542
+ output_attentions: bool = False,
543
+ use_cache: bool = False,
544
+ cache_position: Optional[torch.LongTensor] = None,
545
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
546
+ bsz, q_len, _ = hidden_states.size()
547
+
548
+ query_states = self.q_proj(hidden_states)
549
+ key_states = self.k_proj(hidden_states)
550
+ value_states = self.v_proj(hidden_states)
551
+
552
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
553
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
554
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
555
+
556
+ kv_seq_len = key_states.shape[-2]
557
+ if past_key_value is not None:
558
+ if self.layer_idx is None:
559
+ raise ValueError(
560
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
561
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
562
+ "with a layer index."
563
+ )
564
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
565
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
566
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
567
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
568
+ )
569
+
570
+ if past_key_value is not None:
571
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
572
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
573
+
574
+ # repeat k/v heads if n_kv_heads < n_heads
575
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
576
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
577
+
578
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
579
+
580
+ if attention_mask is not None: # no matter the length, we just slice it
581
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
582
+ attn_weights = attn_weights + causal_mask
583
+
584
+ # upcast attention to fp32
585
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
586
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
587
+ attn_output = torch.matmul(attn_weights, value_states)
588
+
589
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
590
+ raise ValueError(
591
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
592
+ f" {attn_output.size()}"
593
+ )
594
+
595
+ attn_output = attn_output.transpose(1, 2).contiguous()
596
+ attn_output = attn_output.reshape(bsz, q_len, -1)
597
+
598
+ attn_output = self.o_proj(attn_output)
599
+
600
+ if not output_attentions:
601
+ attn_weights = None
602
+
603
+ return attn_output, attn_weights, past_key_value
604
+
605
+
606
+ class Qwen2VLFlashAttention2(Qwen2VLAttention):
607
+ """
608
+ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
609
+ as the weights of the module stays untouched. The only required change would be on the forward pass
610
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
611
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
612
+ config.max_window_layers layers.
613
+ """
614
+
615
+ def __init__(self, *args, **kwargs):
616
+ super().__init__(*args, **kwargs)
617
+
618
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
619
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
620
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
621
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
622
+
623
+ def forward(
624
+ self,
625
+ hidden_states: torch.Tensor,
626
+ attention_mask: Optional[torch.Tensor] = None,
627
+ position_ids: Optional[torch.LongTensor] = None,
628
+ past_key_value: Optional[Cache] = None,
629
+ output_attentions: bool = False,
630
+ use_cache: bool = False,
631
+ cache_position: Optional[torch.LongTensor] = None,
632
+ ):
633
+ bsz, q_len, _ = hidden_states.size()
634
+
635
+ query_states = self.q_proj(hidden_states)
636
+ key_states = self.k_proj(hidden_states)
637
+ value_states = self.v_proj(hidden_states)
638
+
639
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
640
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
641
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
642
+
643
+ kv_seq_len = key_states.shape[-2]
644
+ if past_key_value is not None:
645
+ if self.layer_idx is None:
646
+ raise ValueError(
647
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
648
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
649
+ "with a layer index."
650
+ )
651
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
652
+
653
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
654
+ rotary_seq_len = (
655
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
656
+ )
657
+
658
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
659
+
660
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
661
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
662
+ )
663
+
664
+ if past_key_value is not None:
665
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
666
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
667
+ if (
668
+ getattr(self.config, "sliding_window", None) is not None
669
+ and kv_seq_len > self.config.sliding_window
670
+ and cache_has_contents
671
+ ):
672
+ slicing_tokens = 1 - self.config.sliding_window
673
+
674
+ past_key = past_key_value[self.layer_idx][0]
675
+ past_value = past_key_value[self.layer_idx][1]
676
+
677
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
678
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
679
+
680
+ if past_key.shape[-2] != self.config.sliding_window - 1:
681
+ raise ValueError(
682
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
683
+ f" {past_key.shape}"
684
+ )
685
+
686
+ if attention_mask is not None:
687
+ attention_mask = attention_mask[:, slicing_tokens:]
688
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
689
+
690
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
691
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
692
+
693
+ # repeat k/v heads if n_kv_heads < n_heads
694
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
695
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
696
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
697
+
698
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
699
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
700
+ # cast them back in float16 just to be sure everything works as expected.
701
+ input_dtype = query_states.dtype
702
+ if input_dtype == torch.float32:
703
+ if torch.is_autocast_enabled():
704
+ target_dtype = torch.get_autocast_gpu_dtype()
705
+ # Handle the case where the model is quantized
706
+ elif hasattr(self.config, "_pre_quantization_dtype"):
707
+ target_dtype = self.config._pre_quantization_dtype
708
+ else:
709
+ target_dtype = self.q_proj.weight.dtype
710
+
711
+ logger.warning_once(
712
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
713
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
714
+ f" {target_dtype}."
715
+ )
716
+
717
+ query_states = query_states.to(target_dtype)
718
+ key_states = key_states.to(target_dtype)
719
+ value_states = value_states.to(target_dtype)
720
+
721
+ # Reashape to the expected shape for Flash Attention
722
+ query_states = query_states.transpose(1, 2)
723
+ key_states = key_states.transpose(1, 2)
724
+ value_states = value_states.transpose(1, 2)
725
+
726
+ if (
727
+ self.config.use_sliding_window
728
+ and getattr(self.config, "sliding_window", None) is not None
729
+ and self.layer_idx >= self.config.max_window_layers
730
+ ):
731
+ sliding_window = self.config.sliding_window
732
+ else:
733
+ sliding_window = None
734
+
735
+ attn_output = _flash_attention_forward(
736
+ query_states,
737
+ key_states,
738
+ value_states,
739
+ attention_mask,
740
+ q_len,
741
+ dropout=dropout_rate,
742
+ sliding_window=sliding_window,
743
+ is_causal=self.is_causal,
744
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
745
+ )
746
+
747
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
748
+ attn_output = self.o_proj(attn_output)
749
+
750
+ if not output_attentions:
751
+ attn_weights = None
752
+
753
+ return attn_output, attn_weights, past_key_value
754
+
755
+
756
+ class Qwen2VLSdpaAttention(Qwen2VLAttention):
757
+ """
758
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
759
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
760
+ SDPA API.
761
+ """
762
+
763
+ # Adapted from Qwen2Attention.forward
764
+ def forward(
765
+ self,
766
+ hidden_states: torch.Tensor,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_value: Optional[Cache] = None,
770
+ output_attentions: bool = False,
771
+ use_cache: bool = False,
772
+ cache_position: Optional[torch.LongTensor] = None,
773
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
774
+ if output_attentions:
775
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
776
+ logger.warning_once(
777
+ "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
778
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
779
+ )
780
+ return super().forward(
781
+ hidden_states=hidden_states,
782
+ attention_mask=attention_mask,
783
+ position_ids=position_ids,
784
+ past_key_value=past_key_value,
785
+ output_attentions=output_attentions,
786
+ use_cache=use_cache,
787
+ )
788
+
789
+ bsz, q_len, _ = hidden_states.size()
790
+
791
+ query_states = self.q_proj(hidden_states)
792
+ key_states = self.k_proj(hidden_states)
793
+ value_states = self.v_proj(hidden_states)
794
+
795
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
796
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
797
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
798
+
799
+ kv_seq_len = key_states.shape[-2]
800
+ if past_key_value is not None:
801
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
802
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
803
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
804
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
805
+ )
806
+
807
+ if past_key_value is not None:
808
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
809
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
810
+
811
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
812
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
813
+
814
+ causal_mask = attention_mask
815
+ if attention_mask is not None: # no matter the length, we just slice it
816
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
817
+
818
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
819
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
820
+ if query_states.device.type == "cuda" and attention_mask is not None:
821
+ query_states = query_states.contiguous()
822
+ key_states = key_states.contiguous()
823
+ value_states = value_states.contiguous()
824
+
825
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
826
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
827
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
828
+ is_causal = True if causal_mask is None and q_len > 1 else False
829
+
830
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
831
+ query_states,
832
+ key_states,
833
+ value_states,
834
+ attn_mask=causal_mask,
835
+ dropout_p=self.attention_dropout if self.training else 0.0,
836
+ is_causal=is_causal,
837
+ )
838
+
839
+ attn_output = attn_output.transpose(1, 2).contiguous()
840
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
841
+
842
+ attn_output = self.o_proj(attn_output)
843
+
844
+ return attn_output, None, past_key_value
845
+
846
+
847
+ QWEN2_VL_ATTENTION_CLASSES = {
848
+ "eager": Qwen2VLAttention,
849
+ "flash_attention_2": Qwen2VLFlashAttention2,
850
+ "sdpa": Qwen2VLSdpaAttention,
851
+ }
852
+
853
+
854
+ class Qwen2VLDecoderLayer(nn.Module):
855
+ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
856
+ super().__init__()
857
+ self.hidden_size = config.hidden_size
858
+
859
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
860
+ logger.warning_once(
861
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
862
+ "unexpected results may be encountered."
863
+ )
864
+ self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
865
+
866
+ self.mlp = Qwen2MLP(config)
867
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
868
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
869
+
870
+ def forward(
871
+ self,
872
+ hidden_states: torch.Tensor,
873
+ attention_mask: Optional[torch.Tensor] = None,
874
+ position_ids: Optional[torch.LongTensor] = None,
875
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
876
+ output_attentions: Optional[bool] = False,
877
+ use_cache: Optional[bool] = False,
878
+ cache_position: Optional[torch.LongTensor] = None,
879
+ **kwargs,
880
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
881
+ """
882
+ Args:
883
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
884
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
885
+ `(batch, sequence_length)` where padding elements are indicated by 0.
886
+ output_attentions (`bool`, *optional*):
887
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
888
+ returned tensors for more detail.
889
+ use_cache (`bool`, *optional*):
890
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
891
+ (see `past_key_values`).
892
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
893
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
894
+ Indices depicting the position of the input sequence tokens in the sequence.
895
+ kwargs (`dict`, *optional*):
896
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
897
+ into the model
898
+ """
899
+
900
+ residual = hidden_states
901
+
902
+ hidden_states = self.input_layernorm(hidden_states)
903
+
904
+ # Self Attention
905
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
906
+ hidden_states=hidden_states,
907
+ attention_mask=attention_mask,
908
+ position_ids=position_ids,
909
+ past_key_value=past_key_value,
910
+ output_attentions=output_attentions,
911
+ use_cache=use_cache,
912
+ cache_position=cache_position,
913
+ )
914
+ hidden_states = residual + hidden_states
915
+
916
+ # Fully Connected
917
+ residual = hidden_states
918
+ hidden_states = self.post_attention_layernorm(hidden_states)
919
+ hidden_states = self.mlp(hidden_states)
920
+ hidden_states = residual + hidden_states
921
+
922
+ outputs = (hidden_states,)
923
+
924
+ if output_attentions:
925
+ outputs += (self_attn_weights,)
926
+
927
+ if use_cache:
928
+ outputs += (present_key_value,)
929
+
930
+ return outputs
931
+
932
+
933
+ QWEN2VL_START_DOCSTRING = r"""
934
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
935
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
936
+ etc.)
937
+
938
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
939
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
940
+ and behavior.
941
+
942
+ Parameters:
943
+ config ([`Qwen2VLConfig`]):
944
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
945
+ load the weights associated with the model, only the configuration. Check out the
946
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
947
+ """
948
+
949
+
950
+ @add_start_docstrings(
951
+ "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
952
+ QWEN2VL_START_DOCSTRING,
953
+ )
954
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
955
+ config_class = Qwen2VLConfig
956
+ base_model_prefix = "model"
957
+ supports_gradient_checkpointing = True
958
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
959
+ _skip_keys_device_placement = "past_key_values"
960
+ _supports_flash_attn_2 = True
961
+ _supports_sdpa = True
962
+ _supports_cache_class = True
963
+ _supports_static_cache = True
964
+
965
+ def _init_weights(self, module):
966
+ std = self.config.initializer_range
967
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
968
+ module.weight.data.normal_(mean=0.0, std=std)
969
+ if module.bias is not None:
970
+ module.bias.data.zero_()
971
+ elif isinstance(module, nn.Embedding):
972
+ module.weight.data.normal_(mean=0.0, std=std)
973
+ if module.padding_idx is not None:
974
+ module.weight.data[module.padding_idx].zero_()
975
+
976
+
977
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
978
+ config_class = Qwen2VLVisionConfig
979
+ _no_split_modules = ["Qwen2VLVisionBlock"]
980
+
981
+ def __init__(self, config) -> None:
982
+ super().__init__(config)
983
+ self.spatial_merge_size = config.spatial_merge_size
984
+
985
+ self.patch_embed = PatchEmbed(
986
+ patch_size=config.patch_size,
987
+ temporal_patch_size=config.temporal_patch_size,
988
+ in_channels=config.in_channels,
989
+ embed_dim=config.embed_dim,
990
+ )
991
+
992
+ head_dim = config.embed_dim // config.num_heads
993
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
994
+
995
+ self.blocks = nn.ModuleList(
996
+ [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
997
+ )
998
+ self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
999
+
1000
+ def get_dtype(self) -> torch.dtype:
1001
+ return self.blocks[0].mlp.fc2.weight.dtype
1002
+
1003
+ def get_device(self) -> torch.device:
1004
+ return self.blocks[0].mlp.fc2.weight.device
1005
+
1006
+ def rot_pos_emb(self, grid_thw):
1007
+ pos_ids = []
1008
+ for t, h, w in grid_thw:
1009
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
1010
+ hpos_ids = hpos_ids.reshape(
1011
+ h // self.spatial_merge_size,
1012
+ self.spatial_merge_size,
1013
+ w // self.spatial_merge_size,
1014
+ self.spatial_merge_size,
1015
+ )
1016
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
1017
+ hpos_ids = hpos_ids.flatten()
1018
+
1019
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
1020
+ wpos_ids = wpos_ids.reshape(
1021
+ h // self.spatial_merge_size,
1022
+ self.spatial_merge_size,
1023
+ w // self.spatial_merge_size,
1024
+ self.spatial_merge_size,
1025
+ )
1026
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
1027
+ wpos_ids = wpos_ids.flatten()
1028
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
1029
+ pos_ids = torch.cat(pos_ids, dim=0)
1030
+ max_grid_size = grid_thw[:, 1:].max()
1031
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
1032
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
1033
+ return rotary_pos_emb
1034
+
1035
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
1036
+ hidden_states = self.patch_embed(hidden_states)
1037
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
1038
+
1039
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
1040
+ dim=0, dtype=torch.int32
1041
+ )
1042
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
1043
+
1044
+ for blk in self.blocks:
1045
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
1046
+
1047
+ return self.merger(hidden_states)
1048
+
1049
+
1050
+ @add_start_docstrings(
1051
+ "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
1052
+ QWEN2VL_START_DOCSTRING,
1053
+ )
1054
+ class Qwen2VLModel(Qwen2VLPreTrainedModel):
1055
+ def __init__(self, config: Qwen2VLConfig):
1056
+ super().__init__(config)
1057
+ self.padding_idx = config.pad_token_id
1058
+ self.vocab_size = config.vocab_size
1059
+
1060
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1061
+ self.layers = nn.ModuleList(
1062
+ [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1063
+ )
1064
+ self._attn_implementation = config._attn_implementation
1065
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1066
+
1067
+ self.gradient_checkpointing = False
1068
+ # Initialize weights and apply final processing
1069
+ self.post_init()
1070
+
1071
+ def get_input_embeddings(self):
1072
+ return self.embed_tokens
1073
+
1074
+ def set_input_embeddings(self, value):
1075
+ self.embed_tokens = value
1076
+
1077
+ def forward(
1078
+ self,
1079
+ input_ids: torch.LongTensor = None,
1080
+ attention_mask: Optional[torch.Tensor] = None,
1081
+ position_ids: Optional[torch.LongTensor] = None,
1082
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1083
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1084
+ use_cache: Optional[bool] = None,
1085
+ output_attentions: Optional[bool] = None,
1086
+ output_hidden_states: Optional[bool] = None,
1087
+ return_dict: Optional[bool] = None,
1088
+ cache_position: Optional[torch.LongTensor] = None,
1089
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1090
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1091
+ output_hidden_states = (
1092
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1093
+ )
1094
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1095
+
1096
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1097
+
1098
+ if (input_ids is None) ^ (inputs_embeds is not None):
1099
+ raise ValueError(
1100
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1101
+ )
1102
+
1103
+ if self.gradient_checkpointing and self.training:
1104
+ if use_cache:
1105
+ logger.warning_once(
1106
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1107
+ )
1108
+ use_cache = False
1109
+
1110
+ if inputs_embeds is None:
1111
+ inputs_embeds = self.embed_tokens(input_ids)
1112
+
1113
+ if cache_position is None:
1114
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1115
+ cache_position = torch.arange(
1116
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1117
+ )
1118
+ if position_ids is None:
1119
+ # the hard coded `3` is for temporal, height and width.
1120
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
1121
+
1122
+ causal_mask = self._update_causal_mask(
1123
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1124
+ )
1125
+
1126
+ hidden_states = inputs_embeds
1127
+
1128
+ # decoder layers
1129
+ all_hidden_states = () if output_hidden_states else None
1130
+ all_self_attns = () if output_attentions else None
1131
+ next_decoder_cache = None
1132
+
1133
+ for decoder_layer in self.layers:
1134
+ if output_hidden_states:
1135
+ all_hidden_states += (hidden_states,)
1136
+
1137
+ if self.gradient_checkpointing and self.training:
1138
+ layer_outputs = self._gradient_checkpointing_func(
1139
+ decoder_layer.__call__,
1140
+ hidden_states,
1141
+ causal_mask,
1142
+ position_ids,
1143
+ past_key_values,
1144
+ output_attentions,
1145
+ use_cache,
1146
+ cache_position,
1147
+ )
1148
+ else:
1149
+ layer_outputs = decoder_layer(
1150
+ hidden_states,
1151
+ attention_mask=causal_mask,
1152
+ position_ids=position_ids,
1153
+ past_key_value=past_key_values,
1154
+ output_attentions=output_attentions,
1155
+ use_cache=use_cache,
1156
+ cache_position=cache_position,
1157
+ )
1158
+
1159
+ hidden_states = layer_outputs[0]
1160
+
1161
+ if use_cache:
1162
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1163
+
1164
+ if output_attentions:
1165
+ all_self_attns += (layer_outputs[1],)
1166
+
1167
+ hidden_states = self.norm(hidden_states)
1168
+
1169
+ # add hidden states from the last decoder layer
1170
+ if output_hidden_states:
1171
+ all_hidden_states += (hidden_states,)
1172
+
1173
+ next_cache = next_decoder_cache if use_cache else None
1174
+
1175
+ if not return_dict:
1176
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1177
+ return BaseModelOutputWithPast(
1178
+ last_hidden_state=hidden_states,
1179
+ past_key_values=next_cache,
1180
+ hidden_states=all_hidden_states,
1181
+ attentions=all_self_attns,
1182
+ )
1183
+
1184
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1185
+ def _update_causal_mask(
1186
+ self,
1187
+ attention_mask: torch.Tensor,
1188
+ input_tensor: torch.Tensor,
1189
+ cache_position: torch.Tensor,
1190
+ past_key_values: Cache,
1191
+ output_attentions: bool,
1192
+ ):
1193
+ if self.config._attn_implementation == "flash_attention_2":
1194
+ if attention_mask is not None and 0.0 in attention_mask:
1195
+ return attention_mask
1196
+ return None
1197
+
1198
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1199
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1200
+ # to infer the attention mask.
1201
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1202
+ using_static_cache = isinstance(past_key_values, StaticCache)
1203
+
1204
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1205
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1206
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1207
+ attention_mask,
1208
+ inputs_embeds=input_tensor,
1209
+ past_key_values_length=past_seen_tokens,
1210
+ is_training=self.training,
1211
+ ):
1212
+ return None
1213
+
1214
+ dtype, device = input_tensor.dtype, input_tensor.device
1215
+ min_dtype = torch.finfo(dtype).min
1216
+ sequence_length = input_tensor.shape[1]
1217
+ if using_static_cache:
1218
+ target_length = past_key_values.get_max_length()
1219
+ else:
1220
+ target_length = (
1221
+ attention_mask.shape[-1]
1222
+ if isinstance(attention_mask, torch.Tensor)
1223
+ else past_seen_tokens + sequence_length + 1
1224
+ )
1225
+
1226
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1227
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1228
+ attention_mask,
1229
+ sequence_length=sequence_length,
1230
+ target_length=target_length,
1231
+ dtype=dtype,
1232
+ device=device,
1233
+ min_dtype=min_dtype,
1234
+ cache_position=cache_position,
1235
+ batch_size=input_tensor.shape[0],
1236
+ )
1237
+
1238
+ if (
1239
+ self.config._attn_implementation == "sdpa"
1240
+ and attention_mask is not None
1241
+ and attention_mask.device.type == "cuda"
1242
+ and not output_attentions
1243
+ ):
1244
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1245
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1246
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1247
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1248
+
1249
+ return causal_mask
1250
+
1251
+
1252
+ QWEN2_VL_INPUTS_DOCSTRING = r"""
1253
+ Args:
1254
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1255
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1256
+ it.
1257
+
1258
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1259
+ [`PreTrainedTokenizer.__call__`] for details.
1260
+
1261
+ [What are input IDs?](../glossary#input-ids)
1262
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1263
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1264
+
1265
+ - 1 for tokens that are **not masked**,
1266
+ - 0 for tokens that are **masked**.
1267
+
1268
+ [What are attention masks?](../glossary#attention-mask)
1269
+
1270
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1271
+ [`PreTrainedTokenizer.__call__`] for details.
1272
+
1273
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1274
+ `past_key_values`).
1275
+
1276
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1277
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1278
+ information on the default strategy.
1279
+
1280
+ - 1 indicates the head is **not masked**,
1281
+ - 0 indicates the head is **masked**.
1282
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1283
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1284
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1285
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1286
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1287
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1288
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1289
+
1290
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1291
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1292
+
1293
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1294
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1295
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1296
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1297
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1298
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1299
+ model's internal embedding lookup matrix.
1300
+ use_cache (`bool`, *optional*):
1301
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1302
+ `past_key_values`).
1303
+ output_attentions (`bool`, *optional*):
1304
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1305
+ tensors for more detail.
1306
+ output_hidden_states (`bool`, *optional*):
1307
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1308
+ more detail.
1309
+ return_dict (`bool`, *optional*):
1310
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1311
+ pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)):
1312
+ The tensors corresponding to the input images. Pixel values can be obtained using
1313
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
1314
+ [`Qwen2VLImageProcessor`] for processing images.
1315
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
1316
+ The tensors corresponding to the input videos. Pixel values can be obtained using
1317
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
1318
+ [`Qwen2VLImageProcessor`] for processing videos.
1319
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1320
+ The temporal, height and width of feature shape of each image in LLM.
1321
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1322
+ The temporal, height and width of feature shape of each video in LLM.
1323
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1324
+ The rope index difference between sequence length and multimodal rope.
1325
+ """
1326
+
1327
+
1328
+ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
1329
+ _tied_weights_keys = ["lm_head.weight"]
1330
+
1331
+ def __init__(self, config):
1332
+ super().__init__(config)
1333
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
1334
+ config.vision_config, attn_implementation=config._attn_implementation
1335
+ )
1336
+ self.model = Qwen2VLModel(config)
1337
+ self.vocab_size = config.vocab_size
1338
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1339
+ self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
1340
+
1341
+ # Initialize weights and apply final processing
1342
+ self.post_init()
1343
+
1344
+ def get_input_embeddings(self):
1345
+ return self.model.embed_tokens
1346
+
1347
+ def set_input_embeddings(self, value):
1348
+ self.model.embed_tokens = value
1349
+
1350
+ def get_output_embeddings(self):
1351
+ return self.lm_head
1352
+
1353
+ def set_output_embeddings(self, new_embeddings):
1354
+ self.lm_head = new_embeddings
1355
+
1356
+ def set_decoder(self, decoder):
1357
+ self.model = decoder
1358
+
1359
+ def get_decoder(self):
1360
+ return self.model
1361
+
1362
+ def get_rope_index(
1363
+ self,
1364
+ input_ids: torch.LongTensor,
1365
+ image_grid_thw: Optional[torch.LongTensor] = None,
1366
+ video_grid_thw: Optional[torch.LongTensor] = None,
1367
+ attention_mask: Optional[torch.Tensor] = None,
1368
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1369
+ """
1370
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1371
+
1372
+ Explanation:
1373
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1374
+
1375
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
1376
+ Examples:
1377
+ input_ids: [T T T T T], here T is for text.
1378
+ temporal position_ids: [0, 1, 2, 3, 4]
1379
+ height position_ids: [0, 1, 2, 3, 4]
1380
+ width position_ids: [0, 1, 2, 3, 4]
1381
+
1382
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1383
+ and 1D rotary position embeddin for text part.
1384
+ Examples:
1385
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1386
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1387
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1388
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1389
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1390
+ text temporal position_ids: [3, 4, 5, 6, 7]
1391
+ text height position_ids: [3, 4, 5, 6, 7]
1392
+ text width position_ids: [3, 4, 5, 6, 7]
1393
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1394
+
1395
+ Args:
1396
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1397
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1398
+ it.
1399
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1400
+ The temporal, height and width of feature shape of each image in LLM.
1401
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1402
+ The temporal, height and width of feature shape of each video in LLM.
1403
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1404
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1405
+
1406
+ - 1 for tokens that are **not masked**,
1407
+ - 0 for tokens that are **masked**.
1408
+
1409
+ Returns:
1410
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1411
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1412
+ """
1413
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1414
+ image_token_id = self.config.image_token_id
1415
+ video_token_id = self.config.video_token_id
1416
+ vision_start_token_id = self.config.vision_start_token_id
1417
+ mrope_position_deltas = []
1418
+ if image_grid_thw is not None or video_grid_thw is not None:
1419
+ total_input_ids = input_ids
1420
+ position_ids = torch.ones(
1421
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1422
+ )
1423
+ image_index, video_index = 0, 0
1424
+ for i, input_ids in enumerate(total_input_ids):
1425
+ if attention_mask is not None:
1426
+ input_ids = input_ids[attention_mask[i] == 1]
1427
+ image_nums, video_nums = 0, 0
1428
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1429
+ vision_tokens = input_ids[vision_start_indices + 1]
1430
+ image_nums = (vision_tokens == image_token_id).sum()
1431
+ video_nums = (vision_tokens == video_token_id).sum()
1432
+ input_tokens = input_ids.tolist()
1433
+ llm_pos_ids_list: list = []
1434
+ st = 0
1435
+ remain_images, remain_videos = image_nums, video_nums
1436
+ for _ in range(image_nums + video_nums):
1437
+ if image_token_id in input_tokens and remain_images > 0:
1438
+ ed_image = input_tokens.index(image_token_id, st)
1439
+ else:
1440
+ ed_image = len(input_tokens) + 1
1441
+ if video_token_id in input_tokens and remain_videos > 0:
1442
+ ed_video = input_tokens.index(video_token_id, st)
1443
+ else:
1444
+ ed_video = len(input_tokens) + 1
1445
+ if ed_image < ed_video:
1446
+ t, h, w = (
1447
+ image_grid_thw[image_index][0],
1448
+ image_grid_thw[image_index][1],
1449
+ image_grid_thw[image_index][2],
1450
+ )
1451
+ image_index += 1
1452
+ remain_images -= 1
1453
+ ed = ed_image
1454
+ else:
1455
+ t, h, w = (
1456
+ video_grid_thw[video_index][0],
1457
+ video_grid_thw[video_index][1],
1458
+ video_grid_thw[video_index][2],
1459
+ )
1460
+ video_index += 1
1461
+ remain_videos -= 1
1462
+ ed = ed_video
1463
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1464
+ t.item(),
1465
+ h.item() // spatial_merge_size,
1466
+ w.item() // spatial_merge_size,
1467
+ )
1468
+ text_len = ed - st
1469
+
1470
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1471
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1472
+
1473
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1474
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1475
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1476
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1477
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1478
+
1479
+ if st < len(input_tokens):
1480
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1481
+ text_len = len(input_tokens) - st
1482
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1483
+
1484
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1485
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1486
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1487
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1488
+ return position_ids, mrope_position_deltas
1489
+ else:
1490
+ if attention_mask is not None:
1491
+ position_ids = attention_mask.long().cumsum(-1) - 1
1492
+ position_ids.masked_fill_(attention_mask == 0, 1)
1493
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1494
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1495
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1496
+ else:
1497
+ position_ids = (
1498
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1499
+ .view(1, 1, -1)
1500
+ .expand(3, input_ids.shape[0], -1)
1501
+ )
1502
+ mrope_position_deltas = torch.zeros(
1503
+ [input_ids.shape[0], 1],
1504
+ device=input_ids.device,
1505
+ dtype=input_ids.dtype,
1506
+ )
1507
+
1508
+ return position_ids, mrope_position_deltas
1509
+
1510
+ def _update_model_kwargs_for_generation(
1511
+ self,
1512
+ outputs: ModelOutput,
1513
+ model_kwargs: Dict[str, Any],
1514
+ is_encoder_decoder: bool = False,
1515
+ num_new_tokens: int = 1,
1516
+ ) -> Dict[str, Any]:
1517
+ model_kwargs = super()._update_model_kwargs_for_generation(
1518
+ outputs=outputs,
1519
+ model_kwargs=model_kwargs,
1520
+ is_encoder_decoder=is_encoder_decoder,
1521
+ num_new_tokens=num_new_tokens,
1522
+ )
1523
+
1524
+ if getattr(outputs, "rope_deltas", None) is not None:
1525
+ model_kwargs["rope_deltas"] = outputs.rope_deltas
1526
+
1527
+ return model_kwargs
1528
+
1529
+ @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
1530
+ @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1531
+ def forward(
1532
+ self,
1533
+ input_ids: torch.LongTensor = None,
1534
+ attention_mask: Optional[torch.Tensor] = None,
1535
+ position_ids: Optional[torch.LongTensor] = None,
1536
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1537
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1538
+ labels: Optional[torch.LongTensor] = None,
1539
+ use_cache: Optional[bool] = None,
1540
+ output_attentions: Optional[bool] = None,
1541
+ output_hidden_states: Optional[bool] = None,
1542
+ return_dict: Optional[bool] = None,
1543
+ pixel_values: Optional[torch.Tensor] = None,
1544
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1545
+ image_grid_thw: Optional[torch.LongTensor] = None,
1546
+ video_grid_thw: Optional[torch.LongTensor] = None,
1547
+ rope_deltas: Optional[torch.LongTensor] = None,
1548
+ ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
1549
+ r"""
1550
+ Args:
1551
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1552
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1553
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1554
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1555
+
1556
+ Returns:
1557
+
1558
+ Example:
1559
+
1560
+ ```python
1561
+ >>> from PIL import Image
1562
+ >>> import requests
1563
+ >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
1564
+
1565
+ >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
1566
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
1567
+
1568
+ >>> messages = [
1569
+ {
1570
+ "role": "user",
1571
+ "content": [
1572
+ {"type": "image"},
1573
+ {"type": "text", "text": "What is shown in this image?"},
1574
+ ],
1575
+ },
1576
+ ]
1577
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1578
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1579
+
1580
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1581
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1582
+
1583
+ >>> # Generate
1584
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1585
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1586
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1587
+ ```"""
1588
+
1589
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1590
+ output_hidden_states = (
1591
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1592
+ )
1593
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1594
+
1595
+ if inputs_embeds is None:
1596
+ inputs_embeds = self.model.embed_tokens(input_ids)
1597
+ if pixel_values is not None:
1598
+ pixel_values = pixel_values.type(self.visual.get_dtype())
1599
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
1600
+ image_mask = input_ids == self.config.image_token_id
1601
+ if self.training:
1602
+ inputs_embeds = inputs_embeds.clone()
1603
+ inputs_embeds[image_mask] = image_embeds
1604
+ if pixel_values_videos is not None:
1605
+ pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
1606
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
1607
+ video_mask = input_ids == self.config.video_token_id
1608
+ inputs_embeds[video_mask] = video_embeds
1609
+ if attention_mask is not None:
1610
+ attention_mask = attention_mask.to(inputs_embeds.device)
1611
+
1612
+ outputs = self.model(
1613
+ input_ids=None,
1614
+ position_ids=position_ids,
1615
+ attention_mask=attention_mask,
1616
+ past_key_values=past_key_values,
1617
+ inputs_embeds=inputs_embeds,
1618
+ use_cache=use_cache,
1619
+ output_attentions=output_attentions,
1620
+ output_hidden_states=output_hidden_states,
1621
+ return_dict=return_dict,
1622
+ )
1623
+
1624
+ hidden_states = outputs[0]
1625
+ logits = self.lm_head(hidden_states)
1626
+ logits = logits.float()
1627
+
1628
+ loss = None
1629
+ if labels is not None:
1630
+ # Shift so that tokens < n predict n
1631
+ shift_logits = logits[..., :-1, :].contiguous()
1632
+ shift_labels = labels[..., 1:].contiguous()
1633
+ # Flatten the tokens
1634
+ loss_fct = CrossEntropyLoss()
1635
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1636
+ shift_labels = shift_labels.view(-1)
1637
+ # Enable model parallelism
1638
+ shift_labels = shift_labels.to(shift_logits.device)
1639
+ loss = loss_fct(shift_logits, shift_labels)
1640
+
1641
+ if not return_dict:
1642
+ output = (logits,) + outputs[1:]
1643
+ return (loss,) + output if loss is not None else output
1644
+
1645
+ return Qwen2VLCausalLMOutputWithPast(
1646
+ loss=loss,
1647
+ logits=logits,
1648
+ past_key_values=outputs.past_key_values,
1649
+ hidden_states=outputs.hidden_states,
1650
+ attentions=outputs.attentions,
1651
+ rope_deltas=rope_deltas,
1652
+ )
1653
+
1654
+ def prepare_inputs_for_generation(
1655
+ self,
1656
+ input_ids,
1657
+ past_key_values=None,
1658
+ attention_mask=None,
1659
+ inputs_embeds=None,
1660
+ cache_position=None,
1661
+ position_ids=None,
1662
+ use_cache=True,
1663
+ pixel_values=None,
1664
+ pixel_values_videos=None,
1665
+ image_grid_thw=None,
1666
+ video_grid_thw=None,
1667
+ **kwargs,
1668
+ ):
1669
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1670
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1671
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1672
+ if past_key_values is not None:
1673
+ if inputs_embeds is not None: # Exception 1
1674
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1675
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1676
+ input_ids = input_ids[:, cache_position]
1677
+
1678
+ rope_deltas = kwargs.get("rope_deltas", None)
1679
+ if attention_mask is not None and position_ids is None:
1680
+ if cache_position is None or (cache_position is not None and cache_position[0] == 0):
1681
+ position_ids, rope_deltas = self.get_rope_index(
1682
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
1683
+ )
1684
+ else:
1685
+ batch_size, seq_length = input_ids.shape
1686
+ delta = (
1687
+ cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
1688
+ )
1689
+ position_ids = torch.arange(seq_length, device=input_ids.device)
1690
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1691
+ position_ids = position_ids.add(delta)
1692
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1693
+
1694
+ if cache_position[0] != 0:
1695
+ pixel_values = None
1696
+ pixel_values_videos = None
1697
+
1698
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1699
+ if inputs_embeds is not None and cache_position[0] == 0:
1700
+ model_inputs = {"inputs_embeds": inputs_embeds}
1701
+ else:
1702
+ model_inputs = {"input_ids": input_ids}
1703
+
1704
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1705
+ if inputs_embeds is not None:
1706
+ batch_size, sequence_length = inputs_embeds.shape
1707
+ device = inputs_embeds.device
1708
+ else:
1709
+ batch_size, sequence_length = input_ids.shape
1710
+ device = input_ids.device
1711
+
1712
+ dtype = self.lm_head.weight.dtype
1713
+ min_dtype = torch.finfo(dtype).min
1714
+
1715
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1716
+ attention_mask,
1717
+ sequence_length=sequence_length,
1718
+ target_length=past_key_values.get_max_length(),
1719
+ dtype=dtype,
1720
+ device=device,
1721
+ min_dtype=min_dtype,
1722
+ cache_position=cache_position,
1723
+ batch_size=batch_size,
1724
+ )
1725
+
1726
+ model_inputs.update(
1727
+ {
1728
+ "position_ids": position_ids,
1729
+ "past_key_values": past_key_values,
1730
+ "use_cache": use_cache,
1731
+ "attention_mask": attention_mask,
1732
+ "pixel_values": pixel_values,
1733
+ "pixel_values_videos": pixel_values_videos,
1734
+ "image_grid_thw": image_grid_thw,
1735
+ "video_grid_thw": video_grid_thw,
1736
+ "rope_deltas": rope_deltas,
1737
+ }
1738
+ )
1739
+ return model_inputs
1740
+
1741
+
1742
+ class Qwen2VLSimplifiedModel(Qwen2VLPreTrainedModel):
1743
+ def __init__(self, config):
1744
+ super().__init__(config)
1745
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
1746
+ config.vision_config, attn_implementation=config._attn_implementation
1747
+ )
1748
+ self.model = Qwen2VLModel(config)
1749
+ self.hidden_size = config.hidden_size
1750
+
1751
+ # 初始化权重
1752
+ self.post_init()
1753
+
1754
+ def get_input_embeddings(self):
1755
+ return self.model.embed_tokens
1756
+
1757
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
1758
+ # 移除生成相关的更新逻辑
1759
+ return model_kwargs
1760
+
1761
+ def forward(
1762
+ self,
1763
+ input_ids: torch.LongTensor = None,
1764
+ attention_mask: Optional[torch.Tensor] = None,
1765
+ position_ids: Optional[torch.LongTensor] = None,
1766
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1767
+ output_attentions: Optional[bool] = None,
1768
+ output_hidden_states: Optional[bool] = None,
1769
+ return_dict: Optional[bool] = None,
1770
+ pixel_values: Optional[torch.Tensor] = None,
1771
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1772
+ image_grid_thw: Optional[torch.LongTensor] = None,
1773
+ video_grid_thw: Optional[torch.LongTensor] = None
1774
+ ):
1775
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1776
+ output_hidden_states = (
1777
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1778
+ )
1779
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1780
+
1781
+ if inputs_embeds is None:
1782
+ inputs_embeds = self.model.embed_tokens(input_ids)
1783
+ if pixel_values is not None:
1784
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
1785
+ image_mask = input_ids == self.config.image_token_id
1786
+ inputs_embeds[image_mask] = image_embeds
1787
+ if attention_mask is not None:
1788
+ attention_mask = attention_mask.to(inputs_embeds.device)
1789
+
1790
+ if position_ids is None:
1791
+ position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
1792
+
1793
+ outputs = self.model(
1794
+ input_ids=None,
1795
+ position_ids=position_ids,
1796
+ attention_mask=attention_mask,
1797
+ inputs_embeds=inputs_embeds,
1798
+ output_attentions=output_attentions,
1799
+ output_hidden_states=output_hidden_states,
1800
+ return_dict=return_dict,
1801
+ )
1802
+ hidden_states = outputs[0]
1803
+
1804
+ return hidden_states, image_mask, image_grid_thw
1805
+
1806
+ def get_rope_index(
1807
+ self,
1808
+ input_ids: torch.LongTensor,
1809
+ image_grid_thw: Optional[torch.LongTensor] = None,
1810
+ video_grid_thw: Optional[torch.LongTensor] = None,
1811
+ attention_mask: Optional[torch.Tensor] = None,
1812
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1813
+ """
1814
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1815
+
1816
+ Explanation:
1817
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1818
+
1819
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
1820
+ Examples:
1821
+ input_ids: [T T T T T], here T is for text.
1822
+ temporal position_ids: [0, 1, 2, 3, 4]
1823
+ height position_ids: [0, 1, 2, 3, 4]
1824
+ width position_ids: [0, 1, 2, 3, 4]
1825
+
1826
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1827
+ and 1D rotary position embeddin for text part.
1828
+ Examples:
1829
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1830
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1831
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1832
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1833
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1834
+ text temporal position_ids: [3, 4, 5, 6, 7]
1835
+ text height position_ids: [3, 4, 5, 6, 7]
1836
+ text width position_ids: [3, 4, 5, 6, 7]
1837
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1838
+
1839
+ Args:
1840
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1841
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1842
+ it.
1843
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1844
+ The temporal, height and width of feature shape of each image in LLM.
1845
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1846
+ The temporal, height and width of feature shape of each video in LLM.
1847
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1848
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1849
+
1850
+ - 1 for tokens that are **not masked**,
1851
+ - 0 for tokens that are **masked**.
1852
+
1853
+ Returns:
1854
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1855
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1856
+ """
1857
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1858
+ image_token_id = self.config.image_token_id
1859
+ video_token_id = self.config.video_token_id
1860
+ vision_start_token_id = self.config.vision_start_token_id
1861
+ mrope_position_deltas = []
1862
+ if image_grid_thw is not None or video_grid_thw is not None:
1863
+ total_input_ids = input_ids
1864
+ position_ids = torch.ones(
1865
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1866
+ )
1867
+ image_index, video_index = 0, 0
1868
+ for i, input_ids in enumerate(total_input_ids):
1869
+ if attention_mask is not None:
1870
+ input_ids = input_ids[attention_mask[i] == 1]
1871
+ image_nums, video_nums = 0, 0
1872
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1873
+ vision_tokens = input_ids[vision_start_indices + 1]
1874
+ image_nums = (vision_tokens == image_token_id).sum()
1875
+ video_nums = (vision_tokens == video_token_id).sum()
1876
+ input_tokens = input_ids.tolist()
1877
+ llm_pos_ids_list: list = []
1878
+ st = 0
1879
+ remain_images, remain_videos = image_nums, video_nums
1880
+ for _ in range(image_nums + video_nums):
1881
+ if image_token_id in input_tokens and remain_images > 0:
1882
+ ed_image = input_tokens.index(image_token_id, st)
1883
+ else:
1884
+ ed_image = len(input_tokens) + 1
1885
+ if video_token_id in input_tokens and remain_videos > 0:
1886
+ ed_video = input_tokens.index(video_token_id, st)
1887
+ else:
1888
+ ed_video = len(input_tokens) + 1
1889
+ if ed_image < ed_video:
1890
+ t, h, w = (
1891
+ image_grid_thw[image_index][0],
1892
+ image_grid_thw[image_index][1],
1893
+ image_grid_thw[image_index][2],
1894
+ )
1895
+ image_index += 1
1896
+ remain_images -= 1
1897
+ ed = ed_image
1898
+ else:
1899
+ t, h, w = (
1900
+ video_grid_thw[video_index][0],
1901
+ video_grid_thw[video_index][1],
1902
+ video_grid_thw[video_index][2],
1903
+ )
1904
+ video_index += 1
1905
+ remain_videos -= 1
1906
+ ed = ed_video
1907
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1908
+ t.item(),
1909
+ h.item() // spatial_merge_size,
1910
+ w.item() // spatial_merge_size,
1911
+ )
1912
+ text_len = ed - st
1913
+
1914
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1915
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1916
+
1917
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1918
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1919
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1920
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1921
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1922
+
1923
+ if st < len(input_tokens):
1924
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1925
+ text_len = len(input_tokens) - st
1926
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1927
+
1928
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1929
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1930
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1931
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1932
+ return position_ids, mrope_position_deltas
1933
+ else:
1934
+ if attention_mask is not None:
1935
+ position_ids = attention_mask.long().cumsum(-1) - 1
1936
+ position_ids.masked_fill_(attention_mask == 0, 1)
1937
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1938
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1939
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1940
+ else:
1941
+ position_ids = (
1942
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1943
+ .view(1, 1, -1)
1944
+ .expand(3, input_ids.shape[0], -1)
1945
+ )
1946
+ mrope_position_deltas = torch.zeros(
1947
+ [input_ids.shape[0], 1],
1948
+ device=input_ids.device,
1949
+ dtype=input_ids.dtype,
1950
+ )
1951
+
1952
+ return position_ids, mrope_position_deltas
qwen2_vl/processing_qwen2_vl.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ Processor class for Qwen2-VL.
22
+ """
23
+
24
+ from typing import List, Optional, Union
25
+
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.image_utils import ImageInput, VideoInput
28
+ from transformers.processing_utils import ProcessorMixin
29
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
30
+ from transformers.utils import TensorType, logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class Qwen2VLProcessor(ProcessorMixin):
37
+ r"""
38
+ Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
39
+
40
+ [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
41
+ [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
42
+
43
+ Args:
44
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
45
+ The image processor is a required input.
46
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
47
+ The tokenizer is a required input.
48
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
49
+ in a chat into a tokenizable string.
50
+ """
51
+
52
+ attributes = ["image_processor", "tokenizer"]
53
+ valid_kwargs = ["chat_template"]
54
+ image_processor_class = "Qwen2VLImageProcessor"
55
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
56
+
57
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
58
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
59
+
60
+ def __call__(
61
+ self,
62
+ images: ImageInput = None,
63
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
64
+ videos: VideoInput = None,
65
+ padding: Union[bool, str, PaddingStrategy] = False,
66
+ truncation: Union[bool, str, TruncationStrategy] = None,
67
+ max_length: int = None,
68
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
69
+ ) -> BatchFeature:
70
+ """
71
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
72
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
73
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
74
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
75
+
76
+ Args:
77
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
78
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
79
+ tensor. Both channels-first and channels-last formats are supported.
80
+ text (`str`, `List[str]`, `List[List[str]]`):
81
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
82
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
83
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
84
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
85
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
86
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
87
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
88
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
89
+ index) among:
90
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
91
+ sequence if provided).
92
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
93
+ acceptable input length for the model if that argument is not provided.
94
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
95
+ lengths).
96
+ max_length (`int`, *optional*):
97
+ Maximum length of the returned list and optionally padding length (see above).
98
+ truncation (`bool`, *optional*):
99
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
100
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
101
+ If set, will return tensors of a particular framework. Acceptable values are:
102
+
103
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
104
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
105
+ - `'np'`: Return NumPy `np.ndarray` objects.
106
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
107
+
108
+ Returns:
109
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
110
+
111
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
112
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
113
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
114
+ `None`).
115
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
116
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
117
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
118
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
119
+ """
120
+ if images is not None:
121
+ image_inputs = self.image_processor(images=images, videos=None, return_tensors=return_tensors)
122
+ image_grid_thw = image_inputs["image_grid_thw"]
123
+ else:
124
+ image_inputs = {}
125
+ image_grid_thw = None
126
+
127
+ if videos is not None:
128
+ videos_inputs = self.image_processor(images=None, videos=videos, return_tensors=return_tensors)
129
+ video_grid_thw = videos_inputs["video_grid_thw"]
130
+ else:
131
+ videos_inputs = {}
132
+ video_grid_thw = None
133
+
134
+ if not isinstance(text, list):
135
+ text = [text]
136
+
137
+ if image_grid_thw is not None:
138
+ merge_length = self.image_processor.merge_size**2
139
+ index = 0
140
+ for i in range(len(text)):
141
+ while "<|image_pad|>" in text[i]:
142
+ text[i] = text[i].replace(
143
+ "<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
144
+ )
145
+ index += 1
146
+ text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
147
+
148
+ if video_grid_thw is not None:
149
+ merge_length = self.image_processor.merge_size**2
150
+ index = 0
151
+ for i in range(len(text)):
152
+ while "<|video_pad|>" in text[i]:
153
+ text[i] = text[i].replace(
154
+ "<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
155
+ )
156
+ index += 1
157
+ text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
158
+
159
+ text_inputs = self.tokenizer(
160
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
161
+ )
162
+
163
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
164
+
165
+ def batch_decode(self, *args, **kwargs):
166
+ """
167
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
168
+ refer to the docstring of this method for more information.
169
+ """
170
+ return self.tokenizer.batch_decode(*args, **kwargs)
171
+
172
+ def decode(self, *args, **kwargs):
173
+ """
174
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
175
+ the docstring of this method for more information.
176
+ """
177
+ return self.tokenizer.decode(*args, **kwargs)
178
+
179
+ @property
180
+ def model_input_names(self):
181
+ tokenizer_input_names = self.tokenizer.model_input_names
182
+ image_processor_input_names = self.image_processor.model_input_names
183
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ huggingface_hub
3
+ gradio_imageslider
4
+ requests
5
+ numpy<2
6
+ torch
7
+ torchvision
8
+ transformers
9
+ diffusers>=0.30.0
10
+ accelerate
11
+ Pillow
12
+ opencv-python-headless>=4.8.0
13
+ protobuf
14
+ sentencepiece
15
+ git+https://github.com/huggingface/controlnet_aux
16
+ mediapipe
17
+ sam2
18
+ optimum
19
+ optimum-quanto
20
+ matplotlib
21
+ xformers
22
+ bitsandbytes
technical-report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9cdff3966330d9a59053cd1c2163421d803e9802533061d3e6bca23a62054f8
3
+ size 66144700