Spaces:
Runtime error
Runtime error
Upload 46 files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +13 -12
- app.py +229 -0
- depth_anything_v2/dinov2.py +415 -0
- depth_anything_v2/dinov2_layers/__init__.py +11 -0
- depth_anything_v2/dinov2_layers/attention.py +83 -0
- depth_anything_v2/dinov2_layers/block.py +252 -0
- depth_anything_v2/dinov2_layers/drop_path.py +35 -0
- depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
- depth_anything_v2/dinov2_layers/mlp.py +41 -0
- depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
- depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
- depth_anything_v2/dpt.py +221 -0
- depth_anything_v2/util/blocks.py +148 -0
- depth_anything_v2/util/transform.py +158 -0
- flux-architecture.svg +169 -0
- flux/activations.py +165 -0
- flux/attention.py +843 -0
- flux/attention_processor.py +0 -0
- flux/controlnet_flux.py +617 -0
- flux/embeddings.py +1469 -0
- flux/flux_network.py +183 -0
- flux/lora/lora_base.py +752 -0
- flux/lora/lora_conversion_utils.py +328 -0
- flux/lora/lora_pipeline.py +0 -0
- flux/lora/peft.py +395 -0
- flux/normalization.py +393 -0
- flux/pipeline_flux.py +749 -0
- flux/pipeline_flux_chameleon.py +758 -0
- flux/pipeline_flux_controlnet.py +945 -0
- flux/pipeline_flux_controlnet_img2img.py +1002 -0
- flux/pipeline_flux_controlnet_inpainting.py +1199 -0
- flux/pipeline_flux_img2img.py +856 -0
- flux/pipeline_flux_inpaint.py +1021 -0
- flux/pipeline_output.py +21 -0
- flux/scheduling_flow_match_euler_discrete.py +325 -0
- flux/transformer_flux.py +572 -0
- main.py +154 -0
- model.py +644 -0
- modelmod.py +650 -0
- qwen2_vl/configuration_qwen2_vl.py +206 -0
- qwen2_vl/image_processing_qwen2_vl.py +458 -0
- qwen2_vl/modeling_qwen2_vl.py +1952 -0
- qwen2_vl/processing_qwen2_vl.py +183 -0
- requirements.txt +22 -0
- technical-report.pdf +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 erwold
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
1 |
+
---
|
2 |
+
title: Qwen2VL-Flux Zero (failure)
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import requests
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
import spaces
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from huggingface_hub import login
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
from gradio_imageslider import ImageSlider
|
14 |
+
|
15 |
+
import requests
|
16 |
+
from io import BytesIO
|
17 |
+
import PIL.Image
|
18 |
+
import requests
|
19 |
+
import shutil
|
20 |
+
import glob
|
21 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
22 |
+
|
23 |
+
MAX_SEED = np.iinfo(np.int32).max
|
24 |
+
IMAGE_SIZE = 1024
|
25 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
27 |
+
if HF_TOKEN: login(token=HF_TOKEN)
|
28 |
+
|
29 |
+
cp_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints')
|
30 |
+
snapshot_download("Djrango/Qwen2vl-Flux", local_dir=cp_dir)
|
31 |
+
hf_hub_download(repo_id="TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline", local_dir=f"{cp_dir}/anyline")
|
32 |
+
shutil.move("checkpoints/anyline/Anyline/MTEED.pth", f"{cp_dir}/anyline")
|
33 |
+
snapshot_download("depth-anything/Depth-Anything-V2-Large", local_dir=f"{cp_dir}/depth-anything-v2")
|
34 |
+
snapshot_download("facebook/sam2-hiera-large", local_dir=f"{cp_dir}/segment-anything-2")
|
35 |
+
# https://github.com/facebookresearch/sam2/issues/26
|
36 |
+
os.makedirs("sam2_configs", exist_ok=True)
|
37 |
+
for p in glob.glob(f"{cp_dir}/segment-anything-2/*.yaml"):
|
38 |
+
shutil.copy(p, "sam2_configs")
|
39 |
+
|
40 |
+
from modelmod import FluxModel
|
41 |
+
model = FluxModel(device=DEVICE, is_turbo=False, required_features=['controlnet', 'depth', 'line'], is_quantization=True) # , 'sam'
|
42 |
+
|
43 |
+
QWEN2VLFLUX_MODES = ["variation", "img2img", "inpaint", "controlnet", "controlnet-inpaint"]
|
44 |
+
QWEN2VLFLUX_ASPECT_RATIO = ["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"]
|
45 |
+
|
46 |
+
class calculateDuration:
|
47 |
+
def __init__(self, activity_name=""):
|
48 |
+
self.activity_name = activity_name
|
49 |
+
|
50 |
+
def __enter__(self):
|
51 |
+
self.start_time = time.time()
|
52 |
+
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
|
53 |
+
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
|
54 |
+
return self
|
55 |
+
|
56 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
57 |
+
self.end_time = time.time()
|
58 |
+
self.elapsed_time = self.end_time - self.start_time
|
59 |
+
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
|
60 |
+
|
61 |
+
if self.activity_name:
|
62 |
+
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
|
63 |
+
else:
|
64 |
+
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
65 |
+
|
66 |
+
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
|
67 |
+
|
68 |
+
def resize_image_dimensions(
|
69 |
+
original_resolution_wh: Tuple[int, int],
|
70 |
+
maximum_dimension: int = IMAGE_SIZE
|
71 |
+
) -> Tuple[int, int]:
|
72 |
+
width, height = original_resolution_wh
|
73 |
+
|
74 |
+
# if width <= maximum_dimension and height <= maximum_dimension:
|
75 |
+
# width = width - (width % 32)
|
76 |
+
# height = height - (height % 32)
|
77 |
+
# return width, height
|
78 |
+
|
79 |
+
if width > height:
|
80 |
+
scaling_factor = maximum_dimension / width
|
81 |
+
else:
|
82 |
+
scaling_factor = maximum_dimension / height
|
83 |
+
|
84 |
+
new_width = int(width * scaling_factor)
|
85 |
+
new_height = int(height * scaling_factor)
|
86 |
+
|
87 |
+
new_width = new_width - (new_width % 32)
|
88 |
+
new_height = new_height - (new_height % 32)
|
89 |
+
|
90 |
+
return new_width, new_height
|
91 |
+
|
92 |
+
def fetch_from_url(url: str, name: str):
|
93 |
+
try:
|
94 |
+
print(f"start to fetch {name} from url", url)
|
95 |
+
response = requests.get(url)
|
96 |
+
response.raise_for_status()
|
97 |
+
image = PIL.Image.open(BytesIO(response.content))
|
98 |
+
print(f"fetch {name} success")
|
99 |
+
return image
|
100 |
+
except Exception as e:
|
101 |
+
print(e)
|
102 |
+
return None
|
103 |
+
|
104 |
+
@spaces.GPU(duration=100)
|
105 |
+
@torch.inference_mode()
|
106 |
+
def process(
|
107 |
+
mode: str,
|
108 |
+
input_image_editor: dict,
|
109 |
+
ref_image: Image.Image,
|
110 |
+
image_url: str,
|
111 |
+
mask_url: str,
|
112 |
+
ref_url: str,
|
113 |
+
input_text: str,
|
114 |
+
strength: float,
|
115 |
+
num_inference_steps: int,
|
116 |
+
guidance_scale: float,
|
117 |
+
aspect_ratio: str,
|
118 |
+
attn_mode: bool,
|
119 |
+
center_x: float,
|
120 |
+
center_y: float,
|
121 |
+
radius: float,
|
122 |
+
line_mode: bool,
|
123 |
+
line_strength: float,
|
124 |
+
depth_mode: bool,
|
125 |
+
depth_strength: float,
|
126 |
+
progress=gr.Progress(track_tqdm=True)
|
127 |
+
):
|
128 |
+
#if not input_text:
|
129 |
+
# gr.Info("Please enter a text prompt.")
|
130 |
+
# return None
|
131 |
+
|
132 |
+
kwargs = {}
|
133 |
+
|
134 |
+
image = input_image_editor['background']
|
135 |
+
mask = input_image_editor['layers'][0]
|
136 |
+
|
137 |
+
if image_url: image = fetch_from_url(image_url, "image")
|
138 |
+
if mask_url: mask = fetch_from_url(mask_url, "mask")
|
139 |
+
if ref_url: ref_image = fetch_from_url(ref_url, "refernce image")
|
140 |
+
|
141 |
+
if not image:
|
142 |
+
gr.Info("Please upload an image.")
|
143 |
+
return None
|
144 |
+
|
145 |
+
if ref_image: kwargs["input_image_b"] = ref_image
|
146 |
+
if mode == "inpaint" or mode == "controlnet-inpaint":
|
147 |
+
if not mask:
|
148 |
+
gr.Info("Please draw a mask on the image.")
|
149 |
+
return None
|
150 |
+
kwargs["mask_image"] = mask
|
151 |
+
|
152 |
+
if attn_mode:
|
153 |
+
kwargs["center_x"] = center_x
|
154 |
+
kwargs["center_y"] = center_y
|
155 |
+
kwargs["radius"] = radius
|
156 |
+
|
157 |
+
with calculateDuration("run inference"):
|
158 |
+
result = model.generate(
|
159 |
+
input_image_a=image,
|
160 |
+
prompt=input_text,
|
161 |
+
guidance_scale=guidance_scale,
|
162 |
+
num_inference_steps=num_inference_steps,
|
163 |
+
aspect_ratio=aspect_ratio,
|
164 |
+
mode=mode,
|
165 |
+
denoise_strength=strength,
|
166 |
+
line_mode=line_mode,
|
167 |
+
line_strength=line_strength,
|
168 |
+
depth_mode=depth_mode,
|
169 |
+
depth_strength=depth_strength,
|
170 |
+
imageCount=1,
|
171 |
+
**kwargs
|
172 |
+
)[0]
|
173 |
+
|
174 |
+
#return result
|
175 |
+
return [image, result]
|
176 |
+
|
177 |
+
CSS = """
|
178 |
+
.title { text-align: center; }
|
179 |
+
"""
|
180 |
+
|
181 |
+
with gr.Blocks(fill_width=True, css=CSS) as demo:
|
182 |
+
gr.Markdown("# Qwen2VL-Flux", elem_classes="title")
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column():
|
185 |
+
gen_mode = gr.Radio(label="Generation mode", choices=QWEN2VLFLUX_MODES, value="variation")
|
186 |
+
with gr.Row():
|
187 |
+
input_image_editor = gr.ImageEditor(label='Image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB',
|
188 |
+
layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
189 |
+
ref_image = gr.Image(label='Reference image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB')
|
190 |
+
with gr.Accordion("Image from URL", open=False):
|
191 |
+
image_url = gr.Textbox(label="Image url", show_label=True, max_lines=1, placeholder="Enter your image url (Optional)")
|
192 |
+
mask_url = gr.Textbox(label="Mask image url", show_label=True, max_lines=1, placeholder="Enter your mask image url (Optional)")
|
193 |
+
ref_url = gr.Textbox(label="Reference image url", show_label=True, max_lines=1, placeholder="Enter your reference image url (Optional)")
|
194 |
+
|
195 |
+
with gr.Accordion("Prompt Settings", open=True):
|
196 |
+
input_text = gr.Textbox(label="Prompt", show_label=True, max_lines=1, placeholder="Enter your prompt")
|
197 |
+
submit_button = gr.Button(value='Submit', variant='primary')
|
198 |
+
|
199 |
+
with gr.Accordion("Advanced Settings", open=True):
|
200 |
+
with gr.Row():
|
201 |
+
denoise_strength = gr.Slider(label="Denoise strength", minimum=0, maximum=1, step=0.01, value=0.75)
|
202 |
+
aspect_ratio = gr.Radio(label="Output image ratio", choices=QWEN2VLFLUX_ASPECT_RATIO, value="1:1")
|
203 |
+
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28)
|
204 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.5, value=3.5)
|
205 |
+
with gr.Accordion("Attention Control", open=True):
|
206 |
+
with gr.Row():
|
207 |
+
attn_mode = gr.Checkbox(label="Attention Control", value=False)
|
208 |
+
center_x = gr.Slider(label="X coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
|
209 |
+
center_y = gr.Slider(label="Y coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
|
210 |
+
radius = gr.Slider(label="Radius of attention circle", minimum=0, maximum=1, step=0.01, value=0.5)
|
211 |
+
with gr.Accordion("ControlNet Settings", open=True):
|
212 |
+
with gr.Row():
|
213 |
+
line_mode = gr.Checkbox(label="Line mode", value=True)
|
214 |
+
line_strength = gr.Slider(label="Line strength", minimum=0, maximum=1, step=0.01, value=0.4)
|
215 |
+
depth_mode = gr.Checkbox(label="Depth mode", value=True)
|
216 |
+
depth_strength = gr.Slider(label="Depth strength", minimum=0, maximum=1, step=0.01, value=0.2)
|
217 |
+
|
218 |
+
with gr.Column():
|
219 |
+
#output_image = gr.Image(label="Generated image", type="pil", format="png", show_download_button=True, show_share_button=False)
|
220 |
+
output_image = ImageSlider(label="Generated image", type="pil")
|
221 |
+
|
222 |
+
gr.on(triggers=[submit_button.click, input_text.submit], fn=process,
|
223 |
+
inputs=[gen_mode, input_image_editor, ref_image, image_url, mask_url, ref_url,
|
224 |
+
input_text, denoise_strength, num_inference_steps, guidance_scale, aspect_ratio,
|
225 |
+
attn_mode, center_x, center_y, radius, line_mode, line_strength, depth_mode, depth_strength],
|
226 |
+
outputs=[output_image], queue=True)
|
227 |
+
|
228 |
+
demo.queue().launch(debug=True, show_error=True)
|
229 |
+
#demo.queue().launch(debug=True, show_error=True, ssr_mode=False) # Gradio 5
|
depth_anything_v2/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
depth_anything_v2/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
depth_anything_v2/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
depth_anything_v2/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
depth_anything_v2/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
depth_anything_v2/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
depth_anything_v2/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
depth_anything_v2/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
depth_anything_v2/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
depth_anything_v2/dpt.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import Compose
|
6 |
+
|
7 |
+
from .dinov2 import DINOv2
|
8 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
9 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
10 |
+
|
11 |
+
|
12 |
+
def _make_fusion_block(features, use_bn, size=None):
|
13 |
+
return FeatureFusionBlock(
|
14 |
+
features,
|
15 |
+
nn.ReLU(False),
|
16 |
+
deconv=False,
|
17 |
+
bn=use_bn,
|
18 |
+
expand=False,
|
19 |
+
align_corners=True,
|
20 |
+
size=size,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class ConvBlock(nn.Module):
|
25 |
+
def __init__(self, in_feature, out_feature):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.conv_block = nn.Sequential(
|
29 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
30 |
+
nn.BatchNorm2d(out_feature),
|
31 |
+
nn.ReLU(True)
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.conv_block(x)
|
36 |
+
|
37 |
+
|
38 |
+
class DPTHead(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_channels,
|
42 |
+
features=256,
|
43 |
+
use_bn=False,
|
44 |
+
out_channels=[256, 512, 1024, 1024],
|
45 |
+
use_clstoken=False
|
46 |
+
):
|
47 |
+
super(DPTHead, self).__init__()
|
48 |
+
|
49 |
+
self.use_clstoken = use_clstoken
|
50 |
+
|
51 |
+
self.projects = nn.ModuleList([
|
52 |
+
nn.Conv2d(
|
53 |
+
in_channels=in_channels,
|
54 |
+
out_channels=out_channel,
|
55 |
+
kernel_size=1,
|
56 |
+
stride=1,
|
57 |
+
padding=0,
|
58 |
+
) for out_channel in out_channels
|
59 |
+
])
|
60 |
+
|
61 |
+
self.resize_layers = nn.ModuleList([
|
62 |
+
nn.ConvTranspose2d(
|
63 |
+
in_channels=out_channels[0],
|
64 |
+
out_channels=out_channels[0],
|
65 |
+
kernel_size=4,
|
66 |
+
stride=4,
|
67 |
+
padding=0),
|
68 |
+
nn.ConvTranspose2d(
|
69 |
+
in_channels=out_channels[1],
|
70 |
+
out_channels=out_channels[1],
|
71 |
+
kernel_size=2,
|
72 |
+
stride=2,
|
73 |
+
padding=0),
|
74 |
+
nn.Identity(),
|
75 |
+
nn.Conv2d(
|
76 |
+
in_channels=out_channels[3],
|
77 |
+
out_channels=out_channels[3],
|
78 |
+
kernel_size=3,
|
79 |
+
stride=2,
|
80 |
+
padding=1)
|
81 |
+
])
|
82 |
+
|
83 |
+
if use_clstoken:
|
84 |
+
self.readout_projects = nn.ModuleList()
|
85 |
+
for _ in range(len(self.projects)):
|
86 |
+
self.readout_projects.append(
|
87 |
+
nn.Sequential(
|
88 |
+
nn.Linear(2 * in_channels, in_channels),
|
89 |
+
nn.GELU()))
|
90 |
+
|
91 |
+
self.scratch = _make_scratch(
|
92 |
+
out_channels,
|
93 |
+
features,
|
94 |
+
groups=1,
|
95 |
+
expand=False,
|
96 |
+
)
|
97 |
+
|
98 |
+
self.scratch.stem_transpose = None
|
99 |
+
|
100 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
101 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
102 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
103 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
104 |
+
|
105 |
+
head_features_1 = features
|
106 |
+
head_features_2 = 32
|
107 |
+
|
108 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
109 |
+
self.scratch.output_conv2 = nn.Sequential(
|
110 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
111 |
+
nn.ReLU(True),
|
112 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
113 |
+
nn.ReLU(True),
|
114 |
+
nn.Identity(),
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, out_features, patch_h, patch_w):
|
118 |
+
out = []
|
119 |
+
for i, x in enumerate(out_features):
|
120 |
+
if self.use_clstoken:
|
121 |
+
x, cls_token = x[0], x[1]
|
122 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
123 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
124 |
+
else:
|
125 |
+
x = x[0]
|
126 |
+
|
127 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
128 |
+
|
129 |
+
x = self.projects[i](x)
|
130 |
+
x = self.resize_layers[i](x)
|
131 |
+
|
132 |
+
out.append(x)
|
133 |
+
|
134 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
135 |
+
|
136 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
137 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
138 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
139 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
140 |
+
|
141 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
142 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
143 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
144 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
145 |
+
|
146 |
+
out = self.scratch.output_conv1(path_1)
|
147 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
148 |
+
out = self.scratch.output_conv2(out)
|
149 |
+
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class DepthAnythingV2(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
encoder='vitl',
|
157 |
+
features=256,
|
158 |
+
out_channels=[256, 512, 1024, 1024],
|
159 |
+
use_bn=False,
|
160 |
+
use_clstoken=False
|
161 |
+
):
|
162 |
+
super(DepthAnythingV2, self).__init__()
|
163 |
+
|
164 |
+
self.intermediate_layer_idx = {
|
165 |
+
'vits': [2, 5, 8, 11],
|
166 |
+
'vitb': [2, 5, 8, 11],
|
167 |
+
'vitl': [4, 11, 17, 23],
|
168 |
+
'vitg': [9, 19, 29, 39]
|
169 |
+
}
|
170 |
+
|
171 |
+
self.encoder = encoder
|
172 |
+
self.pretrained = DINOv2(model_name=encoder)
|
173 |
+
|
174 |
+
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
178 |
+
|
179 |
+
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
180 |
+
|
181 |
+
depth = self.depth_head(features, patch_h, patch_w)
|
182 |
+
depth = F.relu(depth)
|
183 |
+
|
184 |
+
return depth.squeeze(1)
|
185 |
+
|
186 |
+
@torch.no_grad()
|
187 |
+
def infer_image(self, raw_image, input_size=518):
|
188 |
+
image, (h, w) = self.image2tensor(raw_image, input_size)
|
189 |
+
|
190 |
+
depth = self.forward(image)
|
191 |
+
|
192 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
193 |
+
|
194 |
+
return depth.cpu().numpy()
|
195 |
+
|
196 |
+
def image2tensor(self, raw_image, input_size=518):
|
197 |
+
transform = Compose([
|
198 |
+
Resize(
|
199 |
+
width=input_size,
|
200 |
+
height=input_size,
|
201 |
+
resize_target=False,
|
202 |
+
keep_aspect_ratio=True,
|
203 |
+
ensure_multiple_of=14,
|
204 |
+
resize_method='lower_bound',
|
205 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
206 |
+
),
|
207 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
208 |
+
PrepareForNet(),
|
209 |
+
])
|
210 |
+
|
211 |
+
h, w = raw_image.shape[:2]
|
212 |
+
|
213 |
+
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
|
214 |
+
|
215 |
+
image = transform({'image': image})['image']
|
216 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
217 |
+
|
218 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
219 |
+
image = image.to(DEVICE)
|
220 |
+
|
221 |
+
return image, (h, w)
|
depth_anything_v2/util/blocks.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape * 2
|
16 |
+
out_shape3 = out_shape * 4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape * 8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
21 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
22 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
23 |
+
if len(in_shape) >= 4:
|
24 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
25 |
+
|
26 |
+
return scratch
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualConvUnit(nn.Module):
|
30 |
+
"""Residual convolution module.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, features, activation, bn):
|
34 |
+
"""Init.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
features (int): number of features
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.bn = bn
|
42 |
+
|
43 |
+
self.groups=1
|
44 |
+
|
45 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
46 |
+
|
47 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
48 |
+
|
49 |
+
if self.bn == True:
|
50 |
+
self.bn1 = nn.BatchNorm2d(features)
|
51 |
+
self.bn2 = nn.BatchNorm2d(features)
|
52 |
+
|
53 |
+
self.activation = activation
|
54 |
+
|
55 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
"""Forward pass.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x (tensor): input
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
tensor: output
|
65 |
+
"""
|
66 |
+
|
67 |
+
out = self.activation(x)
|
68 |
+
out = self.conv1(out)
|
69 |
+
if self.bn == True:
|
70 |
+
out = self.bn1(out)
|
71 |
+
|
72 |
+
out = self.activation(out)
|
73 |
+
out = self.conv2(out)
|
74 |
+
if self.bn == True:
|
75 |
+
out = self.bn2(out)
|
76 |
+
|
77 |
+
if self.groups > 1:
|
78 |
+
out = self.conv_merge(out)
|
79 |
+
|
80 |
+
return self.skip_add.add(out, x)
|
81 |
+
|
82 |
+
|
83 |
+
class FeatureFusionBlock(nn.Module):
|
84 |
+
"""Feature fusion block.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
features,
|
90 |
+
activation,
|
91 |
+
deconv=False,
|
92 |
+
bn=False,
|
93 |
+
expand=False,
|
94 |
+
align_corners=True,
|
95 |
+
size=None
|
96 |
+
):
|
97 |
+
"""Init.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
features (int): number of features
|
101 |
+
"""
|
102 |
+
super(FeatureFusionBlock, self).__init__()
|
103 |
+
|
104 |
+
self.deconv = deconv
|
105 |
+
self.align_corners = align_corners
|
106 |
+
|
107 |
+
self.groups=1
|
108 |
+
|
109 |
+
self.expand = expand
|
110 |
+
out_features = features
|
111 |
+
if self.expand == True:
|
112 |
+
out_features = features // 2
|
113 |
+
|
114 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
115 |
+
|
116 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
117 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
118 |
+
|
119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
120 |
+
|
121 |
+
self.size=size
|
122 |
+
|
123 |
+
def forward(self, *xs, size=None):
|
124 |
+
"""Forward pass.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tensor: output
|
128 |
+
"""
|
129 |
+
output = xs[0]
|
130 |
+
|
131 |
+
if len(xs) == 2:
|
132 |
+
res = self.resConfUnit1(xs[1])
|
133 |
+
output = self.skip_add.add(output, res)
|
134 |
+
|
135 |
+
output = self.resConfUnit2(output)
|
136 |
+
|
137 |
+
if (size is None) and (self.size is None):
|
138 |
+
modifier = {"scale_factor": 2}
|
139 |
+
elif size is None:
|
140 |
+
modifier = {"size": self.size}
|
141 |
+
else:
|
142 |
+
modifier = {"size": size}
|
143 |
+
|
144 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
145 |
+
|
146 |
+
output = self.out_conv(output)
|
147 |
+
|
148 |
+
return output
|
depth_anything_v2/util/transform.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class Resize(object):
|
6 |
+
"""Resize sample to given size (width, height).
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
width,
|
12 |
+
height,
|
13 |
+
resize_target=True,
|
14 |
+
keep_aspect_ratio=False,
|
15 |
+
ensure_multiple_of=1,
|
16 |
+
resize_method="lower_bound",
|
17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
18 |
+
):
|
19 |
+
"""Init.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
width (int): desired output width
|
23 |
+
height (int): desired output height
|
24 |
+
resize_target (bool, optional):
|
25 |
+
True: Resize the full sample (image, mask, target).
|
26 |
+
False: Resize image only.
|
27 |
+
Defaults to True.
|
28 |
+
keep_aspect_ratio (bool, optional):
|
29 |
+
True: Keep the aspect ratio of the input sample.
|
30 |
+
Output sample might not have the given width and height, and
|
31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
32 |
+
Defaults to False.
|
33 |
+
ensure_multiple_of (int, optional):
|
34 |
+
Output width and height is constrained to be multiple of this parameter.
|
35 |
+
Defaults to 1.
|
36 |
+
resize_method (str, optional):
|
37 |
+
"lower_bound": Output will be at least as large as the given size.
|
38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
40 |
+
Defaults to "lower_bound".
|
41 |
+
"""
|
42 |
+
self.__width = width
|
43 |
+
self.__height = height
|
44 |
+
|
45 |
+
self.__resize_target = resize_target
|
46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
47 |
+
self.__multiple_of = ensure_multiple_of
|
48 |
+
self.__resize_method = resize_method
|
49 |
+
self.__image_interpolation_method = image_interpolation_method
|
50 |
+
|
51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
53 |
+
|
54 |
+
if max_val is not None and y > max_val:
|
55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
56 |
+
|
57 |
+
if y < min_val:
|
58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
def get_size(self, width, height):
|
63 |
+
# determine new height and width
|
64 |
+
scale_height = self.__height / height
|
65 |
+
scale_width = self.__width / width
|
66 |
+
|
67 |
+
if self.__keep_aspect_ratio:
|
68 |
+
if self.__resize_method == "lower_bound":
|
69 |
+
# scale such that output size is lower bound
|
70 |
+
if scale_width > scale_height:
|
71 |
+
# fit width
|
72 |
+
scale_height = scale_width
|
73 |
+
else:
|
74 |
+
# fit height
|
75 |
+
scale_width = scale_height
|
76 |
+
elif self.__resize_method == "upper_bound":
|
77 |
+
# scale such that output size is upper bound
|
78 |
+
if scale_width < scale_height:
|
79 |
+
# fit width
|
80 |
+
scale_height = scale_width
|
81 |
+
else:
|
82 |
+
# fit height
|
83 |
+
scale_width = scale_height
|
84 |
+
elif self.__resize_method == "minimal":
|
85 |
+
# scale as least as possbile
|
86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
87 |
+
# fit width
|
88 |
+
scale_height = scale_width
|
89 |
+
else:
|
90 |
+
# fit height
|
91 |
+
scale_width = scale_height
|
92 |
+
else:
|
93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
94 |
+
|
95 |
+
if self.__resize_method == "lower_bound":
|
96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
98 |
+
elif self.__resize_method == "upper_bound":
|
99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
101 |
+
elif self.__resize_method == "minimal":
|
102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
106 |
+
|
107 |
+
return (new_width, new_height)
|
108 |
+
|
109 |
+
def __call__(self, sample):
|
110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
111 |
+
|
112 |
+
# resize sample
|
113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
114 |
+
|
115 |
+
if self.__resize_target:
|
116 |
+
if "depth" in sample:
|
117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
118 |
+
|
119 |
+
if "mask" in sample:
|
120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
121 |
+
|
122 |
+
return sample
|
123 |
+
|
124 |
+
|
125 |
+
class NormalizeImage(object):
|
126 |
+
"""Normlize image by given mean and std.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.__mean = mean
|
131 |
+
self.__std = std
|
132 |
+
|
133 |
+
def __call__(self, sample):
|
134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
135 |
+
|
136 |
+
return sample
|
137 |
+
|
138 |
+
|
139 |
+
class PrepareForNet(object):
|
140 |
+
"""Prepare sample for usage as network input.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
def __call__(self, sample):
|
147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
149 |
+
|
150 |
+
if "depth" in sample:
|
151 |
+
depth = sample["depth"].astype(np.float32)
|
152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
153 |
+
|
154 |
+
if "mask" in sample:
|
155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
157 |
+
|
158 |
+
return sample
|
flux-architecture.svg
ADDED
flux/activations.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import deprecate
|
21 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
22 |
+
|
23 |
+
|
24 |
+
if is_torch_npu_available():
|
25 |
+
import torch_npu
|
26 |
+
|
27 |
+
ACTIVATION_FUNCTIONS = {
|
28 |
+
"swish": nn.SiLU(),
|
29 |
+
"silu": nn.SiLU(),
|
30 |
+
"mish": nn.Mish(),
|
31 |
+
"gelu": nn.GELU(),
|
32 |
+
"relu": nn.ReLU(),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def get_activation(act_fn: str) -> nn.Module:
|
37 |
+
"""Helper function to get activation function from string.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
act_fn (str): Name of activation function.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
nn.Module: Activation function.
|
44 |
+
"""
|
45 |
+
|
46 |
+
act_fn = act_fn.lower()
|
47 |
+
if act_fn in ACTIVATION_FUNCTIONS:
|
48 |
+
return ACTIVATION_FUNCTIONS[act_fn]
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
51 |
+
|
52 |
+
|
53 |
+
class FP32SiLU(nn.Module):
|
54 |
+
r"""
|
55 |
+
SiLU activation function with input upcasted to torch.float32.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
62 |
+
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
|
63 |
+
|
64 |
+
|
65 |
+
class GELU(nn.Module):
|
66 |
+
r"""
|
67 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
dim_in (`int`): The number of channels in the input.
|
71 |
+
dim_out (`int`): The number of channels in the output.
|
72 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
73 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
77 |
+
super().__init__()
|
78 |
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
79 |
+
self.approximate = approximate
|
80 |
+
|
81 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
82 |
+
if gate.device.type != "mps":
|
83 |
+
return F.gelu(gate, approximate=self.approximate)
|
84 |
+
# mps: gelu is not implemented for float16
|
85 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
86 |
+
|
87 |
+
def forward(self, hidden_states):
|
88 |
+
hidden_states = self.proj(hidden_states)
|
89 |
+
hidden_states = self.gelu(hidden_states)
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
|
93 |
+
class GEGLU(nn.Module):
|
94 |
+
r"""
|
95 |
+
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
|
96 |
+
|
97 |
+
Parameters:
|
98 |
+
dim_in (`int`): The number of channels in the input.
|
99 |
+
dim_out (`int`): The number of channels in the output.
|
100 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
104 |
+
super().__init__()
|
105 |
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
106 |
+
|
107 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
108 |
+
if gate.device.type != "mps":
|
109 |
+
return F.gelu(gate)
|
110 |
+
# mps: gelu is not implemented for float16
|
111 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
112 |
+
|
113 |
+
def forward(self, hidden_states, *args, **kwargs):
|
114 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
115 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
116 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
117 |
+
hidden_states = self.proj(hidden_states)
|
118 |
+
if is_torch_npu_available():
|
119 |
+
# using torch_npu.npu_geglu can run faster and save memory on NPU.
|
120 |
+
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
|
121 |
+
else:
|
122 |
+
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
123 |
+
return hidden_states * self.gelu(gate)
|
124 |
+
|
125 |
+
|
126 |
+
class SwiGLU(nn.Module):
|
127 |
+
r"""
|
128 |
+
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
|
129 |
+
but uses SiLU / Swish instead of GeLU.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
dim_in (`int`): The number of channels in the input.
|
133 |
+
dim_out (`int`): The number of channels in the output.
|
134 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
138 |
+
super().__init__()
|
139 |
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
140 |
+
self.activation = nn.SiLU()
|
141 |
+
|
142 |
+
def forward(self, hidden_states):
|
143 |
+
hidden_states = self.proj(hidden_states)
|
144 |
+
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
145 |
+
return hidden_states * self.activation(gate)
|
146 |
+
|
147 |
+
|
148 |
+
class ApproximateGELU(nn.Module):
|
149 |
+
r"""
|
150 |
+
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
|
151 |
+
[paper](https://arxiv.org/abs/1606.08415).
|
152 |
+
|
153 |
+
Parameters:
|
154 |
+
dim_in (`int`): The number of channels in the input.
|
155 |
+
dim_out (`int`): The number of channels in the output.
|
156 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
160 |
+
super().__init__()
|
161 |
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164 |
+
x = self.proj(x)
|
165 |
+
return x * torch.sigmoid(1.702 * x)
|
flux/attention.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import deprecate, logging
|
21 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
22 |
+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23 |
+
from .attention_processor import Attention, JointAttnProcessor2_0
|
24 |
+
from .embeddings import SinusoidalPositionalEmbedding
|
25 |
+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32 |
+
# "feed_forward_chunk_size" can be used to save memory
|
33 |
+
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
34 |
+
raise ValueError(
|
35 |
+
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
36 |
+
)
|
37 |
+
|
38 |
+
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
39 |
+
ff_output = torch.cat(
|
40 |
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
41 |
+
dim=chunk_dim,
|
42 |
+
)
|
43 |
+
return ff_output
|
44 |
+
|
45 |
+
|
46 |
+
@maybe_allow_in_graph
|
47 |
+
class GatedSelfAttentionDense(nn.Module):
|
48 |
+
r"""
|
49 |
+
A gated self-attention dense layer that combines visual features and object features.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
query_dim (`int`): The number of channels in the query.
|
53 |
+
context_dim (`int`): The number of channels in the context.
|
54 |
+
n_heads (`int`): The number of heads to use for attention.
|
55 |
+
d_head (`int`): The number of channels in each head.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
62 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
63 |
+
|
64 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
65 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
66 |
+
|
67 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
68 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
69 |
+
|
70 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
71 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
72 |
+
|
73 |
+
self.enabled = True
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
76 |
+
if not self.enabled:
|
77 |
+
return x
|
78 |
+
|
79 |
+
n_visual = x.shape[1]
|
80 |
+
objs = self.linear(objs)
|
81 |
+
|
82 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
83 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
84 |
+
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
@maybe_allow_in_graph
|
89 |
+
class JointTransformerBlock(nn.Module):
|
90 |
+
r"""
|
91 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
92 |
+
|
93 |
+
Reference: https://arxiv.org/abs/2403.03206
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
dim (`int`): The number of channels in the input and output.
|
97 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
98 |
+
attention_head_dim (`int`): The number of channels in each head.
|
99 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
100 |
+
processing of `context` conditions.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
self.context_pre_only = context_pre_only
|
107 |
+
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
108 |
+
|
109 |
+
self.norm1 = AdaLayerNormZero(dim)
|
110 |
+
|
111 |
+
if context_norm_type == "ada_norm_continous":
|
112 |
+
self.norm1_context = AdaLayerNormContinuous(
|
113 |
+
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
114 |
+
)
|
115 |
+
elif context_norm_type == "ada_norm_zero":
|
116 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
117 |
+
else:
|
118 |
+
raise ValueError(
|
119 |
+
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
120 |
+
)
|
121 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
122 |
+
processor = JointAttnProcessor2_0()
|
123 |
+
else:
|
124 |
+
raise ValueError(
|
125 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
126 |
+
)
|
127 |
+
self.attn = Attention(
|
128 |
+
query_dim=dim,
|
129 |
+
cross_attention_dim=None,
|
130 |
+
added_kv_proj_dim=dim,
|
131 |
+
dim_head=attention_head_dim,
|
132 |
+
heads=num_attention_heads,
|
133 |
+
out_dim=dim,
|
134 |
+
context_pre_only=context_pre_only,
|
135 |
+
bias=True,
|
136 |
+
processor=processor,
|
137 |
+
)
|
138 |
+
|
139 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
140 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
141 |
+
|
142 |
+
if not context_pre_only:
|
143 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
144 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
145 |
+
else:
|
146 |
+
self.norm2_context = None
|
147 |
+
self.ff_context = None
|
148 |
+
|
149 |
+
# let chunk size default to None
|
150 |
+
self._chunk_size = None
|
151 |
+
self._chunk_dim = 0
|
152 |
+
|
153 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
154 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
155 |
+
# Sets chunk feed-forward
|
156 |
+
self._chunk_size = chunk_size
|
157 |
+
self._chunk_dim = dim
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
161 |
+
):
|
162 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
163 |
+
|
164 |
+
if self.context_pre_only:
|
165 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
166 |
+
else:
|
167 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
168 |
+
encoder_hidden_states, emb=temb
|
169 |
+
)
|
170 |
+
|
171 |
+
# Attention.
|
172 |
+
attn_output, context_attn_output = self.attn(
|
173 |
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
174 |
+
)
|
175 |
+
|
176 |
+
# Process attention outputs for the `hidden_states`.
|
177 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
178 |
+
hidden_states = hidden_states + attn_output
|
179 |
+
|
180 |
+
norm_hidden_states = self.norm2(hidden_states)
|
181 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
182 |
+
if self._chunk_size is not None:
|
183 |
+
# "feed_forward_chunk_size" can be used to save memory
|
184 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
185 |
+
else:
|
186 |
+
ff_output = self.ff(norm_hidden_states)
|
187 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
188 |
+
|
189 |
+
hidden_states = hidden_states + ff_output
|
190 |
+
|
191 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
192 |
+
if self.context_pre_only:
|
193 |
+
encoder_hidden_states = None
|
194 |
+
else:
|
195 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
196 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
197 |
+
|
198 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
199 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
200 |
+
if self._chunk_size is not None:
|
201 |
+
# "feed_forward_chunk_size" can be used to save memory
|
202 |
+
context_ff_output = _chunked_feed_forward(
|
203 |
+
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
207 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
208 |
+
|
209 |
+
return encoder_hidden_states, hidden_states
|
210 |
+
|
211 |
+
|
212 |
+
@maybe_allow_in_graph
|
213 |
+
class BasicTransformerBlock(nn.Module):
|
214 |
+
r"""
|
215 |
+
A basic Transformer block.
|
216 |
+
|
217 |
+
Parameters:
|
218 |
+
dim (`int`): The number of channels in the input and output.
|
219 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
220 |
+
attention_head_dim (`int`): The number of channels in each head.
|
221 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
222 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
223 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
224 |
+
num_embeds_ada_norm (:
|
225 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
226 |
+
attention_bias (:
|
227 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
228 |
+
only_cross_attention (`bool`, *optional*):
|
229 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
230 |
+
double_self_attention (`bool`, *optional*):
|
231 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
232 |
+
upcast_attention (`bool`, *optional*):
|
233 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
234 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
235 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
236 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
237 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
238 |
+
final_dropout (`bool` *optional*, defaults to False):
|
239 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
240 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
241 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
242 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
243 |
+
The type of positional embeddings to apply to.
|
244 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
245 |
+
The maximum number of positional embeddings to apply.
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
dim: int,
|
251 |
+
num_attention_heads: int,
|
252 |
+
attention_head_dim: int,
|
253 |
+
dropout=0.0,
|
254 |
+
cross_attention_dim: Optional[int] = None,
|
255 |
+
activation_fn: str = "geglu",
|
256 |
+
num_embeds_ada_norm: Optional[int] = None,
|
257 |
+
attention_bias: bool = False,
|
258 |
+
only_cross_attention: bool = False,
|
259 |
+
double_self_attention: bool = False,
|
260 |
+
upcast_attention: bool = False,
|
261 |
+
norm_elementwise_affine: bool = True,
|
262 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
263 |
+
norm_eps: float = 1e-5,
|
264 |
+
final_dropout: bool = False,
|
265 |
+
attention_type: str = "default",
|
266 |
+
positional_embeddings: Optional[str] = None,
|
267 |
+
num_positional_embeddings: Optional[int] = None,
|
268 |
+
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
269 |
+
ada_norm_bias: Optional[int] = None,
|
270 |
+
ff_inner_dim: Optional[int] = None,
|
271 |
+
ff_bias: bool = True,
|
272 |
+
attention_out_bias: bool = True,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.only_cross_attention = only_cross_attention
|
276 |
+
|
277 |
+
# We keep these boolean flags for backward-compatibility.
|
278 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
279 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
280 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
281 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
282 |
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
283 |
+
|
284 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
285 |
+
raise ValueError(
|
286 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
287 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
288 |
+
)
|
289 |
+
|
290 |
+
self.norm_type = norm_type
|
291 |
+
self.num_embeds_ada_norm = num_embeds_ada_norm
|
292 |
+
|
293 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
294 |
+
raise ValueError(
|
295 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
296 |
+
)
|
297 |
+
|
298 |
+
if positional_embeddings == "sinusoidal":
|
299 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
300 |
+
else:
|
301 |
+
self.pos_embed = None
|
302 |
+
|
303 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
304 |
+
# 1. Self-Attn
|
305 |
+
if norm_type == "ada_norm":
|
306 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
307 |
+
elif norm_type == "ada_norm_zero":
|
308 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
309 |
+
elif norm_type == "ada_norm_continuous":
|
310 |
+
self.norm1 = AdaLayerNormContinuous(
|
311 |
+
dim,
|
312 |
+
ada_norm_continous_conditioning_embedding_dim,
|
313 |
+
norm_elementwise_affine,
|
314 |
+
norm_eps,
|
315 |
+
ada_norm_bias,
|
316 |
+
"rms_norm",
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
320 |
+
|
321 |
+
self.attn1 = Attention(
|
322 |
+
query_dim=dim,
|
323 |
+
heads=num_attention_heads,
|
324 |
+
dim_head=attention_head_dim,
|
325 |
+
dropout=dropout,
|
326 |
+
bias=attention_bias,
|
327 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
328 |
+
upcast_attention=upcast_attention,
|
329 |
+
out_bias=attention_out_bias,
|
330 |
+
)
|
331 |
+
|
332 |
+
# 2. Cross-Attn
|
333 |
+
if cross_attention_dim is not None or double_self_attention:
|
334 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
335 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
336 |
+
# the second cross attention block.
|
337 |
+
if norm_type == "ada_norm":
|
338 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
339 |
+
elif norm_type == "ada_norm_continuous":
|
340 |
+
self.norm2 = AdaLayerNormContinuous(
|
341 |
+
dim,
|
342 |
+
ada_norm_continous_conditioning_embedding_dim,
|
343 |
+
norm_elementwise_affine,
|
344 |
+
norm_eps,
|
345 |
+
ada_norm_bias,
|
346 |
+
"rms_norm",
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
350 |
+
|
351 |
+
self.attn2 = Attention(
|
352 |
+
query_dim=dim,
|
353 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
354 |
+
heads=num_attention_heads,
|
355 |
+
dim_head=attention_head_dim,
|
356 |
+
dropout=dropout,
|
357 |
+
bias=attention_bias,
|
358 |
+
upcast_attention=upcast_attention,
|
359 |
+
out_bias=attention_out_bias,
|
360 |
+
) # is self-attn if encoder_hidden_states is none
|
361 |
+
else:
|
362 |
+
if norm_type == "ada_norm_single": # For Latte
|
363 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
364 |
+
else:
|
365 |
+
self.norm2 = None
|
366 |
+
self.attn2 = None
|
367 |
+
|
368 |
+
# 3. Feed-forward
|
369 |
+
if norm_type == "ada_norm_continuous":
|
370 |
+
self.norm3 = AdaLayerNormContinuous(
|
371 |
+
dim,
|
372 |
+
ada_norm_continous_conditioning_embedding_dim,
|
373 |
+
norm_elementwise_affine,
|
374 |
+
norm_eps,
|
375 |
+
ada_norm_bias,
|
376 |
+
"layer_norm",
|
377 |
+
)
|
378 |
+
|
379 |
+
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
380 |
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
381 |
+
elif norm_type == "layer_norm_i2vgen":
|
382 |
+
self.norm3 = None
|
383 |
+
|
384 |
+
self.ff = FeedForward(
|
385 |
+
dim,
|
386 |
+
dropout=dropout,
|
387 |
+
activation_fn=activation_fn,
|
388 |
+
final_dropout=final_dropout,
|
389 |
+
inner_dim=ff_inner_dim,
|
390 |
+
bias=ff_bias,
|
391 |
+
)
|
392 |
+
|
393 |
+
# 4. Fuser
|
394 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
395 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
396 |
+
|
397 |
+
# 5. Scale-shift for PixArt-Alpha.
|
398 |
+
if norm_type == "ada_norm_single":
|
399 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
400 |
+
|
401 |
+
# let chunk size default to None
|
402 |
+
self._chunk_size = None
|
403 |
+
self._chunk_dim = 0
|
404 |
+
|
405 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
406 |
+
# Sets chunk feed-forward
|
407 |
+
self._chunk_size = chunk_size
|
408 |
+
self._chunk_dim = dim
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
hidden_states: torch.Tensor,
|
413 |
+
attention_mask: Optional[torch.Tensor] = None,
|
414 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
415 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
416 |
+
timestep: Optional[torch.LongTensor] = None,
|
417 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
418 |
+
class_labels: Optional[torch.LongTensor] = None,
|
419 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
420 |
+
) -> torch.Tensor:
|
421 |
+
if cross_attention_kwargs is not None:
|
422 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
423 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
424 |
+
|
425 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
426 |
+
# 0. Self-Attention
|
427 |
+
batch_size = hidden_states.shape[0]
|
428 |
+
|
429 |
+
if self.norm_type == "ada_norm":
|
430 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
431 |
+
elif self.norm_type == "ada_norm_zero":
|
432 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
433 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
434 |
+
)
|
435 |
+
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
436 |
+
norm_hidden_states = self.norm1(hidden_states)
|
437 |
+
elif self.norm_type == "ada_norm_continuous":
|
438 |
+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
439 |
+
elif self.norm_type == "ada_norm_single":
|
440 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
441 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
442 |
+
).chunk(6, dim=1)
|
443 |
+
norm_hidden_states = self.norm1(hidden_states)
|
444 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
445 |
+
else:
|
446 |
+
raise ValueError("Incorrect norm used")
|
447 |
+
|
448 |
+
if self.pos_embed is not None:
|
449 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
450 |
+
|
451 |
+
# 1. Prepare GLIGEN inputs
|
452 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
453 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
454 |
+
|
455 |
+
attn_output = self.attn1(
|
456 |
+
norm_hidden_states,
|
457 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
458 |
+
attention_mask=attention_mask,
|
459 |
+
**cross_attention_kwargs,
|
460 |
+
)
|
461 |
+
|
462 |
+
if self.norm_type == "ada_norm_zero":
|
463 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
464 |
+
elif self.norm_type == "ada_norm_single":
|
465 |
+
attn_output = gate_msa * attn_output
|
466 |
+
|
467 |
+
hidden_states = attn_output + hidden_states
|
468 |
+
if hidden_states.ndim == 4:
|
469 |
+
hidden_states = hidden_states.squeeze(1)
|
470 |
+
|
471 |
+
# 1.2 GLIGEN Control
|
472 |
+
if gligen_kwargs is not None:
|
473 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
474 |
+
|
475 |
+
# 3. Cross-Attention
|
476 |
+
if self.attn2 is not None:
|
477 |
+
if self.norm_type == "ada_norm":
|
478 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
479 |
+
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
480 |
+
norm_hidden_states = self.norm2(hidden_states)
|
481 |
+
elif self.norm_type == "ada_norm_single":
|
482 |
+
# For PixArt norm2 isn't applied here:
|
483 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
484 |
+
norm_hidden_states = hidden_states
|
485 |
+
elif self.norm_type == "ada_norm_continuous":
|
486 |
+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
487 |
+
else:
|
488 |
+
raise ValueError("Incorrect norm")
|
489 |
+
|
490 |
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
491 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
492 |
+
|
493 |
+
attn_output = self.attn2(
|
494 |
+
norm_hidden_states,
|
495 |
+
encoder_hidden_states=encoder_hidden_states,
|
496 |
+
attention_mask=encoder_attention_mask,
|
497 |
+
**cross_attention_kwargs,
|
498 |
+
)
|
499 |
+
hidden_states = attn_output + hidden_states
|
500 |
+
|
501 |
+
# 4. Feed-forward
|
502 |
+
# i2vgen doesn't have this norm 🤷♂️
|
503 |
+
if self.norm_type == "ada_norm_continuous":
|
504 |
+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
505 |
+
elif not self.norm_type == "ada_norm_single":
|
506 |
+
norm_hidden_states = self.norm3(hidden_states)
|
507 |
+
|
508 |
+
if self.norm_type == "ada_norm_zero":
|
509 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
510 |
+
|
511 |
+
if self.norm_type == "ada_norm_single":
|
512 |
+
norm_hidden_states = self.norm2(hidden_states)
|
513 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
514 |
+
|
515 |
+
if self._chunk_size is not None:
|
516 |
+
# "feed_forward_chunk_size" can be used to save memory
|
517 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
518 |
+
else:
|
519 |
+
ff_output = self.ff(norm_hidden_states)
|
520 |
+
|
521 |
+
if self.norm_type == "ada_norm_zero":
|
522 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
523 |
+
elif self.norm_type == "ada_norm_single":
|
524 |
+
ff_output = gate_mlp * ff_output
|
525 |
+
|
526 |
+
hidden_states = ff_output + hidden_states
|
527 |
+
if hidden_states.ndim == 4:
|
528 |
+
hidden_states = hidden_states.squeeze(1)
|
529 |
+
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
|
533 |
+
class LuminaFeedForward(nn.Module):
|
534 |
+
r"""
|
535 |
+
A feed-forward layer.
|
536 |
+
|
537 |
+
Parameters:
|
538 |
+
hidden_size (`int`):
|
539 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
540 |
+
hidden representations.
|
541 |
+
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
542 |
+
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
543 |
+
of this value.
|
544 |
+
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
545 |
+
dimension. Defaults to None.
|
546 |
+
"""
|
547 |
+
|
548 |
+
def __init__(
|
549 |
+
self,
|
550 |
+
dim: int,
|
551 |
+
inner_dim: int,
|
552 |
+
multiple_of: Optional[int] = 256,
|
553 |
+
ffn_dim_multiplier: Optional[float] = None,
|
554 |
+
):
|
555 |
+
super().__init__()
|
556 |
+
inner_dim = int(2 * inner_dim / 3)
|
557 |
+
# custom hidden_size factor multiplier
|
558 |
+
if ffn_dim_multiplier is not None:
|
559 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
560 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
561 |
+
|
562 |
+
self.linear_1 = nn.Linear(
|
563 |
+
dim,
|
564 |
+
inner_dim,
|
565 |
+
bias=False,
|
566 |
+
)
|
567 |
+
self.linear_2 = nn.Linear(
|
568 |
+
inner_dim,
|
569 |
+
dim,
|
570 |
+
bias=False,
|
571 |
+
)
|
572 |
+
self.linear_3 = nn.Linear(
|
573 |
+
dim,
|
574 |
+
inner_dim,
|
575 |
+
bias=False,
|
576 |
+
)
|
577 |
+
self.silu = FP32SiLU()
|
578 |
+
|
579 |
+
def forward(self, x):
|
580 |
+
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
581 |
+
|
582 |
+
|
583 |
+
@maybe_allow_in_graph
|
584 |
+
class TemporalBasicTransformerBlock(nn.Module):
|
585 |
+
r"""
|
586 |
+
A basic Transformer block for video like data.
|
587 |
+
|
588 |
+
Parameters:
|
589 |
+
dim (`int`): The number of channels in the input and output.
|
590 |
+
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
591 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
592 |
+
attention_head_dim (`int`): The number of channels in each head.
|
593 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
594 |
+
"""
|
595 |
+
|
596 |
+
def __init__(
|
597 |
+
self,
|
598 |
+
dim: int,
|
599 |
+
time_mix_inner_dim: int,
|
600 |
+
num_attention_heads: int,
|
601 |
+
attention_head_dim: int,
|
602 |
+
cross_attention_dim: Optional[int] = None,
|
603 |
+
):
|
604 |
+
super().__init__()
|
605 |
+
self.is_res = dim == time_mix_inner_dim
|
606 |
+
|
607 |
+
self.norm_in = nn.LayerNorm(dim)
|
608 |
+
|
609 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
610 |
+
# 1. Self-Attn
|
611 |
+
self.ff_in = FeedForward(
|
612 |
+
dim,
|
613 |
+
dim_out=time_mix_inner_dim,
|
614 |
+
activation_fn="geglu",
|
615 |
+
)
|
616 |
+
|
617 |
+
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
618 |
+
self.attn1 = Attention(
|
619 |
+
query_dim=time_mix_inner_dim,
|
620 |
+
heads=num_attention_heads,
|
621 |
+
dim_head=attention_head_dim,
|
622 |
+
cross_attention_dim=None,
|
623 |
+
)
|
624 |
+
|
625 |
+
# 2. Cross-Attn
|
626 |
+
if cross_attention_dim is not None:
|
627 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
628 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
629 |
+
# the second cross attention block.
|
630 |
+
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
631 |
+
self.attn2 = Attention(
|
632 |
+
query_dim=time_mix_inner_dim,
|
633 |
+
cross_attention_dim=cross_attention_dim,
|
634 |
+
heads=num_attention_heads,
|
635 |
+
dim_head=attention_head_dim,
|
636 |
+
) # is self-attn if encoder_hidden_states is none
|
637 |
+
else:
|
638 |
+
self.norm2 = None
|
639 |
+
self.attn2 = None
|
640 |
+
|
641 |
+
# 3. Feed-forward
|
642 |
+
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
643 |
+
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
644 |
+
|
645 |
+
# let chunk size default to None
|
646 |
+
self._chunk_size = None
|
647 |
+
self._chunk_dim = None
|
648 |
+
|
649 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
650 |
+
# Sets chunk feed-forward
|
651 |
+
self._chunk_size = chunk_size
|
652 |
+
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
653 |
+
self._chunk_dim = 1
|
654 |
+
|
655 |
+
def forward(
|
656 |
+
self,
|
657 |
+
hidden_states: torch.Tensor,
|
658 |
+
num_frames: int,
|
659 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
660 |
+
) -> torch.Tensor:
|
661 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
662 |
+
# 0. Self-Attention
|
663 |
+
batch_size = hidden_states.shape[0]
|
664 |
+
|
665 |
+
batch_frames, seq_length, channels = hidden_states.shape
|
666 |
+
batch_size = batch_frames // num_frames
|
667 |
+
|
668 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
669 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
670 |
+
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
671 |
+
|
672 |
+
residual = hidden_states
|
673 |
+
hidden_states = self.norm_in(hidden_states)
|
674 |
+
|
675 |
+
if self._chunk_size is not None:
|
676 |
+
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
677 |
+
else:
|
678 |
+
hidden_states = self.ff_in(hidden_states)
|
679 |
+
|
680 |
+
if self.is_res:
|
681 |
+
hidden_states = hidden_states + residual
|
682 |
+
|
683 |
+
norm_hidden_states = self.norm1(hidden_states)
|
684 |
+
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
685 |
+
hidden_states = attn_output + hidden_states
|
686 |
+
|
687 |
+
# 3. Cross-Attention
|
688 |
+
if self.attn2 is not None:
|
689 |
+
norm_hidden_states = self.norm2(hidden_states)
|
690 |
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
691 |
+
hidden_states = attn_output + hidden_states
|
692 |
+
|
693 |
+
# 4. Feed-forward
|
694 |
+
norm_hidden_states = self.norm3(hidden_states)
|
695 |
+
|
696 |
+
if self._chunk_size is not None:
|
697 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
698 |
+
else:
|
699 |
+
ff_output = self.ff(norm_hidden_states)
|
700 |
+
|
701 |
+
if self.is_res:
|
702 |
+
hidden_states = ff_output + hidden_states
|
703 |
+
else:
|
704 |
+
hidden_states = ff_output
|
705 |
+
|
706 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
707 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
708 |
+
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
709 |
+
|
710 |
+
return hidden_states
|
711 |
+
|
712 |
+
|
713 |
+
class SkipFFTransformerBlock(nn.Module):
|
714 |
+
def __init__(
|
715 |
+
self,
|
716 |
+
dim: int,
|
717 |
+
num_attention_heads: int,
|
718 |
+
attention_head_dim: int,
|
719 |
+
kv_input_dim: int,
|
720 |
+
kv_input_dim_proj_use_bias: bool,
|
721 |
+
dropout=0.0,
|
722 |
+
cross_attention_dim: Optional[int] = None,
|
723 |
+
attention_bias: bool = False,
|
724 |
+
attention_out_bias: bool = True,
|
725 |
+
):
|
726 |
+
super().__init__()
|
727 |
+
if kv_input_dim != dim:
|
728 |
+
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
729 |
+
else:
|
730 |
+
self.kv_mapper = None
|
731 |
+
|
732 |
+
self.norm1 = RMSNorm(dim, 1e-06)
|
733 |
+
|
734 |
+
self.attn1 = Attention(
|
735 |
+
query_dim=dim,
|
736 |
+
heads=num_attention_heads,
|
737 |
+
dim_head=attention_head_dim,
|
738 |
+
dropout=dropout,
|
739 |
+
bias=attention_bias,
|
740 |
+
cross_attention_dim=cross_attention_dim,
|
741 |
+
out_bias=attention_out_bias,
|
742 |
+
)
|
743 |
+
|
744 |
+
self.norm2 = RMSNorm(dim, 1e-06)
|
745 |
+
|
746 |
+
self.attn2 = Attention(
|
747 |
+
query_dim=dim,
|
748 |
+
cross_attention_dim=cross_attention_dim,
|
749 |
+
heads=num_attention_heads,
|
750 |
+
dim_head=attention_head_dim,
|
751 |
+
dropout=dropout,
|
752 |
+
bias=attention_bias,
|
753 |
+
out_bias=attention_out_bias,
|
754 |
+
)
|
755 |
+
|
756 |
+
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
757 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
758 |
+
|
759 |
+
if self.kv_mapper is not None:
|
760 |
+
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
761 |
+
|
762 |
+
norm_hidden_states = self.norm1(hidden_states)
|
763 |
+
|
764 |
+
attn_output = self.attn1(
|
765 |
+
norm_hidden_states,
|
766 |
+
encoder_hidden_states=encoder_hidden_states,
|
767 |
+
**cross_attention_kwargs,
|
768 |
+
)
|
769 |
+
|
770 |
+
hidden_states = attn_output + hidden_states
|
771 |
+
|
772 |
+
norm_hidden_states = self.norm2(hidden_states)
|
773 |
+
|
774 |
+
attn_output = self.attn2(
|
775 |
+
norm_hidden_states,
|
776 |
+
encoder_hidden_states=encoder_hidden_states,
|
777 |
+
**cross_attention_kwargs,
|
778 |
+
)
|
779 |
+
|
780 |
+
hidden_states = attn_output + hidden_states
|
781 |
+
|
782 |
+
return hidden_states
|
783 |
+
|
784 |
+
|
785 |
+
class FeedForward(nn.Module):
|
786 |
+
r"""
|
787 |
+
A feed-forward layer.
|
788 |
+
|
789 |
+
Parameters:
|
790 |
+
dim (`int`): The number of channels in the input.
|
791 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
792 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
793 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
794 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
795 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
796 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
797 |
+
"""
|
798 |
+
|
799 |
+
def __init__(
|
800 |
+
self,
|
801 |
+
dim: int,
|
802 |
+
dim_out: Optional[int] = None,
|
803 |
+
mult: int = 4,
|
804 |
+
dropout: float = 0.0,
|
805 |
+
activation_fn: str = "geglu",
|
806 |
+
final_dropout: bool = False,
|
807 |
+
inner_dim=None,
|
808 |
+
bias: bool = True,
|
809 |
+
):
|
810 |
+
super().__init__()
|
811 |
+
if inner_dim is None:
|
812 |
+
inner_dim = int(dim * mult)
|
813 |
+
dim_out = dim_out if dim_out is not None else dim
|
814 |
+
|
815 |
+
if activation_fn == "gelu":
|
816 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
817 |
+
if activation_fn == "gelu-approximate":
|
818 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
819 |
+
elif activation_fn == "geglu":
|
820 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
821 |
+
elif activation_fn == "geglu-approximate":
|
822 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
823 |
+
elif activation_fn == "swiglu":
|
824 |
+
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
825 |
+
|
826 |
+
self.net = nn.ModuleList([])
|
827 |
+
# project in
|
828 |
+
self.net.append(act_fn)
|
829 |
+
# project dropout
|
830 |
+
self.net.append(nn.Dropout(dropout))
|
831 |
+
# project out
|
832 |
+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
833 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
834 |
+
if final_dropout:
|
835 |
+
self.net.append(nn.Dropout(dropout))
|
836 |
+
|
837 |
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
838 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
839 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
840 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
841 |
+
for module in self.net:
|
842 |
+
hidden_states = module(hidden_states)
|
843 |
+
return hidden_states
|
flux/attention_processor.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
flux/controlnet_flux.py
ADDED
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from .lora.peft import PeftAdapterMixin
|
23 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
25 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
26 |
+
#from .controlnet import BaseOutput, zero_module
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
29 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
30 |
+
from .transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
def zero_module(module):
|
37 |
+
for p in module.parameters():
|
38 |
+
nn.init.zeros_(p)
|
39 |
+
return module
|
40 |
+
|
41 |
+
def get_1d_rotary_pos_embed(
|
42 |
+
dim: int,
|
43 |
+
pos: Union[np.ndarray, int],
|
44 |
+
theta: float = 10000.0,
|
45 |
+
use_real=False,
|
46 |
+
linear_factor=1.0,
|
47 |
+
ntk_factor=1.0,
|
48 |
+
repeat_interleave_real=True,
|
49 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
53 |
+
|
54 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
55 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
56 |
+
data type.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
dim (`int`): Dimension of the frequency tensor.
|
60 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
61 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
62 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
63 |
+
use_real (`bool`, *optional*):
|
64 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
65 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
66 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
67 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
68 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
69 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
70 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
71 |
+
Otherwise, they are concateanted with themselves.
|
72 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
73 |
+
the dtype of the frequency tensor.
|
74 |
+
Returns:
|
75 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
76 |
+
"""
|
77 |
+
assert dim % 2 == 0
|
78 |
+
|
79 |
+
if isinstance(pos, int):
|
80 |
+
pos = torch.arange(pos)
|
81 |
+
if isinstance(pos, np.ndarray):
|
82 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
83 |
+
|
84 |
+
theta = theta * ntk_factor
|
85 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
86 |
+
freqs = freqs.to(pos.device)
|
87 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
88 |
+
if use_real and repeat_interleave_real:
|
89 |
+
# flux, hunyuan-dit, cogvideox
|
90 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
91 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
92 |
+
return freqs_cos, freqs_sin
|
93 |
+
elif use_real:
|
94 |
+
# stable audio
|
95 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
96 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
97 |
+
return freqs_cos, freqs_sin
|
98 |
+
else:
|
99 |
+
# lumina
|
100 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
101 |
+
return freqs_cis
|
102 |
+
|
103 |
+
class FluxPosEmbed(nn.Module):
|
104 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
105 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
106 |
+
super().__init__()
|
107 |
+
self.theta = theta
|
108 |
+
self.axes_dim = axes_dim
|
109 |
+
|
110 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
111 |
+
n_axes = ids.shape[-1]
|
112 |
+
cos_out = []
|
113 |
+
sin_out = []
|
114 |
+
pos = ids.squeeze().float()
|
115 |
+
is_mps = ids.device.type == "mps"
|
116 |
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
117 |
+
for i in range(n_axes):
|
118 |
+
cos, sin = get_1d_rotary_pos_embed(
|
119 |
+
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
120 |
+
)
|
121 |
+
cos_out.append(cos)
|
122 |
+
sin_out.append(sin)
|
123 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
124 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
125 |
+
return freqs_cos, freqs_sin
|
126 |
+
|
127 |
+
|
128 |
+
@dataclass
|
129 |
+
class FluxControlNetOutput(BaseOutput):
|
130 |
+
controlnet_block_samples: Tuple[torch.Tensor]
|
131 |
+
controlnet_single_block_samples: Tuple[torch.Tensor]
|
132 |
+
|
133 |
+
|
134 |
+
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
135 |
+
_supports_gradient_checkpointing = True
|
136 |
+
|
137 |
+
@register_to_config
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
patch_size: int = 1,
|
141 |
+
in_channels: int = 64,
|
142 |
+
num_layers: int = 19,
|
143 |
+
num_single_layers: int = 38,
|
144 |
+
attention_head_dim: int = 128,
|
145 |
+
num_attention_heads: int = 24,
|
146 |
+
joint_attention_dim: int = 4096,
|
147 |
+
pooled_projection_dim: int = 768,
|
148 |
+
guidance_embeds: bool = False,
|
149 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
150 |
+
num_mode: int = None,
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
self.out_channels = in_channels
|
154 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
155 |
+
|
156 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
157 |
+
text_time_guidance_cls = (
|
158 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
159 |
+
)
|
160 |
+
self.time_text_embed = text_time_guidance_cls(
|
161 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
162 |
+
)
|
163 |
+
|
164 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
165 |
+
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
166 |
+
|
167 |
+
self.transformer_blocks = nn.ModuleList(
|
168 |
+
[
|
169 |
+
FluxTransformerBlock(
|
170 |
+
dim=self.inner_dim,
|
171 |
+
num_attention_heads=num_attention_heads,
|
172 |
+
attention_head_dim=attention_head_dim,
|
173 |
+
)
|
174 |
+
for i in range(num_layers)
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
self.single_transformer_blocks = nn.ModuleList(
|
179 |
+
[
|
180 |
+
FluxSingleTransformerBlock(
|
181 |
+
dim=self.inner_dim,
|
182 |
+
num_attention_heads=num_attention_heads,
|
183 |
+
attention_head_dim=attention_head_dim,
|
184 |
+
)
|
185 |
+
for i in range(num_single_layers)
|
186 |
+
]
|
187 |
+
)
|
188 |
+
|
189 |
+
# controlnet_blocks
|
190 |
+
self.controlnet_blocks = nn.ModuleList([])
|
191 |
+
for _ in range(len(self.transformer_blocks)):
|
192 |
+
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
193 |
+
|
194 |
+
self.controlnet_single_blocks = nn.ModuleList([])
|
195 |
+
for _ in range(len(self.single_transformer_blocks)):
|
196 |
+
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
197 |
+
|
198 |
+
self.union = num_mode is not None
|
199 |
+
if self.union:
|
200 |
+
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
201 |
+
|
202 |
+
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
203 |
+
|
204 |
+
self.gradient_checkpointing = False
|
205 |
+
|
206 |
+
@property
|
207 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
208 |
+
def attn_processors(self):
|
209 |
+
r"""
|
210 |
+
Returns:
|
211 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
212 |
+
indexed by its weight name.
|
213 |
+
"""
|
214 |
+
# set recursively
|
215 |
+
processors = {}
|
216 |
+
|
217 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
218 |
+
if hasattr(module, "get_processor"):
|
219 |
+
processors[f"{name}.processor"] = module.get_processor()
|
220 |
+
|
221 |
+
for sub_name, child in module.named_children():
|
222 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
223 |
+
|
224 |
+
return processors
|
225 |
+
|
226 |
+
for name, module in self.named_children():
|
227 |
+
fn_recursive_add_processors(name, module, processors)
|
228 |
+
|
229 |
+
return processors
|
230 |
+
|
231 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
232 |
+
def set_attn_processor(self, processor):
|
233 |
+
r"""
|
234 |
+
Sets the attention processor to use to compute attention.
|
235 |
+
|
236 |
+
Parameters:
|
237 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
238 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
239 |
+
for **all** `Attention` layers.
|
240 |
+
|
241 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
242 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
243 |
+
|
244 |
+
"""
|
245 |
+
count = len(self.attn_processors.keys())
|
246 |
+
|
247 |
+
if isinstance(processor, dict) and len(processor) != count:
|
248 |
+
raise ValueError(
|
249 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
250 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
251 |
+
)
|
252 |
+
|
253 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
254 |
+
if hasattr(module, "set_processor"):
|
255 |
+
if not isinstance(processor, dict):
|
256 |
+
module.set_processor(processor)
|
257 |
+
else:
|
258 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
259 |
+
|
260 |
+
for sub_name, child in module.named_children():
|
261 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
262 |
+
|
263 |
+
for name, module in self.named_children():
|
264 |
+
fn_recursive_attn_processor(name, module, processor)
|
265 |
+
|
266 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
267 |
+
if hasattr(module, "gradient_checkpointing"):
|
268 |
+
module.gradient_checkpointing = value
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def from_transformer(
|
272 |
+
cls,
|
273 |
+
transformer,
|
274 |
+
num_layers: int = 4,
|
275 |
+
num_single_layers: int = 10,
|
276 |
+
attention_head_dim: int = 128,
|
277 |
+
num_attention_heads: int = 24,
|
278 |
+
load_weights_from_transformer=True,
|
279 |
+
):
|
280 |
+
config = transformer.config
|
281 |
+
config["num_layers"] = num_layers
|
282 |
+
config["num_single_layers"] = num_single_layers
|
283 |
+
config["attention_head_dim"] = attention_head_dim
|
284 |
+
config["num_attention_heads"] = num_attention_heads
|
285 |
+
|
286 |
+
controlnet = cls(**config)
|
287 |
+
|
288 |
+
if load_weights_from_transformer:
|
289 |
+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
290 |
+
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
291 |
+
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
292 |
+
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
293 |
+
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
294 |
+
controlnet.single_transformer_blocks.load_state_dict(
|
295 |
+
transformer.single_transformer_blocks.state_dict(), strict=False
|
296 |
+
)
|
297 |
+
|
298 |
+
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
299 |
+
|
300 |
+
return controlnet
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
hidden_states: torch.Tensor,
|
305 |
+
controlnet_cond: torch.Tensor,
|
306 |
+
controlnet_mode: torch.Tensor = None,
|
307 |
+
conditioning_scale: float = 1.0,
|
308 |
+
encoder_hidden_states: torch.Tensor = None,
|
309 |
+
t5_encoder_hidden_states: torch.Tensor = None,
|
310 |
+
pooled_projections: torch.Tensor = None,
|
311 |
+
timestep: torch.LongTensor = None,
|
312 |
+
img_ids: torch.Tensor = None,
|
313 |
+
txt_ids: torch.Tensor = None,
|
314 |
+
guidance: torch.Tensor = None,
|
315 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
316 |
+
return_dict: bool = True,
|
317 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
318 |
+
"""
|
319 |
+
The [`FluxTransformer2DModel`] forward method.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
323 |
+
Input `hidden_states`.
|
324 |
+
controlnet_cond (`torch.Tensor`):
|
325 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
326 |
+
controlnet_mode (`torch.Tensor`):
|
327 |
+
The mode tensor of shape `(batch_size, 1)`.
|
328 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
329 |
+
The scale factor for ControlNet outputs.
|
330 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
331 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
332 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
333 |
+
from the embeddings of input conditions.
|
334 |
+
timestep ( `torch.LongTensor`):
|
335 |
+
Used to indicate denoising step.
|
336 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
337 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
338 |
+
joint_attention_kwargs (`dict`, *optional*):
|
339 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
340 |
+
`self.processor` in
|
341 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
342 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
343 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
344 |
+
tuple.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
348 |
+
`tuple` where the first element is the sample tensor.
|
349 |
+
"""
|
350 |
+
if joint_attention_kwargs is not None:
|
351 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
352 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
353 |
+
else:
|
354 |
+
lora_scale = 1.0
|
355 |
+
|
356 |
+
if USE_PEFT_BACKEND:
|
357 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
358 |
+
scale_lora_layers(self, lora_scale)
|
359 |
+
else:
|
360 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
361 |
+
logger.warning(
|
362 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
363 |
+
)
|
364 |
+
hidden_states = self.x_embedder(hidden_states)
|
365 |
+
|
366 |
+
# add
|
367 |
+
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
368 |
+
|
369 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
370 |
+
if guidance is not None:
|
371 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
372 |
+
else:
|
373 |
+
guidance = None
|
374 |
+
temb = (
|
375 |
+
self.time_text_embed(timestep, pooled_projections)
|
376 |
+
if guidance is None
|
377 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
378 |
+
)
|
379 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
380 |
+
if t5_encoder_hidden_states is not None:
|
381 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
|
382 |
+
|
383 |
+
if txt_ids.ndim == 3:
|
384 |
+
logger.warning(
|
385 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
386 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
387 |
+
)
|
388 |
+
txt_ids = txt_ids[0]
|
389 |
+
|
390 |
+
if self.union:
|
391 |
+
# union mode
|
392 |
+
if controlnet_mode is None:
|
393 |
+
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
394 |
+
# union mode emb
|
395 |
+
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
396 |
+
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
397 |
+
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
398 |
+
|
399 |
+
if img_ids.ndim == 3:
|
400 |
+
logger.warning(
|
401 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
402 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
403 |
+
)
|
404 |
+
img_ids = img_ids[0]
|
405 |
+
|
406 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
407 |
+
image_rotary_emb = self.pos_embed(ids)
|
408 |
+
|
409 |
+
block_samples = ()
|
410 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
411 |
+
if self.training and self.gradient_checkpointing:
|
412 |
+
|
413 |
+
def create_custom_forward(module, return_dict=None):
|
414 |
+
def custom_forward(*inputs):
|
415 |
+
if return_dict is not None:
|
416 |
+
return module(*inputs, return_dict=return_dict)
|
417 |
+
else:
|
418 |
+
return module(*inputs)
|
419 |
+
|
420 |
+
return custom_forward
|
421 |
+
|
422 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
423 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
424 |
+
create_custom_forward(block),
|
425 |
+
hidden_states,
|
426 |
+
encoder_hidden_states,
|
427 |
+
temb,
|
428 |
+
image_rotary_emb,
|
429 |
+
**ckpt_kwargs,
|
430 |
+
)
|
431 |
+
|
432 |
+
else:
|
433 |
+
encoder_hidden_states, hidden_states = block(
|
434 |
+
hidden_states=hidden_states,
|
435 |
+
encoder_hidden_states=encoder_hidden_states,
|
436 |
+
temb=temb,
|
437 |
+
image_rotary_emb=image_rotary_emb,
|
438 |
+
)
|
439 |
+
block_samples = block_samples + (hidden_states,)
|
440 |
+
|
441 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
442 |
+
|
443 |
+
single_block_samples = ()
|
444 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
445 |
+
if self.training and self.gradient_checkpointing:
|
446 |
+
|
447 |
+
def create_custom_forward(module, return_dict=None):
|
448 |
+
def custom_forward(*inputs):
|
449 |
+
if return_dict is not None:
|
450 |
+
return module(*inputs, return_dict=return_dict)
|
451 |
+
else:
|
452 |
+
return module(*inputs)
|
453 |
+
|
454 |
+
return custom_forward
|
455 |
+
|
456 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
457 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
458 |
+
create_custom_forward(block),
|
459 |
+
hidden_states,
|
460 |
+
temb,
|
461 |
+
image_rotary_emb,
|
462 |
+
**ckpt_kwargs,
|
463 |
+
)
|
464 |
+
|
465 |
+
else:
|
466 |
+
hidden_states = block(
|
467 |
+
hidden_states=hidden_states,
|
468 |
+
temb=temb,
|
469 |
+
image_rotary_emb=image_rotary_emb,
|
470 |
+
)
|
471 |
+
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
472 |
+
|
473 |
+
# controlnet block
|
474 |
+
controlnet_block_samples = ()
|
475 |
+
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
476 |
+
block_sample = controlnet_block(block_sample)
|
477 |
+
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
478 |
+
|
479 |
+
controlnet_single_block_samples = ()
|
480 |
+
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
|
481 |
+
single_block_sample = controlnet_block(single_block_sample)
|
482 |
+
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
483 |
+
|
484 |
+
# scaling
|
485 |
+
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
486 |
+
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
487 |
+
|
488 |
+
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
489 |
+
controlnet_single_block_samples = (
|
490 |
+
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
491 |
+
)
|
492 |
+
|
493 |
+
if USE_PEFT_BACKEND:
|
494 |
+
# remove `lora_scale` from each PEFT layer
|
495 |
+
unscale_lora_layers(self, lora_scale)
|
496 |
+
|
497 |
+
if not return_dict:
|
498 |
+
return (controlnet_block_samples, controlnet_single_block_samples)
|
499 |
+
|
500 |
+
return FluxControlNetOutput(
|
501 |
+
controlnet_block_samples=controlnet_block_samples,
|
502 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
class FluxMultiControlNetModel(ModelMixin):
|
507 |
+
r"""
|
508 |
+
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
|
509 |
+
|
510 |
+
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
|
511 |
+
compatible with `FluxControlNetModel`.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
controlnets (`List[FluxControlNetModel]`):
|
515 |
+
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
516 |
+
`FluxControlNetModel` as a list.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, controlnets):
|
520 |
+
super().__init__()
|
521 |
+
self.nets = nn.ModuleList(controlnets)
|
522 |
+
|
523 |
+
def forward(
|
524 |
+
self,
|
525 |
+
hidden_states: torch.FloatTensor,
|
526 |
+
controlnet_cond: List[torch.tensor],
|
527 |
+
controlnet_mode: List[torch.tensor],
|
528 |
+
conditioning_scale: List[float],
|
529 |
+
encoder_hidden_states: torch.Tensor = None,
|
530 |
+
t5_encoder_hidden_states: torch.Tensor = None,
|
531 |
+
pooled_projections: torch.Tensor = None,
|
532 |
+
timestep: torch.LongTensor = None,
|
533 |
+
img_ids: torch.Tensor = None,
|
534 |
+
txt_ids: torch.Tensor = None,
|
535 |
+
guidance: torch.Tensor = None,
|
536 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
537 |
+
return_dict: bool = True,
|
538 |
+
) -> Union[FluxControlNetOutput, Tuple]:
|
539 |
+
# ControlNet-Union with multiple conditions
|
540 |
+
# only load one ControlNet for saving memories
|
541 |
+
if len(self.nets) == 1 and self.nets[0].union:
|
542 |
+
controlnet = self.nets[0]
|
543 |
+
|
544 |
+
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
545 |
+
block_samples, single_block_samples = controlnet(
|
546 |
+
hidden_states=hidden_states,
|
547 |
+
controlnet_cond=image,
|
548 |
+
controlnet_mode=mode[:, None],
|
549 |
+
conditioning_scale=scale,
|
550 |
+
timestep=timestep,
|
551 |
+
guidance=guidance,
|
552 |
+
pooled_projections=pooled_projections,
|
553 |
+
encoder_hidden_states=encoder_hidden_states,
|
554 |
+
t5_encoder_hidden_states=t5_encoder_hidden_states,
|
555 |
+
txt_ids=txt_ids,
|
556 |
+
img_ids=img_ids,
|
557 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
558 |
+
return_dict=return_dict,
|
559 |
+
)
|
560 |
+
|
561 |
+
# merge samples
|
562 |
+
if i == 0:
|
563 |
+
control_block_samples = block_samples
|
564 |
+
control_single_block_samples = single_block_samples
|
565 |
+
else:
|
566 |
+
control_block_samples = [
|
567 |
+
control_block_sample + block_sample
|
568 |
+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
569 |
+
]
|
570 |
+
|
571 |
+
control_single_block_samples = [
|
572 |
+
control_single_block_sample + block_sample
|
573 |
+
for control_single_block_sample, block_sample in zip(
|
574 |
+
control_single_block_samples, single_block_samples
|
575 |
+
)
|
576 |
+
]
|
577 |
+
|
578 |
+
# Regular Multi-ControlNets
|
579 |
+
# load all ControlNets into memories
|
580 |
+
else:
|
581 |
+
for i, (image, mode, scale, controlnet) in enumerate(
|
582 |
+
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
|
583 |
+
):
|
584 |
+
block_samples, single_block_samples = controlnet(
|
585 |
+
hidden_states=hidden_states,
|
586 |
+
controlnet_cond=image,
|
587 |
+
controlnet_mode=mode[:, None],
|
588 |
+
conditioning_scale=scale,
|
589 |
+
timestep=timestep,
|
590 |
+
guidance=guidance,
|
591 |
+
pooled_projections=pooled_projections,
|
592 |
+
encoder_hidden_states=encoder_hidden_states,
|
593 |
+
txt_ids=txt_ids,
|
594 |
+
img_ids=img_ids,
|
595 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
596 |
+
return_dict=return_dict,
|
597 |
+
)
|
598 |
+
|
599 |
+
# merge samples
|
600 |
+
if i == 0:
|
601 |
+
control_block_samples = block_samples
|
602 |
+
control_single_block_samples = single_block_samples
|
603 |
+
else:
|
604 |
+
if block_samples is not None and control_block_samples is not None:
|
605 |
+
control_block_samples = [
|
606 |
+
control_block_sample + block_sample
|
607 |
+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
608 |
+
]
|
609 |
+
if single_block_samples is not None and control_single_block_samples is not None:
|
610 |
+
control_single_block_samples = [
|
611 |
+
control_single_block_sample + block_sample
|
612 |
+
for control_single_block_sample, block_sample in zip(
|
613 |
+
control_single_block_samples, single_block_samples
|
614 |
+
)
|
615 |
+
]
|
616 |
+
|
617 |
+
return control_block_samples, control_single_block_samples
|
flux/embeddings.py
ADDED
@@ -0,0 +1,1469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from diffusers.utils import deprecate
|
23 |
+
from .activations import FP32SiLU, get_activation
|
24 |
+
from .attention_processor import Attention
|
25 |
+
|
26 |
+
|
27 |
+
def get_timestep_embedding(
|
28 |
+
timesteps: torch.Tensor,
|
29 |
+
embedding_dim: int,
|
30 |
+
flip_sin_to_cos: bool = False,
|
31 |
+
downscale_freq_shift: float = 1,
|
32 |
+
scale: float = 1,
|
33 |
+
max_period: int = 10000,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
37 |
+
|
38 |
+
Args
|
39 |
+
timesteps (torch.Tensor):
|
40 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
41 |
+
embedding_dim (int):
|
42 |
+
the dimension of the output.
|
43 |
+
flip_sin_to_cos (bool):
|
44 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
45 |
+
downscale_freq_shift (float):
|
46 |
+
Controls the delta between frequencies between dimensions
|
47 |
+
scale (float):
|
48 |
+
Scaling factor applied to the embeddings.
|
49 |
+
max_period (int):
|
50 |
+
Controls the maximum frequency of the embeddings
|
51 |
+
Returns
|
52 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
53 |
+
"""
|
54 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
55 |
+
|
56 |
+
half_dim = embedding_dim // 2
|
57 |
+
exponent = -math.log(max_period) * torch.arange(
|
58 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
59 |
+
)
|
60 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
61 |
+
|
62 |
+
emb = torch.exp(exponent)
|
63 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
64 |
+
|
65 |
+
# scale embeddings
|
66 |
+
emb = scale * emb
|
67 |
+
|
68 |
+
# concat sine and cosine embeddings
|
69 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
70 |
+
|
71 |
+
# flip sine and cosine embeddings
|
72 |
+
if flip_sin_to_cos:
|
73 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
74 |
+
|
75 |
+
# zero pad
|
76 |
+
if embedding_dim % 2 == 1:
|
77 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
78 |
+
return emb
|
79 |
+
|
80 |
+
|
81 |
+
def get_2d_sincos_pos_embed(
|
82 |
+
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
86 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
87 |
+
"""
|
88 |
+
if isinstance(grid_size, int):
|
89 |
+
grid_size = (grid_size, grid_size)
|
90 |
+
|
91 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
92 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
93 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
94 |
+
grid = np.stack(grid, axis=0)
|
95 |
+
|
96 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
97 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
98 |
+
if cls_token and extra_tokens > 0:
|
99 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
100 |
+
return pos_embed
|
101 |
+
|
102 |
+
|
103 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
104 |
+
if embed_dim % 2 != 0:
|
105 |
+
raise ValueError("embed_dim must be divisible by 2")
|
106 |
+
|
107 |
+
# use half of dimensions to encode grid_h
|
108 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
109 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
110 |
+
|
111 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
112 |
+
return emb
|
113 |
+
|
114 |
+
|
115 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
116 |
+
"""
|
117 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
118 |
+
"""
|
119 |
+
if embed_dim % 2 != 0:
|
120 |
+
raise ValueError("embed_dim must be divisible by 2")
|
121 |
+
|
122 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
123 |
+
omega /= embed_dim / 2.0
|
124 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
125 |
+
|
126 |
+
pos = pos.reshape(-1) # (M,)
|
127 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
128 |
+
|
129 |
+
emb_sin = np.sin(out) # (M, D/2)
|
130 |
+
emb_cos = np.cos(out) # (M, D/2)
|
131 |
+
|
132 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
133 |
+
return emb
|
134 |
+
|
135 |
+
|
136 |
+
class PatchEmbed(nn.Module):
|
137 |
+
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
height=224,
|
142 |
+
width=224,
|
143 |
+
patch_size=16,
|
144 |
+
in_channels=3,
|
145 |
+
embed_dim=768,
|
146 |
+
layer_norm=False,
|
147 |
+
flatten=True,
|
148 |
+
bias=True,
|
149 |
+
interpolation_scale=1,
|
150 |
+
pos_embed_type="sincos",
|
151 |
+
pos_embed_max_size=None, # For SD3 cropping
|
152 |
+
):
|
153 |
+
super().__init__()
|
154 |
+
|
155 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
156 |
+
self.flatten = flatten
|
157 |
+
self.layer_norm = layer_norm
|
158 |
+
self.pos_embed_max_size = pos_embed_max_size
|
159 |
+
|
160 |
+
self.proj = nn.Conv2d(
|
161 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
162 |
+
)
|
163 |
+
if layer_norm:
|
164 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
165 |
+
else:
|
166 |
+
self.norm = None
|
167 |
+
|
168 |
+
self.patch_size = patch_size
|
169 |
+
self.height, self.width = height // patch_size, width // patch_size
|
170 |
+
self.base_size = height // patch_size
|
171 |
+
self.interpolation_scale = interpolation_scale
|
172 |
+
|
173 |
+
# Calculate positional embeddings based on max size or default
|
174 |
+
if pos_embed_max_size:
|
175 |
+
grid_size = pos_embed_max_size
|
176 |
+
else:
|
177 |
+
grid_size = int(num_patches**0.5)
|
178 |
+
|
179 |
+
if pos_embed_type is None:
|
180 |
+
self.pos_embed = None
|
181 |
+
elif pos_embed_type == "sincos":
|
182 |
+
pos_embed = get_2d_sincos_pos_embed(
|
183 |
+
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
184 |
+
)
|
185 |
+
persistent = True if pos_embed_max_size else False
|
186 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
|
187 |
+
else:
|
188 |
+
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
189 |
+
|
190 |
+
def cropped_pos_embed(self, height, width):
|
191 |
+
"""Crops positional embeddings for SD3 compatibility."""
|
192 |
+
if self.pos_embed_max_size is None:
|
193 |
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
194 |
+
|
195 |
+
height = height // self.patch_size
|
196 |
+
width = width // self.patch_size
|
197 |
+
if height > self.pos_embed_max_size:
|
198 |
+
raise ValueError(
|
199 |
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
200 |
+
)
|
201 |
+
if width > self.pos_embed_max_size:
|
202 |
+
raise ValueError(
|
203 |
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
204 |
+
)
|
205 |
+
|
206 |
+
top = (self.pos_embed_max_size - height) // 2
|
207 |
+
left = (self.pos_embed_max_size - width) // 2
|
208 |
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
209 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
210 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
211 |
+
return spatial_pos_embed
|
212 |
+
|
213 |
+
def forward(self, latent):
|
214 |
+
if self.pos_embed_max_size is not None:
|
215 |
+
height, width = latent.shape[-2:]
|
216 |
+
else:
|
217 |
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
218 |
+
|
219 |
+
latent = self.proj(latent)
|
220 |
+
if self.flatten:
|
221 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
222 |
+
if self.layer_norm:
|
223 |
+
latent = self.norm(latent)
|
224 |
+
if self.pos_embed is None:
|
225 |
+
return latent.to(latent.dtype)
|
226 |
+
# Interpolate or crop positional embeddings as needed
|
227 |
+
if self.pos_embed_max_size:
|
228 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
229 |
+
else:
|
230 |
+
if self.height != height or self.width != width:
|
231 |
+
pos_embed = get_2d_sincos_pos_embed(
|
232 |
+
embed_dim=self.pos_embed.shape[-1],
|
233 |
+
grid_size=(height, width),
|
234 |
+
base_size=self.base_size,
|
235 |
+
interpolation_scale=self.interpolation_scale,
|
236 |
+
)
|
237 |
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
238 |
+
else:
|
239 |
+
pos_embed = self.pos_embed
|
240 |
+
|
241 |
+
return (latent + pos_embed).to(latent.dtype)
|
242 |
+
|
243 |
+
|
244 |
+
class LuminaPatchEmbed(nn.Module):
|
245 |
+
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
246 |
+
|
247 |
+
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
248 |
+
super().__init__()
|
249 |
+
self.patch_size = patch_size
|
250 |
+
self.proj = nn.Linear(
|
251 |
+
in_features=patch_size * patch_size * in_channels,
|
252 |
+
out_features=embed_dim,
|
253 |
+
bias=bias,
|
254 |
+
)
|
255 |
+
|
256 |
+
def forward(self, x, freqs_cis):
|
257 |
+
"""
|
258 |
+
Patchifies and embeds the input tensor(s).
|
259 |
+
|
260 |
+
Args:
|
261 |
+
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
265 |
+
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
266 |
+
frequency tensor(s).
|
267 |
+
"""
|
268 |
+
freqs_cis = freqs_cis.to(x[0].device)
|
269 |
+
patch_height = patch_width = self.patch_size
|
270 |
+
batch_size, channel, height, width = x.size()
|
271 |
+
height_tokens, width_tokens = height // patch_height, width // patch_width
|
272 |
+
|
273 |
+
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
|
274 |
+
0, 2, 4, 1, 3, 5
|
275 |
+
)
|
276 |
+
x = x.flatten(3)
|
277 |
+
x = self.proj(x)
|
278 |
+
x = x.flatten(1, 2)
|
279 |
+
|
280 |
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
281 |
+
|
282 |
+
return (
|
283 |
+
x,
|
284 |
+
mask,
|
285 |
+
[(height, width)] * batch_size,
|
286 |
+
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
291 |
+
"""
|
292 |
+
RoPE for image tokens with 2d structure.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
embed_dim: (`int`):
|
296 |
+
The embedding dimension size
|
297 |
+
crops_coords (`Tuple[int]`)
|
298 |
+
The top-left and bottom-right coordinates of the crop.
|
299 |
+
grid_size (`Tuple[int]`):
|
300 |
+
The grid size of the positional embedding.
|
301 |
+
use_real (`bool`):
|
302 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
306 |
+
"""
|
307 |
+
start, stop = crops_coords
|
308 |
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
309 |
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
310 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
311 |
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
312 |
+
|
313 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
314 |
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
315 |
+
return pos_embed
|
316 |
+
|
317 |
+
|
318 |
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
319 |
+
assert embed_dim % 4 == 0
|
320 |
+
|
321 |
+
# use half of dimensions to encode grid_h
|
322 |
+
emb_h = get_1d_rotary_pos_embed(
|
323 |
+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
324 |
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
325 |
+
emb_w = get_1d_rotary_pos_embed(
|
326 |
+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
327 |
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
328 |
+
|
329 |
+
if use_real:
|
330 |
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
331 |
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
332 |
+
return cos, sin
|
333 |
+
else:
|
334 |
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
335 |
+
return emb
|
336 |
+
|
337 |
+
|
338 |
+
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
339 |
+
assert embed_dim % 4 == 0
|
340 |
+
|
341 |
+
emb_h = get_1d_rotary_pos_embed(
|
342 |
+
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
343 |
+
) # (H, D/4)
|
344 |
+
emb_w = get_1d_rotary_pos_embed(
|
345 |
+
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
346 |
+
) # (W, D/4)
|
347 |
+
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
348 |
+
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
349 |
+
|
350 |
+
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
351 |
+
return emb
|
352 |
+
|
353 |
+
|
354 |
+
def get_1d_rotary_pos_embed(
|
355 |
+
dim: int,
|
356 |
+
pos: Union[np.ndarray, int],
|
357 |
+
theta: float = 10000.0,
|
358 |
+
use_real=False,
|
359 |
+
linear_factor=1.0,
|
360 |
+
ntk_factor=1.0,
|
361 |
+
repeat_interleave_real=True,
|
362 |
+
):
|
363 |
+
"""
|
364 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
365 |
+
|
366 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
367 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
368 |
+
data type.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
dim (`int`): Dimension of the frequency tensor.
|
372 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
373 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
374 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
375 |
+
use_real (`bool`, *optional*):
|
376 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
377 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
378 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
379 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
380 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
381 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
382 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
383 |
+
Otherwise, they are concateanted with themselves.
|
384 |
+
Returns:
|
385 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
386 |
+
"""
|
387 |
+
assert dim % 2 == 0
|
388 |
+
|
389 |
+
if isinstance(pos, int):
|
390 |
+
pos = np.arange(pos)
|
391 |
+
theta = theta * ntk_factor
|
392 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
|
393 |
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
394 |
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
395 |
+
if use_real and repeat_interleave_real:
|
396 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
397 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
398 |
+
return freqs_cos, freqs_sin
|
399 |
+
elif use_real:
|
400 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
|
401 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
|
402 |
+
return freqs_cos, freqs_sin
|
403 |
+
else:
|
404 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
405 |
+
return freqs_cis
|
406 |
+
|
407 |
+
|
408 |
+
def apply_rotary_emb(
|
409 |
+
x: torch.Tensor,
|
410 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
411 |
+
use_real: bool = True,
|
412 |
+
use_real_unbind_dim: int = -1,
|
413 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
414 |
+
"""
|
415 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
416 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
417 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
418 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
x (`torch.Tensor`):
|
422 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
423 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
427 |
+
"""
|
428 |
+
if use_real:
|
429 |
+
cos, sin = freqs_cis # [S, D]
|
430 |
+
cos = cos[None, None]
|
431 |
+
sin = sin[None, None]
|
432 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
433 |
+
|
434 |
+
if use_real_unbind_dim == -1:
|
435 |
+
# Use for example in Lumina
|
436 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
437 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
438 |
+
elif use_real_unbind_dim == -2:
|
439 |
+
# Use for example in Stable Audio
|
440 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
441 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
442 |
+
else:
|
443 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
444 |
+
|
445 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
446 |
+
|
447 |
+
return out
|
448 |
+
else:
|
449 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
450 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
451 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
452 |
+
|
453 |
+
return x_out.type_as(x)
|
454 |
+
|
455 |
+
|
456 |
+
class TimestepEmbedding(nn.Module):
|
457 |
+
def __init__(
|
458 |
+
self,
|
459 |
+
in_channels: int,
|
460 |
+
time_embed_dim: int,
|
461 |
+
act_fn: str = "silu",
|
462 |
+
out_dim: int = None,
|
463 |
+
post_act_fn: Optional[str] = None,
|
464 |
+
cond_proj_dim=None,
|
465 |
+
sample_proj_bias=True,
|
466 |
+
):
|
467 |
+
super().__init__()
|
468 |
+
|
469 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
470 |
+
|
471 |
+
if cond_proj_dim is not None:
|
472 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
473 |
+
else:
|
474 |
+
self.cond_proj = None
|
475 |
+
|
476 |
+
self.act = get_activation(act_fn)
|
477 |
+
|
478 |
+
if out_dim is not None:
|
479 |
+
time_embed_dim_out = out_dim
|
480 |
+
else:
|
481 |
+
time_embed_dim_out = time_embed_dim
|
482 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
483 |
+
|
484 |
+
if post_act_fn is None:
|
485 |
+
self.post_act = None
|
486 |
+
else:
|
487 |
+
self.post_act = get_activation(post_act_fn)
|
488 |
+
|
489 |
+
def forward(self, sample, condition=None):
|
490 |
+
if condition is not None:
|
491 |
+
sample = sample + self.cond_proj(condition)
|
492 |
+
sample = self.linear_1(sample)
|
493 |
+
|
494 |
+
if self.act is not None:
|
495 |
+
sample = self.act(sample)
|
496 |
+
|
497 |
+
sample = self.linear_2(sample)
|
498 |
+
|
499 |
+
if self.post_act is not None:
|
500 |
+
sample = self.post_act(sample)
|
501 |
+
return sample
|
502 |
+
|
503 |
+
|
504 |
+
class Timesteps(nn.Module):
|
505 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
506 |
+
super().__init__()
|
507 |
+
self.num_channels = num_channels
|
508 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
509 |
+
self.downscale_freq_shift = downscale_freq_shift
|
510 |
+
self.scale = scale
|
511 |
+
|
512 |
+
def forward(self, timesteps):
|
513 |
+
t_emb = get_timestep_embedding(
|
514 |
+
timesteps,
|
515 |
+
self.num_channels,
|
516 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
517 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
518 |
+
scale=self.scale,
|
519 |
+
)
|
520 |
+
return t_emb
|
521 |
+
|
522 |
+
|
523 |
+
class GaussianFourierProjection(nn.Module):
|
524 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
525 |
+
|
526 |
+
def __init__(
|
527 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
528 |
+
):
|
529 |
+
super().__init__()
|
530 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
531 |
+
self.log = log
|
532 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
533 |
+
|
534 |
+
if set_W_to_weight:
|
535 |
+
# to delete later
|
536 |
+
del self.weight
|
537 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
538 |
+
self.weight = self.W
|
539 |
+
del self.W
|
540 |
+
|
541 |
+
def forward(self, x):
|
542 |
+
if self.log:
|
543 |
+
x = torch.log(x)
|
544 |
+
|
545 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
546 |
+
|
547 |
+
if self.flip_sin_to_cos:
|
548 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
549 |
+
else:
|
550 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
551 |
+
return out
|
552 |
+
|
553 |
+
|
554 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
555 |
+
"""Apply positional information to a sequence of embeddings.
|
556 |
+
|
557 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
558 |
+
them
|
559 |
+
|
560 |
+
Args:
|
561 |
+
embed_dim: (int): Dimension of the positional embedding.
|
562 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
563 |
+
|
564 |
+
"""
|
565 |
+
|
566 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
567 |
+
super().__init__()
|
568 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
569 |
+
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
|
570 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
571 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
572 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
573 |
+
self.register_buffer("pe", pe)
|
574 |
+
|
575 |
+
def forward(self, x):
|
576 |
+
_, seq_length, _ = x.shape
|
577 |
+
x = x + self.pe[:, :seq_length]
|
578 |
+
return x
|
579 |
+
|
580 |
+
|
581 |
+
class ImagePositionalEmbeddings(nn.Module):
|
582 |
+
"""
|
583 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
584 |
+
height and width of the latent space.
|
585 |
+
|
586 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
587 |
+
|
588 |
+
For VQ-diffusion:
|
589 |
+
|
590 |
+
Output vector embeddings are used as input for the transformer.
|
591 |
+
|
592 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
num_embed (`int`):
|
596 |
+
Number of embeddings for the latent pixels embeddings.
|
597 |
+
height (`int`):
|
598 |
+
Height of the latent image i.e. the number of height embeddings.
|
599 |
+
width (`int`):
|
600 |
+
Width of the latent image i.e. the number of width embeddings.
|
601 |
+
embed_dim (`int`):
|
602 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
603 |
+
"""
|
604 |
+
|
605 |
+
def __init__(
|
606 |
+
self,
|
607 |
+
num_embed: int,
|
608 |
+
height: int,
|
609 |
+
width: int,
|
610 |
+
embed_dim: int,
|
611 |
+
):
|
612 |
+
super().__init__()
|
613 |
+
|
614 |
+
self.height = height
|
615 |
+
self.width = width
|
616 |
+
self.num_embed = num_embed
|
617 |
+
self.embed_dim = embed_dim
|
618 |
+
|
619 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
620 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
621 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
622 |
+
|
623 |
+
def forward(self, index):
|
624 |
+
emb = self.emb(index)
|
625 |
+
|
626 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
627 |
+
|
628 |
+
# 1 x H x D -> 1 x H x 1 x D
|
629 |
+
height_emb = height_emb.unsqueeze(2)
|
630 |
+
|
631 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
632 |
+
|
633 |
+
# 1 x W x D -> 1 x 1 x W x D
|
634 |
+
width_emb = width_emb.unsqueeze(1)
|
635 |
+
|
636 |
+
pos_emb = height_emb + width_emb
|
637 |
+
|
638 |
+
# 1 x H x W x D -> 1 x L xD
|
639 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
640 |
+
|
641 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
642 |
+
|
643 |
+
return emb
|
644 |
+
|
645 |
+
|
646 |
+
class LabelEmbedding(nn.Module):
|
647 |
+
"""
|
648 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
num_classes (`int`): The number of classes.
|
652 |
+
hidden_size (`int`): The size of the vector embeddings.
|
653 |
+
dropout_prob (`float`): The probability of dropping a label.
|
654 |
+
"""
|
655 |
+
|
656 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
657 |
+
super().__init__()
|
658 |
+
use_cfg_embedding = dropout_prob > 0
|
659 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
660 |
+
self.num_classes = num_classes
|
661 |
+
self.dropout_prob = dropout_prob
|
662 |
+
|
663 |
+
def token_drop(self, labels, force_drop_ids=None):
|
664 |
+
"""
|
665 |
+
Drops labels to enable classifier-free guidance.
|
666 |
+
"""
|
667 |
+
if force_drop_ids is None:
|
668 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
669 |
+
else:
|
670 |
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
671 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
672 |
+
return labels
|
673 |
+
|
674 |
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
675 |
+
use_dropout = self.dropout_prob > 0
|
676 |
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
677 |
+
labels = self.token_drop(labels, force_drop_ids)
|
678 |
+
embeddings = self.embedding_table(labels)
|
679 |
+
return embeddings
|
680 |
+
|
681 |
+
|
682 |
+
class TextImageProjection(nn.Module):
|
683 |
+
def __init__(
|
684 |
+
self,
|
685 |
+
text_embed_dim: int = 1024,
|
686 |
+
image_embed_dim: int = 768,
|
687 |
+
cross_attention_dim: int = 768,
|
688 |
+
num_image_text_embeds: int = 10,
|
689 |
+
):
|
690 |
+
super().__init__()
|
691 |
+
|
692 |
+
self.num_image_text_embeds = num_image_text_embeds
|
693 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
694 |
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
695 |
+
|
696 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
697 |
+
batch_size = text_embeds.shape[0]
|
698 |
+
|
699 |
+
# image
|
700 |
+
image_text_embeds = self.image_embeds(image_embeds)
|
701 |
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
702 |
+
|
703 |
+
# text
|
704 |
+
text_embeds = self.text_proj(text_embeds)
|
705 |
+
|
706 |
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
707 |
+
|
708 |
+
|
709 |
+
class ImageProjection(nn.Module):
|
710 |
+
def __init__(
|
711 |
+
self,
|
712 |
+
image_embed_dim: int = 768,
|
713 |
+
cross_attention_dim: int = 768,
|
714 |
+
num_image_text_embeds: int = 32,
|
715 |
+
):
|
716 |
+
super().__init__()
|
717 |
+
|
718 |
+
self.num_image_text_embeds = num_image_text_embeds
|
719 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
720 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
721 |
+
|
722 |
+
def forward(self, image_embeds: torch.Tensor):
|
723 |
+
batch_size = image_embeds.shape[0]
|
724 |
+
|
725 |
+
# image
|
726 |
+
image_embeds = self.image_embeds(image_embeds)
|
727 |
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
728 |
+
image_embeds = self.norm(image_embeds)
|
729 |
+
return image_embeds
|
730 |
+
|
731 |
+
|
732 |
+
class IPAdapterFullImageProjection(nn.Module):
|
733 |
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
734 |
+
super().__init__()
|
735 |
+
from .attention import FeedForward
|
736 |
+
|
737 |
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
738 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
739 |
+
|
740 |
+
def forward(self, image_embeds: torch.Tensor):
|
741 |
+
return self.norm(self.ff(image_embeds))
|
742 |
+
|
743 |
+
|
744 |
+
class IPAdapterFaceIDImageProjection(nn.Module):
|
745 |
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
746 |
+
super().__init__()
|
747 |
+
from .attention import FeedForward
|
748 |
+
|
749 |
+
self.num_tokens = num_tokens
|
750 |
+
self.cross_attention_dim = cross_attention_dim
|
751 |
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
752 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
753 |
+
|
754 |
+
def forward(self, image_embeds: torch.Tensor):
|
755 |
+
x = self.ff(image_embeds)
|
756 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
757 |
+
return self.norm(x)
|
758 |
+
|
759 |
+
|
760 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
761 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
762 |
+
super().__init__()
|
763 |
+
|
764 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
765 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
766 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
767 |
+
|
768 |
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
769 |
+
timesteps_proj = self.time_proj(timestep)
|
770 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
771 |
+
|
772 |
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
773 |
+
|
774 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
775 |
+
|
776 |
+
return conditioning
|
777 |
+
|
778 |
+
|
779 |
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
780 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
781 |
+
super().__init__()
|
782 |
+
|
783 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
784 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
785 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
786 |
+
|
787 |
+
def forward(self, timestep, pooled_projection):
|
788 |
+
timesteps_proj = self.time_proj(timestep)
|
789 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
790 |
+
|
791 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
792 |
+
|
793 |
+
conditioning = timesteps_emb + pooled_projections
|
794 |
+
|
795 |
+
return conditioning
|
796 |
+
|
797 |
+
|
798 |
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
799 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
800 |
+
super().__init__()
|
801 |
+
|
802 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
803 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
804 |
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
805 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
806 |
+
|
807 |
+
def forward(self, timestep, guidance, pooled_projection):
|
808 |
+
timesteps_proj = self.time_proj(timestep)
|
809 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
810 |
+
|
811 |
+
guidance_proj = self.time_proj(guidance)
|
812 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
813 |
+
|
814 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
815 |
+
|
816 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
817 |
+
conditioning = time_guidance_emb + pooled_projections
|
818 |
+
|
819 |
+
return conditioning
|
820 |
+
|
821 |
+
|
822 |
+
class HunyuanDiTAttentionPool(nn.Module):
|
823 |
+
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
824 |
+
|
825 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
826 |
+
super().__init__()
|
827 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
|
828 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
829 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
830 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
831 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
832 |
+
self.num_heads = num_heads
|
833 |
+
|
834 |
+
def forward(self, x):
|
835 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
836 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
837 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
838 |
+
x, _ = F.multi_head_attention_forward(
|
839 |
+
query=x[:1],
|
840 |
+
key=x,
|
841 |
+
value=x,
|
842 |
+
embed_dim_to_check=x.shape[-1],
|
843 |
+
num_heads=self.num_heads,
|
844 |
+
q_proj_weight=self.q_proj.weight,
|
845 |
+
k_proj_weight=self.k_proj.weight,
|
846 |
+
v_proj_weight=self.v_proj.weight,
|
847 |
+
in_proj_weight=None,
|
848 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
849 |
+
bias_k=None,
|
850 |
+
bias_v=None,
|
851 |
+
add_zero_attn=False,
|
852 |
+
dropout_p=0,
|
853 |
+
out_proj_weight=self.c_proj.weight,
|
854 |
+
out_proj_bias=self.c_proj.bias,
|
855 |
+
use_separate_proj_weight=True,
|
856 |
+
training=self.training,
|
857 |
+
need_weights=False,
|
858 |
+
)
|
859 |
+
return x.squeeze(0)
|
860 |
+
|
861 |
+
|
862 |
+
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
863 |
+
def __init__(
|
864 |
+
self,
|
865 |
+
embedding_dim,
|
866 |
+
pooled_projection_dim=1024,
|
867 |
+
seq_len=256,
|
868 |
+
cross_attention_dim=2048,
|
869 |
+
use_style_cond_and_image_meta_size=True,
|
870 |
+
):
|
871 |
+
super().__init__()
|
872 |
+
|
873 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
874 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
875 |
+
|
876 |
+
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
877 |
+
|
878 |
+
self.pooler = HunyuanDiTAttentionPool(
|
879 |
+
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
880 |
+
)
|
881 |
+
|
882 |
+
# Here we use a default learned embedder layer for future extension.
|
883 |
+
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
|
884 |
+
if use_style_cond_and_image_meta_size:
|
885 |
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
886 |
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
887 |
+
else:
|
888 |
+
extra_in_dim = pooled_projection_dim
|
889 |
+
|
890 |
+
self.extra_embedder = PixArtAlphaTextProjection(
|
891 |
+
in_features=extra_in_dim,
|
892 |
+
hidden_size=embedding_dim * 4,
|
893 |
+
out_features=embedding_dim,
|
894 |
+
act_fn="silu_fp32",
|
895 |
+
)
|
896 |
+
|
897 |
+
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
|
898 |
+
timesteps_proj = self.time_proj(timestep)
|
899 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
|
900 |
+
|
901 |
+
# extra condition1: text
|
902 |
+
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
903 |
+
|
904 |
+
if self.use_style_cond_and_image_meta_size:
|
905 |
+
# extra condition2: image meta size embedding
|
906 |
+
image_meta_size = self.size_proj(image_meta_size.view(-1))
|
907 |
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
908 |
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
909 |
+
|
910 |
+
# extra condition3: style embedding
|
911 |
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
912 |
+
|
913 |
+
# Concatenate all extra vectors
|
914 |
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
915 |
+
else:
|
916 |
+
extra_cond = torch.cat([pooled_projections], dim=1)
|
917 |
+
|
918 |
+
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
919 |
+
|
920 |
+
return conditioning
|
921 |
+
|
922 |
+
|
923 |
+
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
924 |
+
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
|
925 |
+
super().__init__()
|
926 |
+
self.time_proj = Timesteps(
|
927 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
928 |
+
)
|
929 |
+
|
930 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
931 |
+
|
932 |
+
self.caption_embedder = nn.Sequential(
|
933 |
+
nn.LayerNorm(cross_attention_dim),
|
934 |
+
nn.Linear(
|
935 |
+
cross_attention_dim,
|
936 |
+
hidden_size,
|
937 |
+
bias=True,
|
938 |
+
),
|
939 |
+
)
|
940 |
+
|
941 |
+
def forward(self, timestep, caption_feat, caption_mask):
|
942 |
+
# timestep embedding:
|
943 |
+
time_freq = self.time_proj(timestep)
|
944 |
+
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
945 |
+
|
946 |
+
# caption condition embedding:
|
947 |
+
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
948 |
+
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
|
949 |
+
caption_feats_pool = caption_feats_pool.to(caption_feat)
|
950 |
+
caption_embed = self.caption_embedder(caption_feats_pool)
|
951 |
+
|
952 |
+
conditioning = time_embed + caption_embed
|
953 |
+
|
954 |
+
return conditioning
|
955 |
+
|
956 |
+
|
957 |
+
class TextTimeEmbedding(nn.Module):
|
958 |
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
959 |
+
super().__init__()
|
960 |
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
961 |
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
962 |
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
963 |
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
964 |
+
|
965 |
+
def forward(self, hidden_states):
|
966 |
+
hidden_states = self.norm1(hidden_states)
|
967 |
+
hidden_states = self.pool(hidden_states)
|
968 |
+
hidden_states = self.proj(hidden_states)
|
969 |
+
hidden_states = self.norm2(hidden_states)
|
970 |
+
return hidden_states
|
971 |
+
|
972 |
+
|
973 |
+
class TextImageTimeEmbedding(nn.Module):
|
974 |
+
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
975 |
+
super().__init__()
|
976 |
+
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
977 |
+
self.text_norm = nn.LayerNorm(time_embed_dim)
|
978 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
979 |
+
|
980 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
981 |
+
# text
|
982 |
+
time_text_embeds = self.text_proj(text_embeds)
|
983 |
+
time_text_embeds = self.text_norm(time_text_embeds)
|
984 |
+
|
985 |
+
# image
|
986 |
+
time_image_embeds = self.image_proj(image_embeds)
|
987 |
+
|
988 |
+
return time_image_embeds + time_text_embeds
|
989 |
+
|
990 |
+
|
991 |
+
class ImageTimeEmbedding(nn.Module):
|
992 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
993 |
+
super().__init__()
|
994 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
995 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
996 |
+
|
997 |
+
def forward(self, image_embeds: torch.Tensor):
|
998 |
+
# image
|
999 |
+
time_image_embeds = self.image_proj(image_embeds)
|
1000 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
1001 |
+
return time_image_embeds
|
1002 |
+
|
1003 |
+
|
1004 |
+
class ImageHintTimeEmbedding(nn.Module):
|
1005 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
1006 |
+
super().__init__()
|
1007 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
1008 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
1009 |
+
self.input_hint_block = nn.Sequential(
|
1010 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
1011 |
+
nn.SiLU(),
|
1012 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
1013 |
+
nn.SiLU(),
|
1014 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
1015 |
+
nn.SiLU(),
|
1016 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
1017 |
+
nn.SiLU(),
|
1018 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
1019 |
+
nn.SiLU(),
|
1020 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
1021 |
+
nn.SiLU(),
|
1022 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
1023 |
+
nn.SiLU(),
|
1024 |
+
nn.Conv2d(256, 4, 3, padding=1),
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
|
1028 |
+
# image
|
1029 |
+
time_image_embeds = self.image_proj(image_embeds)
|
1030 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
1031 |
+
hint = self.input_hint_block(hint)
|
1032 |
+
return time_image_embeds, hint
|
1033 |
+
|
1034 |
+
|
1035 |
+
class AttentionPooling(nn.Module):
|
1036 |
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
1037 |
+
|
1038 |
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
1039 |
+
super().__init__()
|
1040 |
+
self.dtype = dtype
|
1041 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
1042 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
1043 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
1044 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
1045 |
+
self.num_heads = num_heads
|
1046 |
+
self.dim_per_head = embed_dim // self.num_heads
|
1047 |
+
|
1048 |
+
def forward(self, x):
|
1049 |
+
bs, length, width = x.size()
|
1050 |
+
|
1051 |
+
def shape(x):
|
1052 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
1053 |
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
1054 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
1055 |
+
x = x.transpose(1, 2)
|
1056 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
1057 |
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
1058 |
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
1059 |
+
x = x.transpose(1, 2)
|
1060 |
+
return x
|
1061 |
+
|
1062 |
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
1063 |
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
1064 |
+
|
1065 |
+
# (bs*n_heads, class_token_length, dim_per_head)
|
1066 |
+
q = shape(self.q_proj(class_token))
|
1067 |
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
1068 |
+
k = shape(self.k_proj(x))
|
1069 |
+
v = shape(self.v_proj(x))
|
1070 |
+
|
1071 |
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
1072 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
1073 |
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
1074 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
1075 |
+
|
1076 |
+
# (bs*n_heads, dim_per_head, class_token_length)
|
1077 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
1078 |
+
|
1079 |
+
# (bs, length+1, width)
|
1080 |
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
1081 |
+
|
1082 |
+
return a[:, 0, :] # cls_token
|
1083 |
+
|
1084 |
+
|
1085 |
+
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
1086 |
+
"""
|
1087 |
+
Args:
|
1088 |
+
embed_dim: int
|
1089 |
+
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
1090 |
+
Returns:
|
1091 |
+
[B x N x embed_dim] tensor of positional embeddings
|
1092 |
+
"""
|
1093 |
+
|
1094 |
+
batch_size, num_boxes = box.shape[:2]
|
1095 |
+
|
1096 |
+
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
1097 |
+
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
1098 |
+
emb = emb * box.unsqueeze(-1)
|
1099 |
+
|
1100 |
+
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
1101 |
+
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
1102 |
+
|
1103 |
+
return emb
|
1104 |
+
|
1105 |
+
|
1106 |
+
class GLIGENTextBoundingboxProjection(nn.Module):
|
1107 |
+
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
1108 |
+
super().__init__()
|
1109 |
+
self.positive_len = positive_len
|
1110 |
+
self.out_dim = out_dim
|
1111 |
+
|
1112 |
+
self.fourier_embedder_dim = fourier_freqs
|
1113 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
1114 |
+
|
1115 |
+
if isinstance(out_dim, tuple):
|
1116 |
+
out_dim = out_dim[0]
|
1117 |
+
|
1118 |
+
if feature_type == "text-only":
|
1119 |
+
self.linears = nn.Sequential(
|
1120 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
1121 |
+
nn.SiLU(),
|
1122 |
+
nn.Linear(512, 512),
|
1123 |
+
nn.SiLU(),
|
1124 |
+
nn.Linear(512, out_dim),
|
1125 |
+
)
|
1126 |
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
1127 |
+
|
1128 |
+
elif feature_type == "text-image":
|
1129 |
+
self.linears_text = nn.Sequential(
|
1130 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
1131 |
+
nn.SiLU(),
|
1132 |
+
nn.Linear(512, 512),
|
1133 |
+
nn.SiLU(),
|
1134 |
+
nn.Linear(512, out_dim),
|
1135 |
+
)
|
1136 |
+
self.linears_image = nn.Sequential(
|
1137 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
1138 |
+
nn.SiLU(),
|
1139 |
+
nn.Linear(512, 512),
|
1140 |
+
nn.SiLU(),
|
1141 |
+
nn.Linear(512, out_dim),
|
1142 |
+
)
|
1143 |
+
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
1144 |
+
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
1145 |
+
|
1146 |
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
1147 |
+
|
1148 |
+
def forward(
|
1149 |
+
self,
|
1150 |
+
boxes,
|
1151 |
+
masks,
|
1152 |
+
positive_embeddings=None,
|
1153 |
+
phrases_masks=None,
|
1154 |
+
image_masks=None,
|
1155 |
+
phrases_embeddings=None,
|
1156 |
+
image_embeddings=None,
|
1157 |
+
):
|
1158 |
+
masks = masks.unsqueeze(-1)
|
1159 |
+
|
1160 |
+
# embedding position (it may includes padding as placeholder)
|
1161 |
+
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
|
1162 |
+
|
1163 |
+
# learnable null embedding
|
1164 |
+
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
1165 |
+
|
1166 |
+
# replace padding with learnable null embedding
|
1167 |
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
1168 |
+
|
1169 |
+
# positionet with text only information
|
1170 |
+
if positive_embeddings is not None:
|
1171 |
+
# learnable null embedding
|
1172 |
+
positive_null = self.null_positive_feature.view(1, 1, -1)
|
1173 |
+
|
1174 |
+
# replace padding with learnable null embedding
|
1175 |
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
1176 |
+
|
1177 |
+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
1178 |
+
|
1179 |
+
# positionet with text and image information
|
1180 |
+
else:
|
1181 |
+
phrases_masks = phrases_masks.unsqueeze(-1)
|
1182 |
+
image_masks = image_masks.unsqueeze(-1)
|
1183 |
+
|
1184 |
+
# learnable null embedding
|
1185 |
+
text_null = self.null_text_feature.view(1, 1, -1)
|
1186 |
+
image_null = self.null_image_feature.view(1, 1, -1)
|
1187 |
+
|
1188 |
+
# replace padding with learnable null embedding
|
1189 |
+
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
|
1190 |
+
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
|
1191 |
+
|
1192 |
+
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
|
1193 |
+
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
|
1194 |
+
objs = torch.cat([objs_text, objs_image], dim=1)
|
1195 |
+
|
1196 |
+
return objs
|
1197 |
+
|
1198 |
+
|
1199 |
+
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
1200 |
+
"""
|
1201 |
+
For PixArt-Alpha.
|
1202 |
+
|
1203 |
+
Reference:
|
1204 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
1205 |
+
"""
|
1206 |
+
|
1207 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
1208 |
+
super().__init__()
|
1209 |
+
|
1210 |
+
self.outdim = size_emb_dim
|
1211 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1212 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
1213 |
+
|
1214 |
+
self.use_additional_conditions = use_additional_conditions
|
1215 |
+
if use_additional_conditions:
|
1216 |
+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1217 |
+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
1218 |
+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
1219 |
+
|
1220 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
1221 |
+
timesteps_proj = self.time_proj(timestep)
|
1222 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
1223 |
+
|
1224 |
+
if self.use_additional_conditions:
|
1225 |
+
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
1226 |
+
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
1227 |
+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
1228 |
+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
1229 |
+
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
1230 |
+
else:
|
1231 |
+
conditioning = timesteps_emb
|
1232 |
+
|
1233 |
+
return conditioning
|
1234 |
+
|
1235 |
+
|
1236 |
+
class PixArtAlphaTextProjection(nn.Module):
|
1237 |
+
"""
|
1238 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
1239 |
+
|
1240 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
1241 |
+
"""
|
1242 |
+
|
1243 |
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
1244 |
+
super().__init__()
|
1245 |
+
if out_features is None:
|
1246 |
+
out_features = hidden_size
|
1247 |
+
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
1248 |
+
if act_fn == "gelu_tanh":
|
1249 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
1250 |
+
elif act_fn == "silu":
|
1251 |
+
self.act_1 = nn.SiLU()
|
1252 |
+
elif act_fn == "silu_fp32":
|
1253 |
+
self.act_1 = FP32SiLU()
|
1254 |
+
else:
|
1255 |
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
1256 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
1257 |
+
|
1258 |
+
def forward(self, caption):
|
1259 |
+
hidden_states = self.linear_1(caption)
|
1260 |
+
hidden_states = self.act_1(hidden_states)
|
1261 |
+
hidden_states = self.linear_2(hidden_states)
|
1262 |
+
return hidden_states
|
1263 |
+
|
1264 |
+
|
1265 |
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1266 |
+
def __init__(
|
1267 |
+
self,
|
1268 |
+
embed_dims: int = 768,
|
1269 |
+
dim_head: int = 64,
|
1270 |
+
heads: int = 16,
|
1271 |
+
ffn_ratio: float = 4,
|
1272 |
+
) -> None:
|
1273 |
+
super().__init__()
|
1274 |
+
from .attention import FeedForward
|
1275 |
+
|
1276 |
+
self.ln0 = nn.LayerNorm(embed_dims)
|
1277 |
+
self.ln1 = nn.LayerNorm(embed_dims)
|
1278 |
+
self.attn = Attention(
|
1279 |
+
query_dim=embed_dims,
|
1280 |
+
dim_head=dim_head,
|
1281 |
+
heads=heads,
|
1282 |
+
out_bias=False,
|
1283 |
+
)
|
1284 |
+
self.ff = nn.Sequential(
|
1285 |
+
nn.LayerNorm(embed_dims),
|
1286 |
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
def forward(self, x, latents, residual):
|
1290 |
+
encoder_hidden_states = self.ln0(x)
|
1291 |
+
latents = self.ln1(latents)
|
1292 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1293 |
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
1294 |
+
latents = self.ff(latents) + latents
|
1295 |
+
return latents
|
1296 |
+
|
1297 |
+
|
1298 |
+
class IPAdapterPlusImageProjection(nn.Module):
|
1299 |
+
"""Resampler of IP-Adapter Plus.
|
1300 |
+
|
1301 |
+
Args:
|
1302 |
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
1303 |
+
that is the same
|
1304 |
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
1305 |
+
hidden_dims (int):
|
1306 |
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
1307 |
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
1308 |
+
Defaults to 16. num_queries (int):
|
1309 |
+
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
1310 |
+
of feedforward network hidden
|
1311 |
+
layer channels. Defaults to 4.
|
1312 |
+
"""
|
1313 |
+
|
1314 |
+
def __init__(
|
1315 |
+
self,
|
1316 |
+
embed_dims: int = 768,
|
1317 |
+
output_dims: int = 1024,
|
1318 |
+
hidden_dims: int = 1280,
|
1319 |
+
depth: int = 4,
|
1320 |
+
dim_head: int = 64,
|
1321 |
+
heads: int = 16,
|
1322 |
+
num_queries: int = 8,
|
1323 |
+
ffn_ratio: float = 4,
|
1324 |
+
) -> None:
|
1325 |
+
super().__init__()
|
1326 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
1327 |
+
|
1328 |
+
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
1329 |
+
|
1330 |
+
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
1331 |
+
self.norm_out = nn.LayerNorm(output_dims)
|
1332 |
+
|
1333 |
+
self.layers = nn.ModuleList(
|
1334 |
+
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1338 |
+
"""Forward pass.
|
1339 |
+
|
1340 |
+
Args:
|
1341 |
+
x (torch.Tensor): Input Tensor.
|
1342 |
+
Returns:
|
1343 |
+
torch.Tensor: Output Tensor.
|
1344 |
+
"""
|
1345 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
1346 |
+
|
1347 |
+
x = self.proj_in(x)
|
1348 |
+
|
1349 |
+
for block in self.layers:
|
1350 |
+
residual = latents
|
1351 |
+
latents = block(x, latents, residual)
|
1352 |
+
|
1353 |
+
latents = self.proj_out(latents)
|
1354 |
+
return self.norm_out(latents)
|
1355 |
+
|
1356 |
+
|
1357 |
+
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
1358 |
+
"""FacePerceiverResampler of IP-Adapter Plus.
|
1359 |
+
|
1360 |
+
Args:
|
1361 |
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
1362 |
+
that is the same
|
1363 |
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
1364 |
+
hidden_dims (int):
|
1365 |
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
1366 |
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
1367 |
+
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
1368 |
+
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
1369 |
+
layer channels. Defaults to 4.
|
1370 |
+
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
1371 |
+
layer channels (for ID embeddings). Defaults to 4.
|
1372 |
+
"""
|
1373 |
+
|
1374 |
+
def __init__(
|
1375 |
+
self,
|
1376 |
+
embed_dims: int = 768,
|
1377 |
+
output_dims: int = 768,
|
1378 |
+
hidden_dims: int = 1280,
|
1379 |
+
id_embeddings_dim: int = 512,
|
1380 |
+
depth: int = 4,
|
1381 |
+
dim_head: int = 64,
|
1382 |
+
heads: int = 16,
|
1383 |
+
num_tokens: int = 4,
|
1384 |
+
num_queries: int = 8,
|
1385 |
+
ffn_ratio: float = 4,
|
1386 |
+
ffproj_ratio: int = 2,
|
1387 |
+
) -> None:
|
1388 |
+
super().__init__()
|
1389 |
+
from .attention import FeedForward
|
1390 |
+
|
1391 |
+
self.num_tokens = num_tokens
|
1392 |
+
self.embed_dim = embed_dims
|
1393 |
+
self.clip_embeds = None
|
1394 |
+
self.shortcut = False
|
1395 |
+
self.shortcut_scale = 1.0
|
1396 |
+
|
1397 |
+
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
1398 |
+
self.norm = nn.LayerNorm(embed_dims)
|
1399 |
+
|
1400 |
+
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
1401 |
+
|
1402 |
+
self.proj_out = nn.Linear(embed_dims, output_dims)
|
1403 |
+
self.norm_out = nn.LayerNorm(output_dims)
|
1404 |
+
|
1405 |
+
self.layers = nn.ModuleList(
|
1406 |
+
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
1407 |
+
)
|
1408 |
+
|
1409 |
+
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
1410 |
+
"""Forward pass.
|
1411 |
+
|
1412 |
+
Args:
|
1413 |
+
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
1414 |
+
Returns:
|
1415 |
+
torch.Tensor: Output Tensor.
|
1416 |
+
"""
|
1417 |
+
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
1418 |
+
id_embeds = self.proj(id_embeds)
|
1419 |
+
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
1420 |
+
id_embeds = self.norm(id_embeds)
|
1421 |
+
latents = id_embeds
|
1422 |
+
|
1423 |
+
clip_embeds = self.proj_in(self.clip_embeds)
|
1424 |
+
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
1425 |
+
|
1426 |
+
for block in self.layers:
|
1427 |
+
residual = latents
|
1428 |
+
latents = block(x, latents, residual)
|
1429 |
+
|
1430 |
+
latents = self.proj_out(latents)
|
1431 |
+
out = self.norm_out(latents)
|
1432 |
+
if self.shortcut:
|
1433 |
+
out = id_embeds + self.shortcut_scale * out
|
1434 |
+
return out
|
1435 |
+
|
1436 |
+
|
1437 |
+
class MultiIPAdapterImageProjection(nn.Module):
|
1438 |
+
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
1439 |
+
super().__init__()
|
1440 |
+
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
1441 |
+
|
1442 |
+
def forward(self, image_embeds: List[torch.Tensor]):
|
1443 |
+
projected_image_embeds = []
|
1444 |
+
|
1445 |
+
# currently, we accept `image_embeds` as
|
1446 |
+
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
|
1447 |
+
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
|
1448 |
+
if not isinstance(image_embeds, list):
|
1449 |
+
deprecation_message = (
|
1450 |
+
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
1451 |
+
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
|
1452 |
+
)
|
1453 |
+
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
1454 |
+
image_embeds = [image_embeds.unsqueeze(1)]
|
1455 |
+
|
1456 |
+
if len(image_embeds) != len(self.image_projection_layers):
|
1457 |
+
raise ValueError(
|
1458 |
+
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
|
1459 |
+
)
|
1460 |
+
|
1461 |
+
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
|
1462 |
+
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
|
1463 |
+
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
|
1464 |
+
image_embed = image_projection_layer(image_embed)
|
1465 |
+
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
|
1466 |
+
|
1467 |
+
projected_image_embeds.append(image_embed)
|
1468 |
+
|
1469 |
+
return projected_image_embeds
|
flux/flux_network.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .transformer_flux import FluxTransformer2DModel
|
4 |
+
|
5 |
+
class FluxNetwork(nn.Module):
|
6 |
+
TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型
|
7 |
+
FLUX_PREFIX = "flux"
|
8 |
+
|
9 |
+
def __init__(self, flux_model: FluxTransformer2DModel):
|
10 |
+
super().__init__()
|
11 |
+
self.flux_model = flux_model
|
12 |
+
self.trainable_component_names = [] # 用于记录可训练组件的名称
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def generate_trainable_components(layers, num_transformer_blocks=19):
|
16 |
+
transformer_components = [
|
17 |
+
"attn.to_q",
|
18 |
+
"attn.to_k",
|
19 |
+
"attn.to_v",
|
20 |
+
"attn.to_out",
|
21 |
+
"norm1",
|
22 |
+
"norm1_context",
|
23 |
+
]
|
24 |
+
|
25 |
+
single_transformer_components = [
|
26 |
+
"attn.to_q",
|
27 |
+
"attn.to_k",
|
28 |
+
"attn.to_v",
|
29 |
+
"norm",
|
30 |
+
#"proj_mlp",
|
31 |
+
]
|
32 |
+
|
33 |
+
components = ["context_embedder"] # 添加 context_embedder
|
34 |
+
for layer in layers:
|
35 |
+
if layer < num_transformer_blocks:
|
36 |
+
prefix = f"transformer_blocks.{layer}"
|
37 |
+
base_components = transformer_components
|
38 |
+
else:
|
39 |
+
prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
|
40 |
+
base_components = single_transformer_components
|
41 |
+
components.extend([f"{prefix}.{comp}" for comp in base_components])
|
42 |
+
|
43 |
+
return components
|
44 |
+
|
45 |
+
#def apply_to(self, num_layers=1, additional_components=None):
|
46 |
+
# component_names = self.generate_trainable_components(num_layers)
|
47 |
+
#
|
48 |
+
# if additional_components:
|
49 |
+
# component_names.extend(additional_components)
|
50 |
+
#
|
51 |
+
# self.trainable_component_names = [] # 重置
|
52 |
+
# for name in component_names:
|
53 |
+
# recursive_getattr(self.flux_model, name).requires_grad_(True)
|
54 |
+
# self.trainable_component_names.append(name) # 记录名称
|
55 |
+
|
56 |
+
#def apply_to(self, num_layers=1, additional_components=None):
|
57 |
+
# component_names = self.generate_trainable_components(num_layers)
|
58 |
+
#
|
59 |
+
# if additional_components:
|
60 |
+
# component_names.extend(additional_components)
|
61 |
+
#
|
62 |
+
# self.trainable_component_names = [] # 重置
|
63 |
+
# for name in component_names:
|
64 |
+
# component = recursive_getattr(self.flux_model, name)
|
65 |
+
# if isinstance(component, nn.Module):
|
66 |
+
# component.requires_grad_(True)
|
67 |
+
# self.trainable_component_names.append(name)
|
68 |
+
# else:
|
69 |
+
# print(f"Warning: {name} is not a Module, skipping.")
|
70 |
+
|
71 |
+
def apply_to(self, layers=None, additional_components=None):
|
72 |
+
if layers is None:
|
73 |
+
layers = list(range(57)) # 默认包含所有层
|
74 |
+
|
75 |
+
component_names = self.generate_trainable_components(layers)
|
76 |
+
|
77 |
+
if additional_components:
|
78 |
+
component_names.extend(additional_components)
|
79 |
+
|
80 |
+
self.trainable_component_names = [] # 重置
|
81 |
+
for name in component_names:
|
82 |
+
try:
|
83 |
+
component = recursive_getattr(self.flux_model, name)
|
84 |
+
if isinstance(component, nn.Module):
|
85 |
+
component.requires_grad_(True)
|
86 |
+
self.trainable_component_names.append(name)
|
87 |
+
else:
|
88 |
+
print(f"Warning: {name} is not a Module, skipping.")
|
89 |
+
except AttributeError:
|
90 |
+
print(f"Warning: {name} not found in the model, skipping.")
|
91 |
+
|
92 |
+
def prepare_grad_etc(self):
|
93 |
+
# 供flux_model调用,用于冻结/解冻组件
|
94 |
+
self.flux_model.requires_grad_(False)
|
95 |
+
for name in self.trainable_component_names:
|
96 |
+
recursive_getattr(self.flux_model, name).requires_grad_(True)
|
97 |
+
|
98 |
+
def get_trainable_params(self):
|
99 |
+
# 返回需要训练的参数
|
100 |
+
params = []
|
101 |
+
for name in self.trainable_component_names:
|
102 |
+
params.extend(recursive_getattr(self.flux_model, name).parameters())
|
103 |
+
return params
|
104 |
+
|
105 |
+
def print_trainable_params_info(self):
|
106 |
+
total_params = 0
|
107 |
+
for name in self.trainable_component_names:
|
108 |
+
module = recursive_getattr(self.flux_model, name)
|
109 |
+
module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
110 |
+
total_params += module_params
|
111 |
+
#print(f'{name}: {module_params} trainable parameters')
|
112 |
+
print(f'Total trainable params: {total_params}')
|
113 |
+
|
114 |
+
def save_weights(self, file, dtype=None):
|
115 |
+
# 保存需要训练的组件参数
|
116 |
+
state_dict = {}
|
117 |
+
for name in self.trainable_component_names:
|
118 |
+
state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
|
119 |
+
if dtype is not None:
|
120 |
+
for v in state_dict.values():
|
121 |
+
v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}
|
122 |
+
torch.save(state_dict, file)
|
123 |
+
|
124 |
+
#def load_weights(self, file):
|
125 |
+
# # 加载需要训练的组件参数
|
126 |
+
# state_dict = torch.load(file, weights_only=True)
|
127 |
+
# for name in state_dict:
|
128 |
+
# module = recursive_getattr(self.flux_model, name)
|
129 |
+
# module.load_state_dict(state_dict[name])
|
130 |
+
# print(f"加载参数: {name}")
|
131 |
+
|
132 |
+
def load_weights(self, file, device):
|
133 |
+
print(f"Loading weights from {file}")
|
134 |
+
try:
|
135 |
+
state_dict = torch.load(file, map_location=device, weights_only=True)
|
136 |
+
except Exception as e:
|
137 |
+
print(f"Failed to load weights from {file}: {str(e)}")
|
138 |
+
return False
|
139 |
+
|
140 |
+
successfully_loaded = []
|
141 |
+
failed_to_load = []
|
142 |
+
|
143 |
+
for name in state_dict:
|
144 |
+
try:
|
145 |
+
module = recursive_getattr(self.flux_model, name)
|
146 |
+
module_state_dict = module.state_dict()
|
147 |
+
|
148 |
+
# 检查state_dict的键是否匹配
|
149 |
+
if set(state_dict[name].keys()) != set(module_state_dict.keys()):
|
150 |
+
raise ValueError(f"State dict keys for {name} do not match")
|
151 |
+
|
152 |
+
# 检查张量的形状是否匹配
|
153 |
+
for key in state_dict[name]:
|
154 |
+
if state_dict[name][key].shape != module_state_dict[key].shape:
|
155 |
+
raise ValueError(f"Shape mismatch for {name}.{key}")
|
156 |
+
|
157 |
+
module.load_state_dict(state_dict[name])
|
158 |
+
successfully_loaded.append(name)
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
print(f"Failed to load weights for {name}: {str(e)}")
|
162 |
+
failed_to_load.append(name)
|
163 |
+
|
164 |
+
if successfully_loaded:
|
165 |
+
print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
|
166 |
+
if failed_to_load:
|
167 |
+
print(f"Failed to load weights for: {', '.join(failed_to_load)}")
|
168 |
+
|
169 |
+
return len(failed_to_load) == 0 # 如果没有加载失败的组件,则返回True
|
170 |
+
|
171 |
+
# 改进的递归获取属性函数
|
172 |
+
def recursive_getattr(obj, attr):
|
173 |
+
attrs = attr.split(".")
|
174 |
+
for i in range(len(attrs)):
|
175 |
+
obj = getattr(obj, attrs[i])
|
176 |
+
return obj
|
177 |
+
|
178 |
+
# 递归设置属性函数
|
179 |
+
def recursive_setattr(obj, attr, val):
|
180 |
+
attrs = attr.split(".")
|
181 |
+
for i in range(len(attrs)-1):
|
182 |
+
obj = getattr(obj, attrs[i])
|
183 |
+
setattr(obj, attrs[-1], val)
|
flux/lora/lora_base.py
ADDED
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import inspect
|
17 |
+
import os
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Callable, Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import safetensors
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from huggingface_hub import model_info
|
25 |
+
from huggingface_hub.constants import HF_HUB_OFFLINE
|
26 |
+
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
28 |
+
from diffusers.utils import (
|
29 |
+
USE_PEFT_BACKEND,
|
30 |
+
_get_model_file,
|
31 |
+
delete_adapter_layers,
|
32 |
+
deprecate,
|
33 |
+
is_accelerate_available,
|
34 |
+
is_peft_available,
|
35 |
+
is_transformers_available,
|
36 |
+
logging,
|
37 |
+
recurse_remove_peft_layers,
|
38 |
+
set_adapter_layers,
|
39 |
+
set_weights_and_activate_adapters,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
if is_transformers_available():
|
44 |
+
from transformers import PreTrainedModel
|
45 |
+
|
46 |
+
if is_peft_available():
|
47 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
48 |
+
|
49 |
+
if is_accelerate_available():
|
50 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
56 |
+
"""
|
57 |
+
Fuses LoRAs for the text encoder.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
text_encoder (`torch.nn.Module`):
|
61 |
+
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
62 |
+
attribute.
|
63 |
+
lora_scale (`float`, defaults to 1.0):
|
64 |
+
Controls how much to influence the outputs with the LoRA parameters.
|
65 |
+
safe_fusing (`bool`, defaults to `False`):
|
66 |
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
67 |
+
adapter_names (`List[str]` or `str`):
|
68 |
+
The names of the adapters to use.
|
69 |
+
"""
|
70 |
+
merge_kwargs = {"safe_merge": safe_fusing}
|
71 |
+
|
72 |
+
for module in text_encoder.modules():
|
73 |
+
if isinstance(module, BaseTunerLayer):
|
74 |
+
if lora_scale != 1.0:
|
75 |
+
module.scale_layer(lora_scale)
|
76 |
+
|
77 |
+
# For BC with previous PEFT versions, we need to check the signature
|
78 |
+
# of the `merge` method to see if it supports the `adapter_names` argument.
|
79 |
+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
80 |
+
if "adapter_names" in supported_merge_kwargs:
|
81 |
+
merge_kwargs["adapter_names"] = adapter_names
|
82 |
+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
83 |
+
raise ValueError(
|
84 |
+
"The `adapter_names` argument is not supported with your PEFT version. "
|
85 |
+
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
86 |
+
)
|
87 |
+
|
88 |
+
module.merge(**merge_kwargs)
|
89 |
+
|
90 |
+
|
91 |
+
def unfuse_text_encoder_lora(text_encoder):
|
92 |
+
"""
|
93 |
+
Unfuses LoRAs for the text encoder.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
text_encoder (`torch.nn.Module`):
|
97 |
+
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
98 |
+
attribute.
|
99 |
+
"""
|
100 |
+
for module in text_encoder.modules():
|
101 |
+
if isinstance(module, BaseTunerLayer):
|
102 |
+
module.unmerge()
|
103 |
+
|
104 |
+
|
105 |
+
def set_adapters_for_text_encoder(
|
106 |
+
adapter_names: Union[List[str], str],
|
107 |
+
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
108 |
+
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
Sets the adapter layers for the text encoder.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
adapter_names (`List[str]` or `str`):
|
115 |
+
The names of the adapters to use.
|
116 |
+
text_encoder (`torch.nn.Module`, *optional*):
|
117 |
+
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
118 |
+
attribute.
|
119 |
+
text_encoder_weights (`List[float]`, *optional*):
|
120 |
+
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
121 |
+
"""
|
122 |
+
if text_encoder is None:
|
123 |
+
raise ValueError(
|
124 |
+
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
125 |
+
)
|
126 |
+
|
127 |
+
def process_weights(adapter_names, weights):
|
128 |
+
# Expand weights into a list, one entry per adapter
|
129 |
+
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
130 |
+
if not isinstance(weights, list):
|
131 |
+
weights = [weights] * len(adapter_names)
|
132 |
+
|
133 |
+
if len(adapter_names) != len(weights):
|
134 |
+
raise ValueError(
|
135 |
+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
136 |
+
)
|
137 |
+
|
138 |
+
# Set None values to default of 1.0
|
139 |
+
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
140 |
+
weights = [w if w is not None else 1.0 for w in weights]
|
141 |
+
|
142 |
+
return weights
|
143 |
+
|
144 |
+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
145 |
+
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
146 |
+
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
147 |
+
|
148 |
+
|
149 |
+
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
150 |
+
"""
|
151 |
+
Disables the LoRA layers for the text encoder.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
text_encoder (`torch.nn.Module`, *optional*):
|
155 |
+
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
156 |
+
attribute.
|
157 |
+
"""
|
158 |
+
if text_encoder is None:
|
159 |
+
raise ValueError("Text Encoder not found.")
|
160 |
+
set_adapter_layers(text_encoder, enabled=False)
|
161 |
+
|
162 |
+
|
163 |
+
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
164 |
+
"""
|
165 |
+
Enables the LoRA layers for the text encoder.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
text_encoder (`torch.nn.Module`, *optional*):
|
169 |
+
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
170 |
+
attribute.
|
171 |
+
"""
|
172 |
+
if text_encoder is None:
|
173 |
+
raise ValueError("Text Encoder not found.")
|
174 |
+
set_adapter_layers(text_encoder, enabled=True)
|
175 |
+
|
176 |
+
|
177 |
+
def _remove_text_encoder_monkey_patch(text_encoder):
|
178 |
+
recurse_remove_peft_layers(text_encoder)
|
179 |
+
if getattr(text_encoder, "peft_config", None) is not None:
|
180 |
+
del text_encoder.peft_config
|
181 |
+
text_encoder._hf_peft_config_loaded = None
|
182 |
+
|
183 |
+
|
184 |
+
class LoraBaseMixin:
|
185 |
+
"""Utility class for handling LoRAs."""
|
186 |
+
|
187 |
+
_lora_loadable_modules = []
|
188 |
+
num_fused_loras = 0
|
189 |
+
|
190 |
+
def load_lora_weights(self, **kwargs):
|
191 |
+
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
192 |
+
|
193 |
+
@classmethod
|
194 |
+
def save_lora_weights(cls, **kwargs):
|
195 |
+
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def lora_state_dict(cls, **kwargs):
|
199 |
+
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
200 |
+
|
201 |
+
@classmethod
|
202 |
+
def _optionally_disable_offloading(cls, _pipeline):
|
203 |
+
"""
|
204 |
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
_pipeline (`DiffusionPipeline`):
|
208 |
+
The pipeline to disable offloading for.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
tuple:
|
212 |
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
213 |
+
"""
|
214 |
+
is_model_cpu_offload = False
|
215 |
+
is_sequential_cpu_offload = False
|
216 |
+
|
217 |
+
if _pipeline is not None and _pipeline.hf_device_map is None:
|
218 |
+
for _, component in _pipeline.components.items():
|
219 |
+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
220 |
+
if not is_model_cpu_offload:
|
221 |
+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
222 |
+
if not is_sequential_cpu_offload:
|
223 |
+
is_sequential_cpu_offload = (
|
224 |
+
isinstance(component._hf_hook, AlignDevicesHook)
|
225 |
+
or hasattr(component._hf_hook, "hooks")
|
226 |
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
227 |
+
)
|
228 |
+
|
229 |
+
logger.info(
|
230 |
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
231 |
+
)
|
232 |
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
233 |
+
|
234 |
+
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
235 |
+
|
236 |
+
@classmethod
|
237 |
+
def _fetch_state_dict(
|
238 |
+
cls,
|
239 |
+
pretrained_model_name_or_path_or_dict,
|
240 |
+
weight_name,
|
241 |
+
use_safetensors,
|
242 |
+
local_files_only,
|
243 |
+
cache_dir,
|
244 |
+
force_download,
|
245 |
+
proxies,
|
246 |
+
token,
|
247 |
+
revision,
|
248 |
+
subfolder,
|
249 |
+
user_agent,
|
250 |
+
allow_pickle,
|
251 |
+
):
|
252 |
+
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
253 |
+
|
254 |
+
model_file = None
|
255 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
256 |
+
# Let's first try to load .safetensors weights
|
257 |
+
if (use_safetensors and weight_name is None) or (
|
258 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
259 |
+
):
|
260 |
+
try:
|
261 |
+
# Here we're relaxing the loading check to enable more Inference API
|
262 |
+
# friendliness where sometimes, it's not at all possible to automatically
|
263 |
+
# determine `weight_name`.
|
264 |
+
if weight_name is None:
|
265 |
+
weight_name = cls._best_guess_weight_name(
|
266 |
+
pretrained_model_name_or_path_or_dict,
|
267 |
+
file_extension=".safetensors",
|
268 |
+
local_files_only=local_files_only,
|
269 |
+
)
|
270 |
+
model_file = _get_model_file(
|
271 |
+
pretrained_model_name_or_path_or_dict,
|
272 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
273 |
+
cache_dir=cache_dir,
|
274 |
+
force_download=force_download,
|
275 |
+
proxies=proxies,
|
276 |
+
local_files_only=local_files_only,
|
277 |
+
token=token,
|
278 |
+
revision=revision,
|
279 |
+
subfolder=subfolder,
|
280 |
+
user_agent=user_agent,
|
281 |
+
)
|
282 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
283 |
+
except (IOError, safetensors.SafetensorError) as e:
|
284 |
+
if not allow_pickle:
|
285 |
+
raise e
|
286 |
+
# try loading non-safetensors weights
|
287 |
+
model_file = None
|
288 |
+
pass
|
289 |
+
|
290 |
+
if model_file is None:
|
291 |
+
if weight_name is None:
|
292 |
+
weight_name = cls._best_guess_weight_name(
|
293 |
+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
294 |
+
)
|
295 |
+
model_file = _get_model_file(
|
296 |
+
pretrained_model_name_or_path_or_dict,
|
297 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
298 |
+
cache_dir=cache_dir,
|
299 |
+
force_download=force_download,
|
300 |
+
proxies=proxies,
|
301 |
+
local_files_only=local_files_only,
|
302 |
+
token=token,
|
303 |
+
revision=revision,
|
304 |
+
subfolder=subfolder,
|
305 |
+
user_agent=user_agent,
|
306 |
+
)
|
307 |
+
state_dict = load_state_dict(model_file)
|
308 |
+
else:
|
309 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
310 |
+
|
311 |
+
return state_dict
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def _best_guess_weight_name(
|
315 |
+
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
316 |
+
):
|
317 |
+
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
318 |
+
|
319 |
+
if local_files_only or HF_HUB_OFFLINE:
|
320 |
+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
321 |
+
|
322 |
+
targeted_files = []
|
323 |
+
|
324 |
+
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
325 |
+
return
|
326 |
+
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
327 |
+
targeted_files = [
|
328 |
+
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
329 |
+
]
|
330 |
+
else:
|
331 |
+
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
332 |
+
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
333 |
+
if len(targeted_files) == 0:
|
334 |
+
return
|
335 |
+
|
336 |
+
# "scheduler" does not correspond to a LoRA checkpoint.
|
337 |
+
# "optimizer" does not correspond to a LoRA checkpoint
|
338 |
+
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
339 |
+
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
340 |
+
targeted_files = list(
|
341 |
+
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
342 |
+
)
|
343 |
+
|
344 |
+
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
345 |
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
346 |
+
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
347 |
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
348 |
+
|
349 |
+
if len(targeted_files) > 1:
|
350 |
+
raise ValueError(
|
351 |
+
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
352 |
+
)
|
353 |
+
weight_name = targeted_files[0]
|
354 |
+
return weight_name
|
355 |
+
|
356 |
+
def unload_lora_weights(self):
|
357 |
+
"""
|
358 |
+
Unloads the LoRA parameters.
|
359 |
+
|
360 |
+
Examples:
|
361 |
+
|
362 |
+
```python
|
363 |
+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
364 |
+
>>> pipeline.unload_lora_weights()
|
365 |
+
>>> ...
|
366 |
+
```
|
367 |
+
"""
|
368 |
+
if not USE_PEFT_BACKEND:
|
369 |
+
raise ValueError("PEFT backend is required for this method.")
|
370 |
+
|
371 |
+
for component in self._lora_loadable_modules:
|
372 |
+
model = getattr(self, component, None)
|
373 |
+
if model is not None:
|
374 |
+
if issubclass(model.__class__, ModelMixin):
|
375 |
+
model.unload_lora()
|
376 |
+
elif issubclass(model.__class__, PreTrainedModel):
|
377 |
+
_remove_text_encoder_monkey_patch(model)
|
378 |
+
|
379 |
+
def fuse_lora(
|
380 |
+
self,
|
381 |
+
components: List[str] = [],
|
382 |
+
lora_scale: float = 1.0,
|
383 |
+
safe_fusing: bool = False,
|
384 |
+
adapter_names: Optional[List[str]] = None,
|
385 |
+
**kwargs,
|
386 |
+
):
|
387 |
+
r"""
|
388 |
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
389 |
+
|
390 |
+
<Tip warning={true}>
|
391 |
+
|
392 |
+
This is an experimental API.
|
393 |
+
|
394 |
+
</Tip>
|
395 |
+
|
396 |
+
Args:
|
397 |
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
398 |
+
lora_scale (`float`, defaults to 1.0):
|
399 |
+
Controls how much to influence the outputs with the LoRA parameters.
|
400 |
+
safe_fusing (`bool`, defaults to `False`):
|
401 |
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
402 |
+
adapter_names (`List[str]`, *optional*):
|
403 |
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
404 |
+
|
405 |
+
Example:
|
406 |
+
|
407 |
+
```py
|
408 |
+
from diffusers import DiffusionPipeline
|
409 |
+
import torch
|
410 |
+
|
411 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
412 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
413 |
+
).to("cuda")
|
414 |
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
415 |
+
pipeline.fuse_lora(lora_scale=0.7)
|
416 |
+
```
|
417 |
+
"""
|
418 |
+
if "fuse_unet" in kwargs:
|
419 |
+
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
420 |
+
deprecate(
|
421 |
+
"fuse_unet",
|
422 |
+
"1.0.0",
|
423 |
+
depr_message,
|
424 |
+
)
|
425 |
+
if "fuse_transformer" in kwargs:
|
426 |
+
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
427 |
+
deprecate(
|
428 |
+
"fuse_transformer",
|
429 |
+
"1.0.0",
|
430 |
+
depr_message,
|
431 |
+
)
|
432 |
+
if "fuse_text_encoder" in kwargs:
|
433 |
+
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
434 |
+
deprecate(
|
435 |
+
"fuse_text_encoder",
|
436 |
+
"1.0.0",
|
437 |
+
depr_message,
|
438 |
+
)
|
439 |
+
|
440 |
+
if len(components) == 0:
|
441 |
+
raise ValueError("`components` cannot be an empty list.")
|
442 |
+
|
443 |
+
for fuse_component in components:
|
444 |
+
if fuse_component not in self._lora_loadable_modules:
|
445 |
+
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
446 |
+
|
447 |
+
model = getattr(self, fuse_component, None)
|
448 |
+
if model is not None:
|
449 |
+
# check if diffusers model
|
450 |
+
if issubclass(model.__class__, ModelMixin):
|
451 |
+
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
452 |
+
# handle transformers models.
|
453 |
+
if issubclass(model.__class__, PreTrainedModel):
|
454 |
+
fuse_text_encoder_lora(
|
455 |
+
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
456 |
+
)
|
457 |
+
|
458 |
+
self.num_fused_loras += 1
|
459 |
+
|
460 |
+
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
461 |
+
r"""
|
462 |
+
Reverses the effect of
|
463 |
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
464 |
+
|
465 |
+
<Tip warning={true}>
|
466 |
+
|
467 |
+
This is an experimental API.
|
468 |
+
|
469 |
+
</Tip>
|
470 |
+
|
471 |
+
Args:
|
472 |
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
473 |
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
474 |
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
475 |
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
476 |
+
LoRA parameters then it won't have any effect.
|
477 |
+
"""
|
478 |
+
if "unfuse_unet" in kwargs:
|
479 |
+
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
480 |
+
deprecate(
|
481 |
+
"unfuse_unet",
|
482 |
+
"1.0.0",
|
483 |
+
depr_message,
|
484 |
+
)
|
485 |
+
if "unfuse_transformer" in kwargs:
|
486 |
+
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
487 |
+
deprecate(
|
488 |
+
"unfuse_transformer",
|
489 |
+
"1.0.0",
|
490 |
+
depr_message,
|
491 |
+
)
|
492 |
+
if "unfuse_text_encoder" in kwargs:
|
493 |
+
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
494 |
+
deprecate(
|
495 |
+
"unfuse_text_encoder",
|
496 |
+
"1.0.0",
|
497 |
+
depr_message,
|
498 |
+
)
|
499 |
+
|
500 |
+
if len(components) == 0:
|
501 |
+
raise ValueError("`components` cannot be an empty list.")
|
502 |
+
|
503 |
+
for fuse_component in components:
|
504 |
+
if fuse_component not in self._lora_loadable_modules:
|
505 |
+
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
506 |
+
|
507 |
+
model = getattr(self, fuse_component, None)
|
508 |
+
if model is not None:
|
509 |
+
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
510 |
+
for module in model.modules():
|
511 |
+
if isinstance(module, BaseTunerLayer):
|
512 |
+
module.unmerge()
|
513 |
+
|
514 |
+
self.num_fused_loras -= 1
|
515 |
+
|
516 |
+
def set_adapters(
|
517 |
+
self,
|
518 |
+
adapter_names: Union[List[str], str],
|
519 |
+
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
520 |
+
):
|
521 |
+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
522 |
+
|
523 |
+
adapter_weights = copy.deepcopy(adapter_weights)
|
524 |
+
|
525 |
+
# Expand weights into a list, one entry per adapter
|
526 |
+
if not isinstance(adapter_weights, list):
|
527 |
+
adapter_weights = [adapter_weights] * len(adapter_names)
|
528 |
+
|
529 |
+
if len(adapter_names) != len(adapter_weights):
|
530 |
+
raise ValueError(
|
531 |
+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
532 |
+
)
|
533 |
+
|
534 |
+
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
535 |
+
all_adapters = {
|
536 |
+
adapter for adapters in list_adapters.values() for adapter in adapters
|
537 |
+
} # eg ["adapter1", "adapter2"]
|
538 |
+
invert_list_adapters = {
|
539 |
+
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
540 |
+
for adapter in all_adapters
|
541 |
+
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
542 |
+
|
543 |
+
# Decompose weights into weights for denoiser and text encoders.
|
544 |
+
_component_adapter_weights = {}
|
545 |
+
for component in self._lora_loadable_modules:
|
546 |
+
model = getattr(self, component)
|
547 |
+
|
548 |
+
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
549 |
+
if isinstance(weights, dict):
|
550 |
+
component_adapter_weights = weights.pop(component, None)
|
551 |
+
|
552 |
+
if component_adapter_weights is not None and not hasattr(self, component):
|
553 |
+
logger.warning(
|
554 |
+
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
|
555 |
+
)
|
556 |
+
|
557 |
+
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
558 |
+
logger.warning(
|
559 |
+
(
|
560 |
+
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
561 |
+
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
562 |
+
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
563 |
+
)
|
564 |
+
)
|
565 |
+
|
566 |
+
else:
|
567 |
+
component_adapter_weights = weights
|
568 |
+
|
569 |
+
_component_adapter_weights.setdefault(component, [])
|
570 |
+
_component_adapter_weights[component].append(component_adapter_weights)
|
571 |
+
|
572 |
+
if issubclass(model.__class__, ModelMixin):
|
573 |
+
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
574 |
+
elif issubclass(model.__class__, PreTrainedModel):
|
575 |
+
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
576 |
+
|
577 |
+
def disable_lora(self):
|
578 |
+
if not USE_PEFT_BACKEND:
|
579 |
+
raise ValueError("PEFT backend is required for this method.")
|
580 |
+
|
581 |
+
for component in self._lora_loadable_modules:
|
582 |
+
model = getattr(self, component, None)
|
583 |
+
if model is not None:
|
584 |
+
if issubclass(model.__class__, ModelMixin):
|
585 |
+
model.disable_lora()
|
586 |
+
elif issubclass(model.__class__, PreTrainedModel):
|
587 |
+
disable_lora_for_text_encoder(model)
|
588 |
+
|
589 |
+
def enable_lora(self):
|
590 |
+
if not USE_PEFT_BACKEND:
|
591 |
+
raise ValueError("PEFT backend is required for this method.")
|
592 |
+
|
593 |
+
for component in self._lora_loadable_modules:
|
594 |
+
model = getattr(self, component, None)
|
595 |
+
if model is not None:
|
596 |
+
if issubclass(model.__class__, ModelMixin):
|
597 |
+
model.enable_lora()
|
598 |
+
elif issubclass(model.__class__, PreTrainedModel):
|
599 |
+
enable_lora_for_text_encoder(model)
|
600 |
+
|
601 |
+
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
602 |
+
"""
|
603 |
+
Args:
|
604 |
+
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
605 |
+
adapter_names (`Union[List[str], str]`):
|
606 |
+
The names of the adapter to delete. Can be a single string or a list of strings
|
607 |
+
"""
|
608 |
+
if not USE_PEFT_BACKEND:
|
609 |
+
raise ValueError("PEFT backend is required for this method.")
|
610 |
+
|
611 |
+
if isinstance(adapter_names, str):
|
612 |
+
adapter_names = [adapter_names]
|
613 |
+
|
614 |
+
for component in self._lora_loadable_modules:
|
615 |
+
model = getattr(self, component, None)
|
616 |
+
if model is not None:
|
617 |
+
if issubclass(model.__class__, ModelMixin):
|
618 |
+
model.delete_adapters(adapter_names)
|
619 |
+
elif issubclass(model.__class__, PreTrainedModel):
|
620 |
+
for adapter_name in adapter_names:
|
621 |
+
delete_adapter_layers(model, adapter_name)
|
622 |
+
|
623 |
+
def get_active_adapters(self) -> List[str]:
|
624 |
+
"""
|
625 |
+
Gets the list of the current active adapters.
|
626 |
+
|
627 |
+
Example:
|
628 |
+
|
629 |
+
```python
|
630 |
+
from diffusers import DiffusionPipeline
|
631 |
+
|
632 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
633 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
634 |
+
).to("cuda")
|
635 |
+
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
636 |
+
pipeline.get_active_adapters()
|
637 |
+
```
|
638 |
+
"""
|
639 |
+
if not USE_PEFT_BACKEND:
|
640 |
+
raise ValueError(
|
641 |
+
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
642 |
+
)
|
643 |
+
|
644 |
+
active_adapters = []
|
645 |
+
|
646 |
+
for component in self._lora_loadable_modules:
|
647 |
+
model = getattr(self, component, None)
|
648 |
+
if model is not None and issubclass(model.__class__, ModelMixin):
|
649 |
+
for module in model.modules():
|
650 |
+
if isinstance(module, BaseTunerLayer):
|
651 |
+
active_adapters = module.active_adapters
|
652 |
+
break
|
653 |
+
|
654 |
+
return active_adapters
|
655 |
+
|
656 |
+
def get_list_adapters(self) -> Dict[str, List[str]]:
|
657 |
+
"""
|
658 |
+
Gets the current list of all available adapters in the pipeline.
|
659 |
+
"""
|
660 |
+
if not USE_PEFT_BACKEND:
|
661 |
+
raise ValueError(
|
662 |
+
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
663 |
+
)
|
664 |
+
|
665 |
+
set_adapters = {}
|
666 |
+
|
667 |
+
for component in self._lora_loadable_modules:
|
668 |
+
model = getattr(self, component, None)
|
669 |
+
if (
|
670 |
+
model is not None
|
671 |
+
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
672 |
+
and hasattr(model, "peft_config")
|
673 |
+
):
|
674 |
+
set_adapters[component] = list(model.peft_config.keys())
|
675 |
+
|
676 |
+
return set_adapters
|
677 |
+
|
678 |
+
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
679 |
+
"""
|
680 |
+
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
681 |
+
you want to load multiple adapters and free some GPU memory.
|
682 |
+
|
683 |
+
Args:
|
684 |
+
adapter_names (`List[str]`):
|
685 |
+
List of adapters to send device to.
|
686 |
+
device (`Union[torch.device, str, int]`):
|
687 |
+
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
688 |
+
"""
|
689 |
+
if not USE_PEFT_BACKEND:
|
690 |
+
raise ValueError("PEFT backend is required for this method.")
|
691 |
+
|
692 |
+
for component in self._lora_loadable_modules:
|
693 |
+
model = getattr(self, component, None)
|
694 |
+
if model is not None:
|
695 |
+
for module in model.modules():
|
696 |
+
if isinstance(module, BaseTunerLayer):
|
697 |
+
for adapter_name in adapter_names:
|
698 |
+
module.lora_A[adapter_name].to(device)
|
699 |
+
module.lora_B[adapter_name].to(device)
|
700 |
+
# this is a param, not a module, so device placement is not in-place -> re-assign
|
701 |
+
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
702 |
+
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
703 |
+
adapter_name
|
704 |
+
].to(device)
|
705 |
+
|
706 |
+
@staticmethod
|
707 |
+
def pack_weights(layers, prefix):
|
708 |
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
709 |
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
710 |
+
return layers_state_dict
|
711 |
+
|
712 |
+
@staticmethod
|
713 |
+
def write_lora_layers(
|
714 |
+
state_dict: Dict[str, torch.Tensor],
|
715 |
+
save_directory: str,
|
716 |
+
is_main_process: bool,
|
717 |
+
weight_name: str,
|
718 |
+
save_function: Callable,
|
719 |
+
safe_serialization: bool,
|
720 |
+
):
|
721 |
+
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
722 |
+
|
723 |
+
if os.path.isfile(save_directory):
|
724 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
725 |
+
return
|
726 |
+
|
727 |
+
if save_function is None:
|
728 |
+
if safe_serialization:
|
729 |
+
|
730 |
+
def save_function(weights, filename):
|
731 |
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
732 |
+
|
733 |
+
else:
|
734 |
+
save_function = torch.save
|
735 |
+
|
736 |
+
os.makedirs(save_directory, exist_ok=True)
|
737 |
+
|
738 |
+
if weight_name is None:
|
739 |
+
if safe_serialization:
|
740 |
+
weight_name = LORA_WEIGHT_NAME_SAFE
|
741 |
+
else:
|
742 |
+
weight_name = LORA_WEIGHT_NAME
|
743 |
+
|
744 |
+
save_path = Path(save_directory, weight_name).as_posix()
|
745 |
+
save_function(state_dict, save_path)
|
746 |
+
logger.info(f"Model weights saved in {save_path}")
|
747 |
+
|
748 |
+
@property
|
749 |
+
def lora_scale(self) -> float:
|
750 |
+
# property function that returns the lora scale which can be set at run time by the pipeline.
|
751 |
+
# if _lora_scale has not been set, return 1
|
752 |
+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
flux/lora/lora_conversion_utils.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
|
17 |
+
from diffusers.utils import is_peft_version, logging
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
24 |
+
# 1. get all state_dict_keys
|
25 |
+
all_keys = list(state_dict.keys())
|
26 |
+
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
|
27 |
+
|
28 |
+
# 2. check if needs remapping, if not return original dict
|
29 |
+
is_in_sgm_format = False
|
30 |
+
for key in all_keys:
|
31 |
+
if any(p in key for p in sgm_patterns):
|
32 |
+
is_in_sgm_format = True
|
33 |
+
break
|
34 |
+
|
35 |
+
if not is_in_sgm_format:
|
36 |
+
return state_dict
|
37 |
+
|
38 |
+
# 3. Else remap from SGM patterns
|
39 |
+
new_state_dict = {}
|
40 |
+
inner_block_map = ["resnets", "attentions", "upsamplers"]
|
41 |
+
|
42 |
+
# Retrieves # of down, mid and up blocks
|
43 |
+
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
|
44 |
+
|
45 |
+
for layer in all_keys:
|
46 |
+
if "text" in layer:
|
47 |
+
new_state_dict[layer] = state_dict.pop(layer)
|
48 |
+
else:
|
49 |
+
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
|
50 |
+
if sgm_patterns[0] in layer:
|
51 |
+
input_block_ids.add(layer_id)
|
52 |
+
elif sgm_patterns[1] in layer:
|
53 |
+
middle_block_ids.add(layer_id)
|
54 |
+
elif sgm_patterns[2] in layer:
|
55 |
+
output_block_ids.add(layer_id)
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
|
58 |
+
|
59 |
+
input_blocks = {
|
60 |
+
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
|
61 |
+
for layer_id in input_block_ids
|
62 |
+
}
|
63 |
+
middle_blocks = {
|
64 |
+
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
|
65 |
+
for layer_id in middle_block_ids
|
66 |
+
}
|
67 |
+
output_blocks = {
|
68 |
+
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
|
69 |
+
for layer_id in output_block_ids
|
70 |
+
}
|
71 |
+
|
72 |
+
# Rename keys accordingly
|
73 |
+
for i in input_block_ids:
|
74 |
+
block_id = (i - 1) // (unet_config.layers_per_block + 1)
|
75 |
+
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
|
76 |
+
|
77 |
+
for key in input_blocks[i]:
|
78 |
+
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
79 |
+
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
|
80 |
+
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
|
81 |
+
new_key = delimiter.join(
|
82 |
+
key.split(delimiter)[: block_slice_pos - 1]
|
83 |
+
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
84 |
+
+ key.split(delimiter)[block_slice_pos + 1 :]
|
85 |
+
)
|
86 |
+
new_state_dict[new_key] = state_dict.pop(key)
|
87 |
+
|
88 |
+
for i in middle_block_ids:
|
89 |
+
key_part = None
|
90 |
+
if i == 0:
|
91 |
+
key_part = [inner_block_map[0], "0"]
|
92 |
+
elif i == 1:
|
93 |
+
key_part = [inner_block_map[1], "0"]
|
94 |
+
elif i == 2:
|
95 |
+
key_part = [inner_block_map[0], "1"]
|
96 |
+
else:
|
97 |
+
raise ValueError(f"Invalid middle block id {i}.")
|
98 |
+
|
99 |
+
for key in middle_blocks[i]:
|
100 |
+
new_key = delimiter.join(
|
101 |
+
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
|
102 |
+
)
|
103 |
+
new_state_dict[new_key] = state_dict.pop(key)
|
104 |
+
|
105 |
+
for i in output_block_ids:
|
106 |
+
block_id = i // (unet_config.layers_per_block + 1)
|
107 |
+
layer_in_block_id = i % (unet_config.layers_per_block + 1)
|
108 |
+
|
109 |
+
for key in output_blocks[i]:
|
110 |
+
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
111 |
+
inner_block_key = inner_block_map[inner_block_id]
|
112 |
+
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
|
113 |
+
new_key = delimiter.join(
|
114 |
+
key.split(delimiter)[: block_slice_pos - 1]
|
115 |
+
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
116 |
+
+ key.split(delimiter)[block_slice_pos + 1 :]
|
117 |
+
)
|
118 |
+
new_state_dict[new_key] = state_dict.pop(key)
|
119 |
+
|
120 |
+
if len(state_dict) > 0:
|
121 |
+
raise ValueError("At this point all state dict entries have to be converted.")
|
122 |
+
|
123 |
+
return new_state_dict
|
124 |
+
|
125 |
+
|
126 |
+
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
127 |
+
"""
|
128 |
+
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
state_dict (`dict`): The state dict to convert.
|
132 |
+
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
|
133 |
+
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
|
134 |
+
"text_encoder".
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
|
138 |
+
"""
|
139 |
+
unet_state_dict = {}
|
140 |
+
te_state_dict = {}
|
141 |
+
te2_state_dict = {}
|
142 |
+
network_alphas = {}
|
143 |
+
|
144 |
+
# Check for DoRA-enabled LoRAs.
|
145 |
+
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
146 |
+
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
147 |
+
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
148 |
+
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
|
149 |
+
if is_peft_version("<", "0.9.0"):
|
150 |
+
raise ValueError(
|
151 |
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
152 |
+
)
|
153 |
+
|
154 |
+
# Iterate over all LoRA weights.
|
155 |
+
all_lora_keys = list(state_dict.keys())
|
156 |
+
for key in all_lora_keys:
|
157 |
+
if not key.endswith("lora_down.weight"):
|
158 |
+
continue
|
159 |
+
|
160 |
+
# Extract LoRA name.
|
161 |
+
lora_name = key.split(".")[0]
|
162 |
+
|
163 |
+
# Find corresponding up weight and alpha.
|
164 |
+
lora_name_up = lora_name + ".lora_up.weight"
|
165 |
+
lora_name_alpha = lora_name + ".alpha"
|
166 |
+
|
167 |
+
# Handle U-Net LoRAs.
|
168 |
+
if lora_name.startswith("lora_unet_"):
|
169 |
+
diffusers_name = _convert_unet_lora_key(key)
|
170 |
+
|
171 |
+
# Store down and up weights.
|
172 |
+
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
173 |
+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
174 |
+
|
175 |
+
# Store DoRA scale if present.
|
176 |
+
if dora_present_in_unet:
|
177 |
+
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
178 |
+
unet_state_dict[
|
179 |
+
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
180 |
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
181 |
+
|
182 |
+
# Handle text encoder LoRAs.
|
183 |
+
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
184 |
+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
185 |
+
|
186 |
+
# Store down and up weights for te or te2.
|
187 |
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
188 |
+
te_state_dict[diffusers_name] = state_dict.pop(key)
|
189 |
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
190 |
+
else:
|
191 |
+
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
192 |
+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
193 |
+
|
194 |
+
# Store DoRA scale if present.
|
195 |
+
if dora_present_in_te or dora_present_in_te2:
|
196 |
+
dora_scale_key_to_replace_te = (
|
197 |
+
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
198 |
+
)
|
199 |
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
200 |
+
te_state_dict[
|
201 |
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
202 |
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
203 |
+
elif lora_name.startswith("lora_te2_"):
|
204 |
+
te2_state_dict[
|
205 |
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
206 |
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
207 |
+
|
208 |
+
# Store alpha if present.
|
209 |
+
if lora_name_alpha in state_dict:
|
210 |
+
alpha = state_dict.pop(lora_name_alpha).item()
|
211 |
+
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
|
212 |
+
|
213 |
+
# Check if any keys remain.
|
214 |
+
if len(state_dict) > 0:
|
215 |
+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
216 |
+
|
217 |
+
logger.info("Non-diffusers checkpoint detected.")
|
218 |
+
|
219 |
+
# Construct final state dict.
|
220 |
+
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
221 |
+
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
222 |
+
te2_state_dict = (
|
223 |
+
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
224 |
+
if len(te2_state_dict) > 0
|
225 |
+
else None
|
226 |
+
)
|
227 |
+
if te2_state_dict is not None:
|
228 |
+
te_state_dict.update(te2_state_dict)
|
229 |
+
|
230 |
+
new_state_dict = {**unet_state_dict, **te_state_dict}
|
231 |
+
return new_state_dict, network_alphas
|
232 |
+
|
233 |
+
|
234 |
+
def _convert_unet_lora_key(key):
|
235 |
+
"""
|
236 |
+
Converts a U-Net LoRA key to a Diffusers compatible key.
|
237 |
+
"""
|
238 |
+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
239 |
+
|
240 |
+
# Replace common U-Net naming patterns.
|
241 |
+
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
242 |
+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
243 |
+
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
244 |
+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
245 |
+
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
246 |
+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
247 |
+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
248 |
+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
249 |
+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
250 |
+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
251 |
+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
252 |
+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
253 |
+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
254 |
+
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
255 |
+
|
256 |
+
# SDXL specific conversions.
|
257 |
+
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
258 |
+
pattern = r"\.\d+(?=\D*$)"
|
259 |
+
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
260 |
+
if ".in." in diffusers_name:
|
261 |
+
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
262 |
+
if ".out." in diffusers_name:
|
263 |
+
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
264 |
+
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
265 |
+
diffusers_name = diffusers_name.replace("op", "conv")
|
266 |
+
if "skip" in diffusers_name:
|
267 |
+
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
268 |
+
|
269 |
+
# LyCORIS specific conversions.
|
270 |
+
if "time.emb.proj" in diffusers_name:
|
271 |
+
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
272 |
+
if "conv.shortcut" in diffusers_name:
|
273 |
+
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
274 |
+
|
275 |
+
# General conversions.
|
276 |
+
if "transformer_blocks" in diffusers_name:
|
277 |
+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
278 |
+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
279 |
+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
280 |
+
elif "ff" in diffusers_name:
|
281 |
+
pass
|
282 |
+
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
283 |
+
pass
|
284 |
+
else:
|
285 |
+
pass
|
286 |
+
|
287 |
+
return diffusers_name
|
288 |
+
|
289 |
+
|
290 |
+
def _convert_text_encoder_lora_key(key, lora_name):
|
291 |
+
"""
|
292 |
+
Converts a text encoder LoRA key to a Diffusers compatible key.
|
293 |
+
"""
|
294 |
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
295 |
+
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
296 |
+
else:
|
297 |
+
key_to_replace = "lora_te2_"
|
298 |
+
|
299 |
+
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
300 |
+
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
301 |
+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
302 |
+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
303 |
+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
304 |
+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
305 |
+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
306 |
+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
307 |
+
|
308 |
+
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
|
309 |
+
pass
|
310 |
+
elif "mlp" in diffusers_name:
|
311 |
+
# Be aware that this is the new diffusers convention and the rest of the code might
|
312 |
+
# not utilize it yet.
|
313 |
+
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
314 |
+
return diffusers_name
|
315 |
+
|
316 |
+
|
317 |
+
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
318 |
+
"""
|
319 |
+
Gets the correct alpha name for the Diffusers model.
|
320 |
+
"""
|
321 |
+
if lora_name_alpha.startswith("lora_unet_"):
|
322 |
+
prefix = "unet."
|
323 |
+
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
324 |
+
prefix = "text_encoder."
|
325 |
+
else:
|
326 |
+
prefix = "text_encoder_2."
|
327 |
+
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
328 |
+
return {new_name: alpha}
|
flux/lora/lora_pipeline.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
flux/lora/peft.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import inspect
|
16 |
+
from functools import partial
|
17 |
+
from typing import Dict, List, Optional, Union
|
18 |
+
|
19 |
+
from diffusers.utils import (
|
20 |
+
MIN_PEFT_VERSION,
|
21 |
+
USE_PEFT_BACKEND,
|
22 |
+
check_peft_version,
|
23 |
+
delete_adapter_layers,
|
24 |
+
is_peft_available,
|
25 |
+
set_adapter_layers,
|
26 |
+
set_weights_and_activate_adapters,
|
27 |
+
)
|
28 |
+
#from .unet_loader_utils import _maybe_expand_lora_scales
|
29 |
+
|
30 |
+
|
31 |
+
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
32 |
+
#"UNet2DConditionModel": _maybe_expand_lora_scales,
|
33 |
+
#"UNetMotionModel": _maybe_expand_lora_scales,
|
34 |
+
"SD3Transformer2DModel": lambda model_cls, weights: weights,
|
35 |
+
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class PeftAdapterMixin:
|
40 |
+
"""
|
41 |
+
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
42 |
+
more details about adapters and injecting them in a base model, check out the PEFT
|
43 |
+
[documentation](https://huggingface.co/docs/peft/index).
|
44 |
+
|
45 |
+
Install the latest version of PEFT, and use this mixin to:
|
46 |
+
|
47 |
+
- Attach new adapters in the model.
|
48 |
+
- Attach multiple adapters and iteratively activate/deactivate them.
|
49 |
+
- Activate/deactivate all adapters from the model.
|
50 |
+
- Get a list of the active adapters.
|
51 |
+
"""
|
52 |
+
|
53 |
+
_hf_peft_config_loaded = False
|
54 |
+
|
55 |
+
def set_adapters(
|
56 |
+
self,
|
57 |
+
adapter_names: Union[List[str], str],
|
58 |
+
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
Set the currently active adapters for use in the UNet.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
adapter_names (`List[str]` or `str`):
|
65 |
+
The names of the adapters to use.
|
66 |
+
adapter_weights (`Union[List[float], float]`, *optional*):
|
67 |
+
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
68 |
+
adapters.
|
69 |
+
|
70 |
+
Example:
|
71 |
+
|
72 |
+
```py
|
73 |
+
from diffusers import AutoPipelineForText2Image
|
74 |
+
import torch
|
75 |
+
|
76 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
77 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
78 |
+
).to("cuda")
|
79 |
+
pipeline.load_lora_weights(
|
80 |
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
81 |
+
)
|
82 |
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
83 |
+
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
84 |
+
```
|
85 |
+
"""
|
86 |
+
if not USE_PEFT_BACKEND:
|
87 |
+
raise ValueError("PEFT backend is required for `set_adapters()`.")
|
88 |
+
|
89 |
+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
90 |
+
|
91 |
+
# Expand weights into a list, one entry per adapter
|
92 |
+
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
|
93 |
+
if not isinstance(weights, list):
|
94 |
+
weights = [weights] * len(adapter_names)
|
95 |
+
|
96 |
+
if len(adapter_names) != len(weights):
|
97 |
+
raise ValueError(
|
98 |
+
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
99 |
+
)
|
100 |
+
|
101 |
+
# Set None values to default of 1.0
|
102 |
+
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
|
103 |
+
weights = [w if w is not None else 1.0 for w in weights]
|
104 |
+
|
105 |
+
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
|
106 |
+
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
|
107 |
+
weights = scale_expansion_fn(self, weights)
|
108 |
+
|
109 |
+
set_weights_and_activate_adapters(self, adapter_names, weights)
|
110 |
+
|
111 |
+
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
|
112 |
+
r"""
|
113 |
+
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
|
114 |
+
to the adapter to follow the convention of the PEFT library.
|
115 |
+
|
116 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
|
117 |
+
[documentation](https://huggingface.co/docs/peft).
|
118 |
+
|
119 |
+
Args:
|
120 |
+
adapter_config (`[~peft.PeftConfig]`):
|
121 |
+
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
|
122 |
+
methods.
|
123 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
124 |
+
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
|
125 |
+
"""
|
126 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
127 |
+
|
128 |
+
if not is_peft_available():
|
129 |
+
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
|
130 |
+
|
131 |
+
from peft import PeftConfig, inject_adapter_in_model
|
132 |
+
|
133 |
+
if not self._hf_peft_config_loaded:
|
134 |
+
self._hf_peft_config_loaded = True
|
135 |
+
elif adapter_name in self.peft_config:
|
136 |
+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
137 |
+
|
138 |
+
if not isinstance(adapter_config, PeftConfig):
|
139 |
+
raise ValueError(
|
140 |
+
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
|
141 |
+
)
|
142 |
+
|
143 |
+
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
|
144 |
+
# handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here.
|
145 |
+
adapter_config.base_model_name_or_path = None
|
146 |
+
inject_adapter_in_model(adapter_config, self, adapter_name)
|
147 |
+
self.set_adapter(adapter_name)
|
148 |
+
|
149 |
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
150 |
+
"""
|
151 |
+
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
|
152 |
+
|
153 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
154 |
+
[documentation](https://huggingface.co/docs/peft).
|
155 |
+
|
156 |
+
Args:
|
157 |
+
adapter_name (Union[str, List[str]])):
|
158 |
+
The list of adapters to set or the adapter name in the case of a single adapter.
|
159 |
+
"""
|
160 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
161 |
+
|
162 |
+
if not self._hf_peft_config_loaded:
|
163 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
164 |
+
|
165 |
+
if isinstance(adapter_name, str):
|
166 |
+
adapter_name = [adapter_name]
|
167 |
+
|
168 |
+
missing = set(adapter_name) - set(self.peft_config)
|
169 |
+
if len(missing) > 0:
|
170 |
+
raise ValueError(
|
171 |
+
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
|
172 |
+
f" current loaded adapters are: {list(self.peft_config.keys())}"
|
173 |
+
)
|
174 |
+
|
175 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
176 |
+
|
177 |
+
_adapters_has_been_set = False
|
178 |
+
|
179 |
+
for _, module in self.named_modules():
|
180 |
+
if isinstance(module, BaseTunerLayer):
|
181 |
+
if hasattr(module, "set_adapter"):
|
182 |
+
module.set_adapter(adapter_name)
|
183 |
+
# Previous versions of PEFT does not support multi-adapter inference
|
184 |
+
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
|
185 |
+
raise ValueError(
|
186 |
+
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
|
187 |
+
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
module.active_adapter = adapter_name
|
191 |
+
_adapters_has_been_set = True
|
192 |
+
|
193 |
+
if not _adapters_has_been_set:
|
194 |
+
raise ValueError(
|
195 |
+
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
|
196 |
+
)
|
197 |
+
|
198 |
+
def disable_adapters(self) -> None:
|
199 |
+
r"""
|
200 |
+
Disable all adapters attached to the model and fallback to inference with the base model only.
|
201 |
+
|
202 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
203 |
+
[documentation](https://huggingface.co/docs/peft).
|
204 |
+
"""
|
205 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
206 |
+
|
207 |
+
if not self._hf_peft_config_loaded:
|
208 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
209 |
+
|
210 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
211 |
+
|
212 |
+
for _, module in self.named_modules():
|
213 |
+
if isinstance(module, BaseTunerLayer):
|
214 |
+
if hasattr(module, "enable_adapters"):
|
215 |
+
module.enable_adapters(enabled=False)
|
216 |
+
else:
|
217 |
+
# support for older PEFT versions
|
218 |
+
module.disable_adapters = True
|
219 |
+
|
220 |
+
def enable_adapters(self) -> None:
|
221 |
+
"""
|
222 |
+
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
|
223 |
+
adapters to enable.
|
224 |
+
|
225 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
226 |
+
[documentation](https://huggingface.co/docs/peft).
|
227 |
+
"""
|
228 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
229 |
+
|
230 |
+
if not self._hf_peft_config_loaded:
|
231 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
232 |
+
|
233 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
234 |
+
|
235 |
+
for _, module in self.named_modules():
|
236 |
+
if isinstance(module, BaseTunerLayer):
|
237 |
+
if hasattr(module, "enable_adapters"):
|
238 |
+
module.enable_adapters(enabled=True)
|
239 |
+
else:
|
240 |
+
# support for older PEFT versions
|
241 |
+
module.disable_adapters = False
|
242 |
+
|
243 |
+
def active_adapters(self) -> List[str]:
|
244 |
+
"""
|
245 |
+
Gets the current list of active adapters of the model.
|
246 |
+
|
247 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
248 |
+
[documentation](https://huggingface.co/docs/peft).
|
249 |
+
"""
|
250 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
251 |
+
|
252 |
+
if not is_peft_available():
|
253 |
+
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
|
254 |
+
|
255 |
+
if not self._hf_peft_config_loaded:
|
256 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
257 |
+
|
258 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
259 |
+
|
260 |
+
for _, module in self.named_modules():
|
261 |
+
if isinstance(module, BaseTunerLayer):
|
262 |
+
return module.active_adapter
|
263 |
+
|
264 |
+
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
265 |
+
if not USE_PEFT_BACKEND:
|
266 |
+
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
267 |
+
|
268 |
+
self.lora_scale = lora_scale
|
269 |
+
self._safe_fusing = safe_fusing
|
270 |
+
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
271 |
+
|
272 |
+
def _fuse_lora_apply(self, module, adapter_names=None):
|
273 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
274 |
+
|
275 |
+
merge_kwargs = {"safe_merge": self._safe_fusing}
|
276 |
+
|
277 |
+
if isinstance(module, BaseTunerLayer):
|
278 |
+
if self.lora_scale != 1.0:
|
279 |
+
module.scale_layer(self.lora_scale)
|
280 |
+
|
281 |
+
# For BC with prevous PEFT versions, we need to check the signature
|
282 |
+
# of the `merge` method to see if it supports the `adapter_names` argument.
|
283 |
+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
284 |
+
if "adapter_names" in supported_merge_kwargs:
|
285 |
+
merge_kwargs["adapter_names"] = adapter_names
|
286 |
+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
287 |
+
raise ValueError(
|
288 |
+
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
289 |
+
" to the latest version of PEFT. `pip install -U peft`"
|
290 |
+
)
|
291 |
+
|
292 |
+
module.merge(**merge_kwargs)
|
293 |
+
|
294 |
+
def unfuse_lora(self):
|
295 |
+
if not USE_PEFT_BACKEND:
|
296 |
+
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
297 |
+
self.apply(self._unfuse_lora_apply)
|
298 |
+
|
299 |
+
def _unfuse_lora_apply(self, module):
|
300 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
301 |
+
|
302 |
+
if isinstance(module, BaseTunerLayer):
|
303 |
+
module.unmerge()
|
304 |
+
|
305 |
+
def unload_lora(self):
|
306 |
+
if not USE_PEFT_BACKEND:
|
307 |
+
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
308 |
+
|
309 |
+
from diffusers.utils import recurse_remove_peft_layers
|
310 |
+
|
311 |
+
recurse_remove_peft_layers(self)
|
312 |
+
if hasattr(self, "peft_config"):
|
313 |
+
del self.peft_config
|
314 |
+
|
315 |
+
def disable_lora(self):
|
316 |
+
"""
|
317 |
+
Disables the active LoRA layers of the underlying model.
|
318 |
+
|
319 |
+
Example:
|
320 |
+
|
321 |
+
```py
|
322 |
+
from diffusers import AutoPipelineForText2Image
|
323 |
+
import torch
|
324 |
+
|
325 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
326 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
327 |
+
).to("cuda")
|
328 |
+
pipeline.load_lora_weights(
|
329 |
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
330 |
+
)
|
331 |
+
pipeline.disable_lora()
|
332 |
+
```
|
333 |
+
"""
|
334 |
+
if not USE_PEFT_BACKEND:
|
335 |
+
raise ValueError("PEFT backend is required for this method.")
|
336 |
+
set_adapter_layers(self, enabled=False)
|
337 |
+
|
338 |
+
def enable_lora(self):
|
339 |
+
"""
|
340 |
+
Enables the active LoRA layers of the underlying model.
|
341 |
+
|
342 |
+
Example:
|
343 |
+
|
344 |
+
```py
|
345 |
+
from diffusers import AutoPipelineForText2Image
|
346 |
+
import torch
|
347 |
+
|
348 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
349 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
350 |
+
).to("cuda")
|
351 |
+
pipeline.load_lora_weights(
|
352 |
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
353 |
+
)
|
354 |
+
pipeline.enable_lora()
|
355 |
+
```
|
356 |
+
"""
|
357 |
+
if not USE_PEFT_BACKEND:
|
358 |
+
raise ValueError("PEFT backend is required for this method.")
|
359 |
+
set_adapter_layers(self, enabled=True)
|
360 |
+
|
361 |
+
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
362 |
+
"""
|
363 |
+
Delete an adapter's LoRA layers from the underlying model.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
adapter_names (`Union[List[str], str]`):
|
367 |
+
The names (single string or list of strings) of the adapter to delete.
|
368 |
+
|
369 |
+
Example:
|
370 |
+
|
371 |
+
```py
|
372 |
+
from diffusers import AutoPipelineForText2Image
|
373 |
+
import torch
|
374 |
+
|
375 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
376 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
377 |
+
).to("cuda")
|
378 |
+
pipeline.load_lora_weights(
|
379 |
+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
380 |
+
)
|
381 |
+
pipeline.delete_adapters("cinematic")
|
382 |
+
```
|
383 |
+
"""
|
384 |
+
if not USE_PEFT_BACKEND:
|
385 |
+
raise ValueError("PEFT backend is required for this method.")
|
386 |
+
|
387 |
+
if isinstance(adapter_names, str):
|
388 |
+
adapter_names = [adapter_names]
|
389 |
+
|
390 |
+
for adapter_name in adapter_names:
|
391 |
+
delete_adapter_layers(self, adapter_name)
|
392 |
+
|
393 |
+
# Pop also the corresponding adapter from the config
|
394 |
+
if hasattr(self, "peft_config"):
|
395 |
+
self.peft_config.pop(adapter_name, None)
|
flux/normalization.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import numbers
|
17 |
+
from typing import Dict, Optional, Tuple
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from diffusers.utils import is_torch_version
|
24 |
+
from .embeddings import get_activation
|
25 |
+
from .embeddings import (
|
26 |
+
CombinedTimestepLabelEmbeddings,
|
27 |
+
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class AdaLayerNorm(nn.Module):
|
32 |
+
r"""
|
33 |
+
Norm layer modified to incorporate timestep embeddings.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
embedding_dim (`int`): The size of each embedding vector.
|
37 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, embedding_dim: int, num_embeddings: int):
|
41 |
+
super().__init__()
|
42 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
43 |
+
self.silu = nn.SiLU()
|
44 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
45 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
48 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
49 |
+
scale, shift = torch.chunk(emb, 2)
|
50 |
+
x = self.norm(x) * (1 + scale) + shift
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class FP32LayerNorm(nn.LayerNorm):
|
55 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
56 |
+
origin_dtype = inputs.dtype
|
57 |
+
return F.layer_norm(
|
58 |
+
inputs.float(),
|
59 |
+
self.normalized_shape,
|
60 |
+
self.weight.float() if self.weight is not None else None,
|
61 |
+
self.bias.float() if self.bias is not None else None,
|
62 |
+
self.eps,
|
63 |
+
).to(origin_dtype)
|
64 |
+
|
65 |
+
|
66 |
+
class AdaLayerNormZero(nn.Module):
|
67 |
+
r"""
|
68 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
embedding_dim (`int`): The size of each embedding vector.
|
72 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
76 |
+
super().__init__()
|
77 |
+
if num_embeddings is not None:
|
78 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
79 |
+
else:
|
80 |
+
self.emb = None
|
81 |
+
|
82 |
+
self.silu = nn.SiLU()
|
83 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
84 |
+
if norm_type == "layer_norm":
|
85 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
86 |
+
elif norm_type == "fp32_layer_norm":
|
87 |
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
88 |
+
else:
|
89 |
+
raise ValueError(
|
90 |
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
x: torch.Tensor,
|
96 |
+
timestep: Optional[torch.Tensor] = None,
|
97 |
+
class_labels: Optional[torch.LongTensor] = None,
|
98 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
99 |
+
emb: Optional[torch.Tensor] = None,
|
100 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
101 |
+
if self.emb is not None:
|
102 |
+
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
103 |
+
emb = self.linear(self.silu(emb))
|
104 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
105 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
106 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
107 |
+
|
108 |
+
|
109 |
+
class AdaLayerNormZeroSingle(nn.Module):
|
110 |
+
r"""
|
111 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
112 |
+
|
113 |
+
Parameters:
|
114 |
+
embedding_dim (`int`): The size of each embedding vector.
|
115 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
119 |
+
super().__init__()
|
120 |
+
|
121 |
+
self.silu = nn.SiLU()
|
122 |
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
123 |
+
if norm_type == "layer_norm":
|
124 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
125 |
+
else:
|
126 |
+
raise ValueError(
|
127 |
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(
|
131 |
+
self,
|
132 |
+
x: torch.Tensor,
|
133 |
+
emb: Optional[torch.Tensor] = None,
|
134 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
135 |
+
emb = self.linear(self.silu(emb))
|
136 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
137 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
138 |
+
return x, gate_msa
|
139 |
+
|
140 |
+
|
141 |
+
class LuminaRMSNormZero(nn.Module):
|
142 |
+
"""
|
143 |
+
Norm layer adaptive RMS normalization zero.
|
144 |
+
|
145 |
+
Parameters:
|
146 |
+
embedding_dim (`int`): The size of each embedding vector.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
|
150 |
+
super().__init__()
|
151 |
+
self.silu = nn.SiLU()
|
152 |
+
self.linear = nn.Linear(
|
153 |
+
min(embedding_dim, 1024),
|
154 |
+
4 * embedding_dim,
|
155 |
+
bias=True,
|
156 |
+
)
|
157 |
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
x: torch.Tensor,
|
162 |
+
emb: Optional[torch.Tensor] = None,
|
163 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
164 |
+
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
165 |
+
emb = self.linear(self.silu(emb))
|
166 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
167 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
168 |
+
|
169 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
170 |
+
|
171 |
+
|
172 |
+
class AdaLayerNormSingle(nn.Module):
|
173 |
+
r"""
|
174 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
175 |
+
|
176 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
177 |
+
|
178 |
+
Parameters:
|
179 |
+
embedding_dim (`int`): The size of each embedding vector.
|
180 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
187 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
188 |
+
)
|
189 |
+
|
190 |
+
self.silu = nn.SiLU()
|
191 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
timestep: torch.Tensor,
|
196 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
197 |
+
batch_size: Optional[int] = None,
|
198 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
199 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
200 |
+
# No modulation happening here.
|
201 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
202 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
203 |
+
|
204 |
+
|
205 |
+
class AdaGroupNorm(nn.Module):
|
206 |
+
r"""
|
207 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
208 |
+
|
209 |
+
Parameters:
|
210 |
+
embedding_dim (`int`): The size of each embedding vector.
|
211 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
212 |
+
num_groups (`int`): The number of groups to separate the channels into.
|
213 |
+
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
214 |
+
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
self.num_groups = num_groups
|
222 |
+
self.eps = eps
|
223 |
+
|
224 |
+
if act_fn is None:
|
225 |
+
self.act = None
|
226 |
+
else:
|
227 |
+
self.act = get_activation(act_fn)
|
228 |
+
|
229 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
230 |
+
|
231 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
232 |
+
if self.act:
|
233 |
+
emb = self.act(emb)
|
234 |
+
emb = self.linear(emb)
|
235 |
+
emb = emb[:, :, None, None]
|
236 |
+
scale, shift = emb.chunk(2, dim=1)
|
237 |
+
|
238 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
239 |
+
x = x * (1 + scale) + shift
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class AdaLayerNormContinuous(nn.Module):
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
embedding_dim: int,
|
247 |
+
conditioning_embedding_dim: int,
|
248 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
249 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
250 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
251 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
252 |
+
# set `elementwise_affine` to False.
|
253 |
+
elementwise_affine=True,
|
254 |
+
eps=1e-5,
|
255 |
+
bias=True,
|
256 |
+
norm_type="layer_norm",
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
self.silu = nn.SiLU()
|
260 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
261 |
+
if norm_type == "layer_norm":
|
262 |
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
263 |
+
elif norm_type == "rms_norm":
|
264 |
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
265 |
+
else:
|
266 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
267 |
+
|
268 |
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
269 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
270 |
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
271 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
272 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class LuminaLayerNormContinuous(nn.Module):
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
embedding_dim: int,
|
280 |
+
conditioning_embedding_dim: int,
|
281 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
282 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
283 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
284 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
285 |
+
# set `elementwise_affine` to False.
|
286 |
+
elementwise_affine=True,
|
287 |
+
eps=1e-5,
|
288 |
+
bias=True,
|
289 |
+
norm_type="layer_norm",
|
290 |
+
out_dim: Optional[int] = None,
|
291 |
+
):
|
292 |
+
super().__init__()
|
293 |
+
# AdaLN
|
294 |
+
self.silu = nn.SiLU()
|
295 |
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
296 |
+
if norm_type == "layer_norm":
|
297 |
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
298 |
+
else:
|
299 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
300 |
+
# linear_2
|
301 |
+
if out_dim is not None:
|
302 |
+
self.linear_2 = nn.Linear(
|
303 |
+
embedding_dim,
|
304 |
+
out_dim,
|
305 |
+
bias=bias,
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(
|
309 |
+
self,
|
310 |
+
x: torch.Tensor,
|
311 |
+
conditioning_embedding: torch.Tensor,
|
312 |
+
) -> torch.Tensor:
|
313 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
314 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
315 |
+
scale = emb
|
316 |
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
317 |
+
|
318 |
+
if self.linear_2 is not None:
|
319 |
+
x = self.linear_2(x)
|
320 |
+
|
321 |
+
return x
|
322 |
+
|
323 |
+
|
324 |
+
if is_torch_version(">=", "2.1.0"):
|
325 |
+
LayerNorm = nn.LayerNorm
|
326 |
+
else:
|
327 |
+
# Has optional bias parameter compared to torch layer norm
|
328 |
+
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
329 |
+
class LayerNorm(nn.Module):
|
330 |
+
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
331 |
+
super().__init__()
|
332 |
+
|
333 |
+
self.eps = eps
|
334 |
+
|
335 |
+
if isinstance(dim, numbers.Integral):
|
336 |
+
dim = (dim,)
|
337 |
+
|
338 |
+
self.dim = torch.Size(dim)
|
339 |
+
|
340 |
+
if elementwise_affine:
|
341 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
342 |
+
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
343 |
+
else:
|
344 |
+
self.weight = None
|
345 |
+
self.bias = None
|
346 |
+
|
347 |
+
def forward(self, input):
|
348 |
+
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
349 |
+
|
350 |
+
|
351 |
+
class RMSNorm(nn.Module):
|
352 |
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
353 |
+
super().__init__()
|
354 |
+
|
355 |
+
self.eps = eps
|
356 |
+
|
357 |
+
if isinstance(dim, numbers.Integral):
|
358 |
+
dim = (dim,)
|
359 |
+
|
360 |
+
self.dim = torch.Size(dim)
|
361 |
+
|
362 |
+
if elementwise_affine:
|
363 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
364 |
+
else:
|
365 |
+
self.weight = None
|
366 |
+
|
367 |
+
def forward(self, hidden_states):
|
368 |
+
input_dtype = hidden_states.dtype
|
369 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
370 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
371 |
+
|
372 |
+
if self.weight is not None:
|
373 |
+
# convert into half-precision if necessary
|
374 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
375 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
376 |
+
hidden_states = hidden_states * self.weight
|
377 |
+
else:
|
378 |
+
hidden_states = hidden_states.to(input_dtype)
|
379 |
+
|
380 |
+
return hidden_states
|
381 |
+
|
382 |
+
|
383 |
+
class GlobalResponseNorm(nn.Module):
|
384 |
+
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
385 |
+
def __init__(self, dim):
|
386 |
+
super().__init__()
|
387 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
388 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
389 |
+
|
390 |
+
def forward(self, x):
|
391 |
+
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
392 |
+
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
393 |
+
return self.gamma * (x * nx) + self.beta + x
|
flux/pipeline_flux.py
ADDED
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
21 |
+
|
22 |
+
from diffusers.image_processor import VaeImageProcessor
|
23 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
24 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
25 |
+
from .transformer_flux import FluxTransformer2DModel
|
26 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
27 |
+
from diffusers.utils import (
|
28 |
+
USE_PEFT_BACKEND,
|
29 |
+
is_torch_xla_available,
|
30 |
+
logging,
|
31 |
+
replace_example_docstring,
|
32 |
+
scale_lora_layers,
|
33 |
+
unscale_lora_layers,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
37 |
+
from .pipeline_output import FluxPipelineOutput
|
38 |
+
|
39 |
+
|
40 |
+
if is_torch_xla_available():
|
41 |
+
import torch_xla.core.xla_model as xm
|
42 |
+
|
43 |
+
XLA_AVAILABLE = True
|
44 |
+
else:
|
45 |
+
XLA_AVAILABLE = False
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
EXAMPLE_DOC_STRING = """
|
51 |
+
Examples:
|
52 |
+
```py
|
53 |
+
>>> import torch
|
54 |
+
>>> from diffusers import FluxPipeline
|
55 |
+
|
56 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
57 |
+
>>> pipe.to("cuda")
|
58 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
59 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
60 |
+
>>> # Refer to the pipeline documentation for more details.
|
61 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
62 |
+
>>> image.save("flux.png")
|
63 |
+
```
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
def calculate_shift(
|
68 |
+
image_seq_len,
|
69 |
+
base_seq_len: int = 256,
|
70 |
+
max_seq_len: int = 4096,
|
71 |
+
base_shift: float = 0.5,
|
72 |
+
max_shift: float = 1.16,
|
73 |
+
):
|
74 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
75 |
+
b = base_shift - m * base_seq_len
|
76 |
+
mu = image_seq_len * m + b
|
77 |
+
return mu
|
78 |
+
|
79 |
+
|
80 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
81 |
+
def retrieve_timesteps(
|
82 |
+
scheduler,
|
83 |
+
num_inference_steps: Optional[int] = None,
|
84 |
+
device: Optional[Union[str, torch.device]] = None,
|
85 |
+
timesteps: Optional[List[int]] = None,
|
86 |
+
sigmas: Optional[List[float]] = None,
|
87 |
+
**kwargs,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
91 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
scheduler (`SchedulerMixin`):
|
95 |
+
The scheduler to get timesteps from.
|
96 |
+
num_inference_steps (`int`):
|
97 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
98 |
+
must be `None`.
|
99 |
+
device (`str` or `torch.device`, *optional*):
|
100 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
101 |
+
timesteps (`List[int]`, *optional*):
|
102 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
103 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
104 |
+
sigmas (`List[float]`, *optional*):
|
105 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
106 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
110 |
+
second element is the number of inference steps.
|
111 |
+
"""
|
112 |
+
if timesteps is not None and sigmas is not None:
|
113 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
114 |
+
if timesteps is not None:
|
115 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
116 |
+
if not accepts_timesteps:
|
117 |
+
raise ValueError(
|
118 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
119 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
120 |
+
)
|
121 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
122 |
+
timesteps = scheduler.timesteps
|
123 |
+
num_inference_steps = len(timesteps)
|
124 |
+
elif sigmas is not None:
|
125 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
126 |
+
if not accept_sigmas:
|
127 |
+
raise ValueError(
|
128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
129 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
130 |
+
)
|
131 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
132 |
+
timesteps = scheduler.timesteps
|
133 |
+
num_inference_steps = len(timesteps)
|
134 |
+
else:
|
135 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
136 |
+
timesteps = scheduler.timesteps
|
137 |
+
return timesteps, num_inference_steps
|
138 |
+
|
139 |
+
|
140 |
+
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
141 |
+
r"""
|
142 |
+
The Flux pipeline for text-to-image generation.
|
143 |
+
|
144 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
145 |
+
|
146 |
+
Args:
|
147 |
+
transformer ([`FluxTransformer2DModel`]):
|
148 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
149 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
150 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
151 |
+
vae ([`AutoencoderKL`]):
|
152 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
153 |
+
text_encoder ([`CLIPTextModel`]):
|
154 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
155 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
156 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
157 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
158 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
159 |
+
tokenizer (`CLIPTokenizer`):
|
160 |
+
Tokenizer of class
|
161 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
162 |
+
tokenizer_2 (`T5TokenizerFast`):
|
163 |
+
Second Tokenizer of class
|
164 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
165 |
+
"""
|
166 |
+
|
167 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
168 |
+
_optional_components = []
|
169 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
174 |
+
vae: AutoencoderKL,
|
175 |
+
text_encoder: CLIPTextModel,
|
176 |
+
tokenizer: CLIPTokenizer,
|
177 |
+
text_encoder_2: T5EncoderModel,
|
178 |
+
tokenizer_2: T5TokenizerFast,
|
179 |
+
transformer: FluxTransformer2DModel,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
self.register_modules(
|
184 |
+
vae=vae,
|
185 |
+
text_encoder=text_encoder,
|
186 |
+
text_encoder_2=text_encoder_2,
|
187 |
+
tokenizer=tokenizer,
|
188 |
+
tokenizer_2=tokenizer_2,
|
189 |
+
transformer=transformer,
|
190 |
+
scheduler=scheduler,
|
191 |
+
)
|
192 |
+
self.vae_scale_factor = (
|
193 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
194 |
+
)
|
195 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
196 |
+
self.tokenizer_max_length = (
|
197 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
198 |
+
)
|
199 |
+
self.default_sample_size = 64
|
200 |
+
|
201 |
+
def _get_t5_prompt_embeds(
|
202 |
+
self,
|
203 |
+
prompt: Union[str, List[str]] = None,
|
204 |
+
num_images_per_prompt: int = 1,
|
205 |
+
max_sequence_length: int = 512,
|
206 |
+
device: Optional[torch.device] = None,
|
207 |
+
dtype: Optional[torch.dtype] = None,
|
208 |
+
):
|
209 |
+
device = device or self._execution_device
|
210 |
+
dtype = dtype or self.text_encoder.dtype
|
211 |
+
|
212 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
213 |
+
batch_size = len(prompt)
|
214 |
+
|
215 |
+
text_inputs = self.tokenizer_2(
|
216 |
+
prompt,
|
217 |
+
padding="max_length",
|
218 |
+
max_length=max_sequence_length,
|
219 |
+
truncation=True,
|
220 |
+
return_length=False,
|
221 |
+
return_overflowing_tokens=False,
|
222 |
+
return_tensors="pt",
|
223 |
+
)
|
224 |
+
text_input_ids = text_inputs.input_ids
|
225 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
226 |
+
|
227 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
228 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
229 |
+
logger.warning(
|
230 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
231 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
232 |
+
)
|
233 |
+
|
234 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
235 |
+
|
236 |
+
dtype = self.text_encoder_2.dtype
|
237 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
238 |
+
|
239 |
+
_, seq_len, _ = prompt_embeds.shape
|
240 |
+
|
241 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
242 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
243 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
244 |
+
|
245 |
+
return prompt_embeds
|
246 |
+
|
247 |
+
def _get_clip_prompt_embeds(
|
248 |
+
self,
|
249 |
+
prompt: Union[str, List[str]],
|
250 |
+
num_images_per_prompt: int = 1,
|
251 |
+
device: Optional[torch.device] = None,
|
252 |
+
):
|
253 |
+
device = device or self._execution_device
|
254 |
+
|
255 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256 |
+
batch_size = len(prompt)
|
257 |
+
|
258 |
+
text_inputs = self.tokenizer(
|
259 |
+
prompt,
|
260 |
+
padding="max_length",
|
261 |
+
max_length=self.tokenizer_max_length,
|
262 |
+
truncation=True,
|
263 |
+
return_overflowing_tokens=False,
|
264 |
+
return_length=False,
|
265 |
+
return_tensors="pt",
|
266 |
+
)
|
267 |
+
|
268 |
+
text_input_ids = text_inputs.input_ids
|
269 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
270 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
271 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
272 |
+
logger.warning(
|
273 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
274 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
275 |
+
)
|
276 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
277 |
+
|
278 |
+
# Use pooled output of CLIPTextModel
|
279 |
+
prompt_embeds = prompt_embeds.pooler_output
|
280 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
281 |
+
|
282 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
283 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
284 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
285 |
+
|
286 |
+
return prompt_embeds
|
287 |
+
|
288 |
+
def encode_prompt(
|
289 |
+
self,
|
290 |
+
prompt: Union[str, List[str]],
|
291 |
+
prompt_2: Union[str, List[str]],
|
292 |
+
device: Optional[torch.device] = None,
|
293 |
+
num_images_per_prompt: int = 1,
|
294 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
295 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
296 |
+
max_sequence_length: int = 512,
|
297 |
+
lora_scale: Optional[float] = None,
|
298 |
+
):
|
299 |
+
r"""
|
300 |
+
|
301 |
+
Args:
|
302 |
+
prompt (`str` or `List[str]`, *optional*):
|
303 |
+
prompt to be encoded
|
304 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
305 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
306 |
+
used in all text-encoders
|
307 |
+
device: (`torch.device`):
|
308 |
+
torch device
|
309 |
+
num_images_per_prompt (`int`):
|
310 |
+
number of images that should be generated per prompt
|
311 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
312 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
313 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
314 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
315 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
316 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
317 |
+
lora_scale (`float`, *optional*):
|
318 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
319 |
+
"""
|
320 |
+
device = device or self._execution_device
|
321 |
+
|
322 |
+
# set lora scale so that monkey patched LoRA
|
323 |
+
# function of text encoder can correctly access it
|
324 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
325 |
+
self._lora_scale = lora_scale
|
326 |
+
|
327 |
+
# dynamically adjust the LoRA scale
|
328 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
329 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
330 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
331 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
332 |
+
|
333 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
334 |
+
if prompt is not None:
|
335 |
+
batch_size = len(prompt)
|
336 |
+
else:
|
337 |
+
batch_size = prompt_embeds.shape[0]
|
338 |
+
|
339 |
+
if prompt_embeds is None:
|
340 |
+
prompt_2 = prompt_2 or prompt
|
341 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
342 |
+
|
343 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
344 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
345 |
+
prompt=prompt,
|
346 |
+
device=device,
|
347 |
+
num_images_per_prompt=num_images_per_prompt,
|
348 |
+
)
|
349 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
350 |
+
prompt=prompt_2,
|
351 |
+
num_images_per_prompt=num_images_per_prompt,
|
352 |
+
max_sequence_length=max_sequence_length,
|
353 |
+
device=device,
|
354 |
+
)
|
355 |
+
|
356 |
+
if self.text_encoder is not None:
|
357 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
358 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
359 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
360 |
+
|
361 |
+
if self.text_encoder_2 is not None:
|
362 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
363 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
364 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
365 |
+
|
366 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
367 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
368 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
369 |
+
|
370 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
371 |
+
|
372 |
+
def check_inputs(
|
373 |
+
self,
|
374 |
+
prompt,
|
375 |
+
prompt_2,
|
376 |
+
height,
|
377 |
+
width,
|
378 |
+
prompt_embeds=None,
|
379 |
+
pooled_prompt_embeds=None,
|
380 |
+
callback_on_step_end_tensor_inputs=None,
|
381 |
+
max_sequence_length=None,
|
382 |
+
):
|
383 |
+
if height % 8 != 0 or width % 8 != 0:
|
384 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
385 |
+
|
386 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
387 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
388 |
+
):
|
389 |
+
raise ValueError(
|
390 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
391 |
+
)
|
392 |
+
|
393 |
+
if prompt is not None and prompt_embeds is not None:
|
394 |
+
raise ValueError(
|
395 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
396 |
+
" only forward one of the two."
|
397 |
+
)
|
398 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
399 |
+
raise ValueError(
|
400 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
401 |
+
" only forward one of the two."
|
402 |
+
)
|
403 |
+
elif prompt is None and prompt_embeds is None:
|
404 |
+
raise ValueError(
|
405 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
406 |
+
)
|
407 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
408 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
409 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
410 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
411 |
+
|
412 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
413 |
+
raise ValueError(
|
414 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
415 |
+
)
|
416 |
+
|
417 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
418 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
419 |
+
|
420 |
+
@staticmethod
|
421 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
422 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
423 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
424 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
425 |
+
|
426 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
427 |
+
|
428 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
429 |
+
latent_image_ids = latent_image_ids.reshape(
|
430 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
431 |
+
)
|
432 |
+
|
433 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
434 |
+
|
435 |
+
@staticmethod
|
436 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
437 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
438 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
439 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
440 |
+
|
441 |
+
return latents
|
442 |
+
|
443 |
+
@staticmethod
|
444 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
445 |
+
batch_size, num_patches, channels = latents.shape
|
446 |
+
|
447 |
+
height = height // vae_scale_factor
|
448 |
+
width = width // vae_scale_factor
|
449 |
+
|
450 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
451 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
452 |
+
|
453 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
454 |
+
|
455 |
+
return latents
|
456 |
+
|
457 |
+
def prepare_latents(
|
458 |
+
self,
|
459 |
+
batch_size,
|
460 |
+
num_channels_latents,
|
461 |
+
height,
|
462 |
+
width,
|
463 |
+
dtype,
|
464 |
+
device,
|
465 |
+
generator,
|
466 |
+
latents=None,
|
467 |
+
):
|
468 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
469 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
470 |
+
|
471 |
+
shape = (batch_size, num_channels_latents, height, width)
|
472 |
+
|
473 |
+
if latents is not None:
|
474 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
475 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
476 |
+
|
477 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
478 |
+
raise ValueError(
|
479 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
480 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
481 |
+
)
|
482 |
+
|
483 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
484 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
485 |
+
|
486 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
487 |
+
|
488 |
+
return latents, latent_image_ids
|
489 |
+
|
490 |
+
@property
|
491 |
+
def guidance_scale(self):
|
492 |
+
return self._guidance_scale
|
493 |
+
|
494 |
+
@property
|
495 |
+
def joint_attention_kwargs(self):
|
496 |
+
return self._joint_attention_kwargs
|
497 |
+
|
498 |
+
@property
|
499 |
+
def num_timesteps(self):
|
500 |
+
return self._num_timesteps
|
501 |
+
|
502 |
+
@property
|
503 |
+
def interrupt(self):
|
504 |
+
return self._interrupt
|
505 |
+
|
506 |
+
@torch.no_grad()
|
507 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
508 |
+
def __call__(
|
509 |
+
self,
|
510 |
+
prompt: Union[str, List[str]] = None,
|
511 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
512 |
+
height: Optional[int] = None,
|
513 |
+
width: Optional[int] = None,
|
514 |
+
num_inference_steps: int = 28,
|
515 |
+
timesteps: List[int] = None,
|
516 |
+
guidance_scale: float = 7.0,
|
517 |
+
num_images_per_prompt: Optional[int] = 1,
|
518 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
519 |
+
latents: Optional[torch.FloatTensor] = None,
|
520 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
521 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
522 |
+
output_type: Optional[str] = "pil",
|
523 |
+
return_dict: bool = True,
|
524 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
525 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
526 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
527 |
+
max_sequence_length: int = 512,
|
528 |
+
):
|
529 |
+
r"""
|
530 |
+
Function invoked when calling the pipeline for generation.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
prompt (`str` or `List[str]`, *optional*):
|
534 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
535 |
+
instead.
|
536 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
537 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
538 |
+
will be used instead
|
539 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
540 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
541 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
542 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
543 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
544 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
545 |
+
expense of slower inference.
|
546 |
+
timesteps (`List[int]`, *optional*):
|
547 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
548 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
549 |
+
passed will be used. Must be in descending order.
|
550 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
551 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
552 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
553 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
554 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
555 |
+
usually at the expense of lower image quality.
|
556 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
557 |
+
The number of images to generate per prompt.
|
558 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
559 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
560 |
+
to make generation deterministic.
|
561 |
+
latents (`torch.FloatTensor`, *optional*):
|
562 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
563 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
564 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
565 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
566 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
567 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
568 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
569 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
570 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
571 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
572 |
+
The output format of the generate image. Choose between
|
573 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
574 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
575 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
576 |
+
joint_attention_kwargs (`dict`, *optional*):
|
577 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
578 |
+
`self.processor` in
|
579 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
580 |
+
callback_on_step_end (`Callable`, *optional*):
|
581 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
582 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
583 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
584 |
+
`callback_on_step_end_tensor_inputs`.
|
585 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
586 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
587 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
588 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
589 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
590 |
+
|
591 |
+
Examples:
|
592 |
+
|
593 |
+
Returns:
|
594 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
595 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
596 |
+
images.
|
597 |
+
"""
|
598 |
+
|
599 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
600 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
601 |
+
|
602 |
+
# 1. Check inputs. Raise error if not correct
|
603 |
+
self.check_inputs(
|
604 |
+
prompt,
|
605 |
+
prompt_2,
|
606 |
+
height,
|
607 |
+
width,
|
608 |
+
prompt_embeds=prompt_embeds,
|
609 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
610 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
611 |
+
max_sequence_length=max_sequence_length,
|
612 |
+
)
|
613 |
+
|
614 |
+
self._guidance_scale = guidance_scale
|
615 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
616 |
+
self._interrupt = False
|
617 |
+
|
618 |
+
# 2. Define call parameters
|
619 |
+
if prompt is not None and isinstance(prompt, str):
|
620 |
+
batch_size = 1
|
621 |
+
elif prompt is not None and isinstance(prompt, list):
|
622 |
+
batch_size = len(prompt)
|
623 |
+
else:
|
624 |
+
batch_size = prompt_embeds.shape[0]
|
625 |
+
|
626 |
+
device = self._execution_device
|
627 |
+
|
628 |
+
lora_scale = (
|
629 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
630 |
+
)
|
631 |
+
(
|
632 |
+
prompt_embeds,
|
633 |
+
pooled_prompt_embeds,
|
634 |
+
text_ids,
|
635 |
+
) = self.encode_prompt(
|
636 |
+
prompt=prompt,
|
637 |
+
prompt_2=prompt_2,
|
638 |
+
prompt_embeds=prompt_embeds,
|
639 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
640 |
+
device=device,
|
641 |
+
num_images_per_prompt=num_images_per_prompt,
|
642 |
+
max_sequence_length=max_sequence_length,
|
643 |
+
lora_scale=lora_scale,
|
644 |
+
)
|
645 |
+
|
646 |
+
# 4. Prepare latent variables
|
647 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
648 |
+
latents, latent_image_ids = self.prepare_latents(
|
649 |
+
batch_size * num_images_per_prompt,
|
650 |
+
num_channels_latents,
|
651 |
+
height,
|
652 |
+
width,
|
653 |
+
prompt_embeds.dtype,
|
654 |
+
device,
|
655 |
+
generator,
|
656 |
+
latents,
|
657 |
+
)
|
658 |
+
|
659 |
+
# 5. Prepare timesteps
|
660 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
661 |
+
image_seq_len = latents.shape[1]
|
662 |
+
mu = calculate_shift(
|
663 |
+
image_seq_len,
|
664 |
+
self.scheduler.config.base_image_seq_len,
|
665 |
+
self.scheduler.config.max_image_seq_len,
|
666 |
+
self.scheduler.config.base_shift,
|
667 |
+
self.scheduler.config.max_shift,
|
668 |
+
)
|
669 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
670 |
+
self.scheduler,
|
671 |
+
num_inference_steps,
|
672 |
+
device,
|
673 |
+
timesteps,
|
674 |
+
sigmas,
|
675 |
+
mu=mu,
|
676 |
+
)
|
677 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
678 |
+
self._num_timesteps = len(timesteps)
|
679 |
+
|
680 |
+
# 6. Denoising loop
|
681 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
682 |
+
for i, t in enumerate(timesteps):
|
683 |
+
if self.interrupt:
|
684 |
+
continue
|
685 |
+
|
686 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
687 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
688 |
+
|
689 |
+
# handle guidance
|
690 |
+
if self.transformer.config.guidance_embeds:
|
691 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
692 |
+
guidance = guidance.expand(latents.shape[0])
|
693 |
+
else:
|
694 |
+
guidance = None
|
695 |
+
|
696 |
+
noise_pred = self.transformer(
|
697 |
+
hidden_states=latents,
|
698 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
699 |
+
timestep=timestep / 1000,
|
700 |
+
guidance=guidance,
|
701 |
+
pooled_projections=pooled_prompt_embeds,
|
702 |
+
encoder_hidden_states=prompt_embeds,
|
703 |
+
txt_ids=text_ids,
|
704 |
+
img_ids=latent_image_ids,
|
705 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
706 |
+
return_dict=False,
|
707 |
+
)[0]
|
708 |
+
|
709 |
+
# compute the previous noisy sample x_t -> x_t-1
|
710 |
+
latents_dtype = latents.dtype
|
711 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
712 |
+
|
713 |
+
if latents.dtype != latents_dtype:
|
714 |
+
if torch.backends.mps.is_available():
|
715 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
716 |
+
latents = latents.to(latents_dtype)
|
717 |
+
|
718 |
+
if callback_on_step_end is not None:
|
719 |
+
callback_kwargs = {}
|
720 |
+
for k in callback_on_step_end_tensor_inputs:
|
721 |
+
callback_kwargs[k] = locals()[k]
|
722 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
723 |
+
|
724 |
+
latents = callback_outputs.pop("latents", latents)
|
725 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
726 |
+
|
727 |
+
# call the callback, if provided
|
728 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
729 |
+
progress_bar.update()
|
730 |
+
|
731 |
+
if XLA_AVAILABLE:
|
732 |
+
xm.mark_step()
|
733 |
+
|
734 |
+
if output_type == "latent":
|
735 |
+
image = latents
|
736 |
+
|
737 |
+
else:
|
738 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
739 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
740 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
741 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
742 |
+
|
743 |
+
# Offload all models
|
744 |
+
self.maybe_free_model_hooks()
|
745 |
+
|
746 |
+
if not return_dict:
|
747 |
+
return (image,)
|
748 |
+
|
749 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_chameleon.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
21 |
+
|
22 |
+
from diffusers.image_processor import VaeImageProcessor
|
23 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
24 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
25 |
+
from .transformer_flux import FluxTransformer2DModel
|
26 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
27 |
+
from diffusers.utils import (
|
28 |
+
USE_PEFT_BACKEND,
|
29 |
+
is_torch_xla_available,
|
30 |
+
logging,
|
31 |
+
replace_example_docstring,
|
32 |
+
scale_lora_layers,
|
33 |
+
unscale_lora_layers,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
37 |
+
from .pipeline_output import FluxPipelineOutput
|
38 |
+
|
39 |
+
|
40 |
+
if is_torch_xla_available():
|
41 |
+
import torch_xla.core.xla_model as xm
|
42 |
+
|
43 |
+
XLA_AVAILABLE = True
|
44 |
+
else:
|
45 |
+
XLA_AVAILABLE = False
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
EXAMPLE_DOC_STRING = """
|
51 |
+
Examples:
|
52 |
+
```py
|
53 |
+
>>> import torch
|
54 |
+
>>> from diffusers import FluxPipeline
|
55 |
+
|
56 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
57 |
+
>>> pipe.to("cuda")
|
58 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
59 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
60 |
+
>>> # Refer to the pipeline documentation for more details.
|
61 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
62 |
+
>>> image.save("flux.png")
|
63 |
+
```
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
def calculate_shift(
|
68 |
+
image_seq_len,
|
69 |
+
base_seq_len: int = 256,
|
70 |
+
max_seq_len: int = 4096,
|
71 |
+
base_shift: float = 0.5,
|
72 |
+
max_shift: float = 1.16,
|
73 |
+
):
|
74 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
75 |
+
b = base_shift - m * base_seq_len
|
76 |
+
mu = image_seq_len * m + b
|
77 |
+
return mu
|
78 |
+
|
79 |
+
|
80 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
81 |
+
def retrieve_timesteps(
|
82 |
+
scheduler,
|
83 |
+
num_inference_steps: Optional[int] = None,
|
84 |
+
device: Optional[Union[str, torch.device]] = None,
|
85 |
+
timesteps: Optional[List[int]] = None,
|
86 |
+
sigmas: Optional[List[float]] = None,
|
87 |
+
**kwargs,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
91 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
scheduler (`SchedulerMixin`):
|
95 |
+
The scheduler to get timesteps from.
|
96 |
+
num_inference_steps (`int`):
|
97 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
98 |
+
must be `None`.
|
99 |
+
device (`str` or `torch.device`, *optional*):
|
100 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
101 |
+
timesteps (`List[int]`, *optional*):
|
102 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
103 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
104 |
+
sigmas (`List[float]`, *optional*):
|
105 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
106 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
110 |
+
second element is the number of inference steps.
|
111 |
+
"""
|
112 |
+
if timesteps is not None and sigmas is not None:
|
113 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
114 |
+
if timesteps is not None:
|
115 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
116 |
+
if not accepts_timesteps:
|
117 |
+
raise ValueError(
|
118 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
119 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
120 |
+
)
|
121 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
122 |
+
timesteps = scheduler.timesteps
|
123 |
+
num_inference_steps = len(timesteps)
|
124 |
+
elif sigmas is not None:
|
125 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
126 |
+
if not accept_sigmas:
|
127 |
+
raise ValueError(
|
128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
129 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
130 |
+
)
|
131 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
132 |
+
timesteps = scheduler.timesteps
|
133 |
+
num_inference_steps = len(timesteps)
|
134 |
+
else:
|
135 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
136 |
+
timesteps = scheduler.timesteps
|
137 |
+
return timesteps, num_inference_steps
|
138 |
+
|
139 |
+
|
140 |
+
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
141 |
+
r"""
|
142 |
+
The Flux pipeline for text-to-image generation.
|
143 |
+
|
144 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
145 |
+
|
146 |
+
Args:
|
147 |
+
transformer ([`FluxTransformer2DModel`]):
|
148 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
149 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
150 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
151 |
+
vae ([`AutoencoderKL`]):
|
152 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
153 |
+
text_encoder ([`CLIPTextModel`]):
|
154 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
155 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
156 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
157 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
158 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
159 |
+
tokenizer (`CLIPTokenizer`):
|
160 |
+
Tokenizer of class
|
161 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
162 |
+
tokenizer_2 (`T5TokenizerFast`):
|
163 |
+
Second Tokenizer of class
|
164 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
165 |
+
"""
|
166 |
+
|
167 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
168 |
+
_optional_components = []
|
169 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
174 |
+
vae: AutoencoderKL,
|
175 |
+
text_encoder: CLIPTextModel,
|
176 |
+
tokenizer: CLIPTokenizer,
|
177 |
+
transformer: FluxTransformer2DModel,
|
178 |
+
text_encoder_2: T5EncoderModel | None = None,
|
179 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
self.register_modules(
|
184 |
+
vae=vae,
|
185 |
+
text_encoder=text_encoder,
|
186 |
+
#text_encoder_2=text_encoder_2,
|
187 |
+
tokenizer=tokenizer,
|
188 |
+
#tokenizer_2=tokenizer_2,
|
189 |
+
transformer=transformer,
|
190 |
+
scheduler=scheduler,
|
191 |
+
)
|
192 |
+
self.vae_scale_factor = (
|
193 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
194 |
+
)
|
195 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
196 |
+
self.tokenizer_max_length = (
|
197 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
198 |
+
)
|
199 |
+
self.default_sample_size = 64
|
200 |
+
|
201 |
+
def _get_t5_prompt_embeds(
|
202 |
+
self,
|
203 |
+
prompt: Union[str, List[str]] = None,
|
204 |
+
num_images_per_prompt: int = 1,
|
205 |
+
max_sequence_length: int = 512,
|
206 |
+
device: Optional[torch.device] = None,
|
207 |
+
dtype: Optional[torch.dtype] = None,
|
208 |
+
):
|
209 |
+
device = device or self._execution_device
|
210 |
+
dtype = dtype or self.text_encoder.dtype
|
211 |
+
|
212 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
213 |
+
batch_size = len(prompt)
|
214 |
+
|
215 |
+
text_inputs = self.tokenizer_2(
|
216 |
+
prompt,
|
217 |
+
padding="max_length",
|
218 |
+
max_length=max_sequence_length,
|
219 |
+
truncation=True,
|
220 |
+
return_length=False,
|
221 |
+
return_overflowing_tokens=False,
|
222 |
+
return_tensors="pt",
|
223 |
+
)
|
224 |
+
text_input_ids = text_inputs.input_ids
|
225 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
226 |
+
|
227 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
228 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
229 |
+
logger.warning(
|
230 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
231 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
232 |
+
)
|
233 |
+
|
234 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
235 |
+
|
236 |
+
dtype = self.text_encoder_2.dtype
|
237 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
238 |
+
|
239 |
+
_, seq_len, _ = prompt_embeds.shape
|
240 |
+
|
241 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
242 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
243 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
244 |
+
|
245 |
+
return prompt_embeds
|
246 |
+
|
247 |
+
def _get_clip_prompt_embeds(
|
248 |
+
self,
|
249 |
+
prompt: Union[str, List[str]],
|
250 |
+
num_images_per_prompt: int = 1,
|
251 |
+
device: Optional[torch.device] = None,
|
252 |
+
):
|
253 |
+
device = device or self._execution_device
|
254 |
+
|
255 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256 |
+
batch_size = len(prompt)
|
257 |
+
|
258 |
+
text_inputs = self.tokenizer(
|
259 |
+
prompt,
|
260 |
+
padding="max_length",
|
261 |
+
max_length=self.tokenizer_max_length,
|
262 |
+
truncation=True,
|
263 |
+
return_overflowing_tokens=False,
|
264 |
+
return_length=False,
|
265 |
+
return_tensors="pt",
|
266 |
+
)
|
267 |
+
|
268 |
+
text_input_ids = text_inputs.input_ids
|
269 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
270 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
271 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
272 |
+
logger.warning(
|
273 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
274 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
275 |
+
)
|
276 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
277 |
+
|
278 |
+
# Use pooled output of CLIPTextModel
|
279 |
+
prompt_embeds = prompt_embeds.pooler_output
|
280 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
281 |
+
|
282 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
283 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
284 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
285 |
+
|
286 |
+
return prompt_embeds
|
287 |
+
|
288 |
+
def encode_prompt(
|
289 |
+
self,
|
290 |
+
prompt: Union[str, List[str]],
|
291 |
+
prompt_2: Union[str, List[str]],
|
292 |
+
device: Optional[torch.device] = None,
|
293 |
+
num_images_per_prompt: int = 1,
|
294 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
295 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
296 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
297 |
+
max_sequence_length: int = 512,
|
298 |
+
lora_scale: Optional[float] = None,
|
299 |
+
):
|
300 |
+
r"""
|
301 |
+
|
302 |
+
Args:
|
303 |
+
prompt (`str` or `List[str]`, *optional*):
|
304 |
+
prompt to be encoded
|
305 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
306 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
307 |
+
used in all text-encoders
|
308 |
+
device: (`torch.device`):
|
309 |
+
torch device
|
310 |
+
num_images_per_prompt (`int`):
|
311 |
+
number of images that should be generated per prompt
|
312 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
313 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
314 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
315 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
316 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
317 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
318 |
+
lora_scale (`float`, *optional*):
|
319 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
320 |
+
"""
|
321 |
+
device = device or self._execution_device
|
322 |
+
|
323 |
+
# set lora scale so that monkey patched LoRA
|
324 |
+
# function of text encoder can correctly access it
|
325 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
326 |
+
self._lora_scale = lora_scale
|
327 |
+
|
328 |
+
# dynamically adjust the LoRA scale
|
329 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
330 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
331 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
332 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
333 |
+
|
334 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
335 |
+
if prompt is not None:
|
336 |
+
batch_size = len(prompt)
|
337 |
+
else:
|
338 |
+
batch_size = prompt_embeds.shape[0]
|
339 |
+
|
340 |
+
if prompt_embeds is None:
|
341 |
+
prompt_2 = prompt_2 or prompt
|
342 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
343 |
+
|
344 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
345 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
346 |
+
prompt=prompt,
|
347 |
+
device=device,
|
348 |
+
num_images_per_prompt=num_images_per_prompt,
|
349 |
+
)
|
350 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
351 |
+
prompt=prompt_2,
|
352 |
+
num_images_per_prompt=num_images_per_prompt,
|
353 |
+
max_sequence_length=max_sequence_length,
|
354 |
+
device=device,
|
355 |
+
)
|
356 |
+
|
357 |
+
if self.text_encoder is not None:
|
358 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
359 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
360 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
361 |
+
|
362 |
+
#if self.text_encoder_2 is not None:
|
363 |
+
# if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
364 |
+
# # Retrieve the original scale by scaling back the LoRA layers
|
365 |
+
# unscale_lora_layers(self.text_encoder_2, lora_scale)
|
366 |
+
|
367 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
368 |
+
if t5_prompt_embeds is not None:
|
369 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
370 |
+
else:
|
371 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
372 |
+
|
373 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
374 |
+
|
375 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
376 |
+
|
377 |
+
def check_inputs(
|
378 |
+
self,
|
379 |
+
prompt,
|
380 |
+
prompt_2,
|
381 |
+
height,
|
382 |
+
width,
|
383 |
+
prompt_embeds=None,
|
384 |
+
pooled_prompt_embeds=None,
|
385 |
+
callback_on_step_end_tensor_inputs=None,
|
386 |
+
max_sequence_length=None,
|
387 |
+
):
|
388 |
+
if height % 8 != 0 or width % 8 != 0:
|
389 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
390 |
+
|
391 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
392 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
393 |
+
):
|
394 |
+
raise ValueError(
|
395 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
396 |
+
)
|
397 |
+
|
398 |
+
if prompt is not None and prompt_embeds is not None:
|
399 |
+
raise ValueError(
|
400 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
401 |
+
" only forward one of the two."
|
402 |
+
)
|
403 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
404 |
+
raise ValueError(
|
405 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
406 |
+
" only forward one of the two."
|
407 |
+
)
|
408 |
+
elif prompt is None and prompt_embeds is None:
|
409 |
+
raise ValueError(
|
410 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
411 |
+
)
|
412 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
413 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
414 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
415 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
416 |
+
|
417 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
418 |
+
raise ValueError(
|
419 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
420 |
+
)
|
421 |
+
|
422 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
423 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
424 |
+
|
425 |
+
@staticmethod
|
426 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
427 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
428 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
429 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
430 |
+
|
431 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
432 |
+
|
433 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
434 |
+
latent_image_ids = latent_image_ids.reshape(
|
435 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
436 |
+
)
|
437 |
+
|
438 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
439 |
+
|
440 |
+
@staticmethod
|
441 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
442 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
443 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
444 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
445 |
+
|
446 |
+
return latents
|
447 |
+
|
448 |
+
@staticmethod
|
449 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
450 |
+
batch_size, num_patches, channels = latents.shape
|
451 |
+
|
452 |
+
height = height // vae_scale_factor
|
453 |
+
width = width // vae_scale_factor
|
454 |
+
|
455 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
456 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
457 |
+
|
458 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
459 |
+
|
460 |
+
return latents
|
461 |
+
|
462 |
+
def prepare_latents(
|
463 |
+
self,
|
464 |
+
batch_size,
|
465 |
+
num_channels_latents,
|
466 |
+
height,
|
467 |
+
width,
|
468 |
+
dtype,
|
469 |
+
device,
|
470 |
+
generator,
|
471 |
+
latents=None,
|
472 |
+
):
|
473 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
474 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
475 |
+
|
476 |
+
shape = (batch_size, num_channels_latents, height, width)
|
477 |
+
|
478 |
+
if latents is not None:
|
479 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
480 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
481 |
+
|
482 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
483 |
+
raise ValueError(
|
484 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
485 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
486 |
+
)
|
487 |
+
|
488 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
489 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
490 |
+
|
491 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
492 |
+
|
493 |
+
return latents, latent_image_ids
|
494 |
+
|
495 |
+
@property
|
496 |
+
def guidance_scale(self):
|
497 |
+
return self._guidance_scale
|
498 |
+
|
499 |
+
@property
|
500 |
+
def joint_attention_kwargs(self):
|
501 |
+
return self._joint_attention_kwargs
|
502 |
+
|
503 |
+
@property
|
504 |
+
def num_timesteps(self):
|
505 |
+
return self._num_timesteps
|
506 |
+
|
507 |
+
@property
|
508 |
+
def interrupt(self):
|
509 |
+
return self._interrupt
|
510 |
+
|
511 |
+
#@torch.inference_mode()
|
512 |
+
@torch.no_grad()
|
513 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
514 |
+
def __call__(
|
515 |
+
self,
|
516 |
+
prompt: Union[str, List[str]] = None,
|
517 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
518 |
+
height: Optional[int] = None,
|
519 |
+
width: Optional[int] = None,
|
520 |
+
num_inference_steps: int = 28,
|
521 |
+
timesteps: List[int] = None,
|
522 |
+
guidance_scale: float = 7.0,
|
523 |
+
num_images_per_prompt: Optional[int] = 1,
|
524 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
525 |
+
latents: Optional[torch.FloatTensor] = None,
|
526 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
527 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
528 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
529 |
+
output_type: Optional[str] = "pil",
|
530 |
+
return_dict: bool = True,
|
531 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
532 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
533 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
534 |
+
max_sequence_length: int = 512,
|
535 |
+
):
|
536 |
+
r"""
|
537 |
+
Function invoked when calling the pipeline for generation.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
prompt (`str` or `List[str]`, *optional*):
|
541 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
542 |
+
instead.
|
543 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
544 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
545 |
+
will be used instead
|
546 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
547 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
548 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
549 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
550 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
551 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
552 |
+
expense of slower inference.
|
553 |
+
timesteps (`List[int]`, *optional*):
|
554 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
555 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
556 |
+
passed will be used. Must be in descending order.
|
557 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
558 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
559 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
560 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
561 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
562 |
+
usually at the expense of lower image quality.
|
563 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
564 |
+
The number of images to generate per prompt.
|
565 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
566 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
567 |
+
to make generation deterministic.
|
568 |
+
latents (`torch.FloatTensor`, *optional*):
|
569 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
570 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
571 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
572 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
573 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
574 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
575 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
576 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
577 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
578 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
579 |
+
The output format of the generate image. Choose between
|
580 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
581 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
582 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
583 |
+
joint_attention_kwargs (`dict`, *optional*):
|
584 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
585 |
+
`self.processor` in
|
586 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
587 |
+
callback_on_step_end (`Callable`, *optional*):
|
588 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
589 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
590 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
591 |
+
`callback_on_step_end_tensor_inputs`.
|
592 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
593 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
594 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
595 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
596 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
597 |
+
|
598 |
+
Examples:
|
599 |
+
|
600 |
+
Returns:
|
601 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
602 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
603 |
+
images.
|
604 |
+
"""
|
605 |
+
|
606 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
607 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
608 |
+
|
609 |
+
# 1. Check inputs. Raise error if not correct
|
610 |
+
self.check_inputs(
|
611 |
+
prompt,
|
612 |
+
prompt_2,
|
613 |
+
height,
|
614 |
+
width,
|
615 |
+
prompt_embeds=prompt_embeds,
|
616 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
617 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
618 |
+
max_sequence_length=max_sequence_length,
|
619 |
+
)
|
620 |
+
|
621 |
+
self._guidance_scale = guidance_scale
|
622 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
623 |
+
self._interrupt = False
|
624 |
+
|
625 |
+
# 2. Define call parameters
|
626 |
+
if prompt is not None and isinstance(prompt, str):
|
627 |
+
batch_size = 1
|
628 |
+
elif prompt is not None and isinstance(prompt, list):
|
629 |
+
batch_size = len(prompt)
|
630 |
+
else:
|
631 |
+
batch_size = prompt_embeds.shape[0]
|
632 |
+
|
633 |
+
device = self._execution_device
|
634 |
+
|
635 |
+
lora_scale = (
|
636 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
637 |
+
)
|
638 |
+
(
|
639 |
+
prompt_embeds,
|
640 |
+
pooled_prompt_embeds,
|
641 |
+
text_ids,
|
642 |
+
) = self.encode_prompt(
|
643 |
+
prompt=prompt,
|
644 |
+
prompt_2=prompt_2,
|
645 |
+
prompt_embeds=prompt_embeds,
|
646 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
647 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
648 |
+
device=device,
|
649 |
+
num_images_per_prompt=num_images_per_prompt,
|
650 |
+
max_sequence_length=max_sequence_length,
|
651 |
+
lora_scale=lora_scale,
|
652 |
+
)
|
653 |
+
|
654 |
+
# 4. Prepare latent variables
|
655 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
656 |
+
latents, latent_image_ids = self.prepare_latents(
|
657 |
+
batch_size * num_images_per_prompt,
|
658 |
+
num_channels_latents,
|
659 |
+
height,
|
660 |
+
width,
|
661 |
+
prompt_embeds.dtype,
|
662 |
+
device,
|
663 |
+
generator,
|
664 |
+
latents,
|
665 |
+
)
|
666 |
+
|
667 |
+
# 5. Prepare timesteps
|
668 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
669 |
+
image_seq_len = latents.shape[1]
|
670 |
+
mu = calculate_shift(
|
671 |
+
image_seq_len,
|
672 |
+
self.scheduler.config.base_image_seq_len,
|
673 |
+
self.scheduler.config.max_image_seq_len,
|
674 |
+
self.scheduler.config.base_shift,
|
675 |
+
self.scheduler.config.max_shift,
|
676 |
+
)
|
677 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
678 |
+
self.scheduler,
|
679 |
+
num_inference_steps,
|
680 |
+
device,
|
681 |
+
timesteps,
|
682 |
+
sigmas,
|
683 |
+
mu=mu,
|
684 |
+
)
|
685 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
686 |
+
self._num_timesteps = len(timesteps)
|
687 |
+
|
688 |
+
# 6. Denoising loop
|
689 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
690 |
+
for i, t in enumerate(timesteps):
|
691 |
+
if self.interrupt:
|
692 |
+
continue
|
693 |
+
|
694 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
695 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
696 |
+
|
697 |
+
# handle guidance
|
698 |
+
if self.transformer.config.guidance_embeds:
|
699 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
700 |
+
guidance = guidance.expand(latents.shape[0])
|
701 |
+
else:
|
702 |
+
guidance = None
|
703 |
+
|
704 |
+
noise_pred = self.transformer(
|
705 |
+
hidden_states=latents,
|
706 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
707 |
+
timestep=timestep / 1000,
|
708 |
+
guidance=guidance,
|
709 |
+
pooled_projections=pooled_prompt_embeds,
|
710 |
+
encoder_hidden_states=prompt_embeds,
|
711 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
712 |
+
txt_ids=text_ids,
|
713 |
+
img_ids=latent_image_ids,
|
714 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
715 |
+
return_dict=False,
|
716 |
+
)[0]
|
717 |
+
|
718 |
+
# compute the previous noisy sample x_t -> x_t-1
|
719 |
+
latents_dtype = latents.dtype
|
720 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
721 |
+
|
722 |
+
if latents.dtype != latents_dtype:
|
723 |
+
if torch.backends.mps.is_available():
|
724 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
725 |
+
latents = latents.to(latents_dtype)
|
726 |
+
|
727 |
+
if callback_on_step_end is not None:
|
728 |
+
callback_kwargs = {}
|
729 |
+
for k in callback_on_step_end_tensor_inputs:
|
730 |
+
callback_kwargs[k] = locals()[k]
|
731 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
732 |
+
|
733 |
+
latents = callback_outputs.pop("latents", latents)
|
734 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
735 |
+
|
736 |
+
# call the callback, if provided
|
737 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
738 |
+
progress_bar.update()
|
739 |
+
|
740 |
+
if XLA_AVAILABLE:
|
741 |
+
xm.mark_step()
|
742 |
+
|
743 |
+
if output_type == "latent":
|
744 |
+
image = latents
|
745 |
+
|
746 |
+
else:
|
747 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
748 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
749 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
750 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
751 |
+
|
752 |
+
# Offload all models
|
753 |
+
self.maybe_free_model_hooks()
|
754 |
+
|
755 |
+
if not return_dict:
|
756 |
+
return (image,)
|
757 |
+
|
758 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_controlnet.py
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import (
|
21 |
+
CLIPTextModel,
|
22 |
+
CLIPTokenizer,
|
23 |
+
T5EncoderModel,
|
24 |
+
T5TokenizerFast,
|
25 |
+
)
|
26 |
+
|
27 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
29 |
+
from diffusers.loaders import FromSingleFileMixin
|
30 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
31 |
+
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
32 |
+
from .transformer_flux import FluxTransformer2DModel
|
33 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
34 |
+
from diffusers.utils import (
|
35 |
+
USE_PEFT_BACKEND,
|
36 |
+
is_torch_xla_available,
|
37 |
+
logging,
|
38 |
+
replace_example_docstring,
|
39 |
+
scale_lora_layers,
|
40 |
+
unscale_lora_layers,
|
41 |
+
)
|
42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
44 |
+
from .pipeline_output import FluxPipelineOutput
|
45 |
+
|
46 |
+
|
47 |
+
if is_torch_xla_available():
|
48 |
+
import torch_xla.core.xla_model as xm
|
49 |
+
|
50 |
+
XLA_AVAILABLE = True
|
51 |
+
else:
|
52 |
+
XLA_AVAILABLE = False
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
+
|
57 |
+
EXAMPLE_DOC_STRING = """
|
58 |
+
Examples:
|
59 |
+
```py
|
60 |
+
>>> import torch
|
61 |
+
>>> from diffusers.utils import load_image
|
62 |
+
>>> from diffusers import FluxControlNetPipeline
|
63 |
+
>>> from diffusers import FluxControlNetModel
|
64 |
+
|
65 |
+
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
66 |
+
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
67 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
68 |
+
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
69 |
+
... )
|
70 |
+
>>> pipe.to("cuda")
|
71 |
+
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
72 |
+
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
73 |
+
>>> image = pipe(
|
74 |
+
... prompt,
|
75 |
+
... control_image=control_image,
|
76 |
+
... controlnet_conditioning_scale=0.6,
|
77 |
+
... num_inference_steps=28,
|
78 |
+
... guidance_scale=3.5,
|
79 |
+
... ).images[0]
|
80 |
+
>>> image.save("flux.png")
|
81 |
+
```
|
82 |
+
"""
|
83 |
+
|
84 |
+
|
85 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
86 |
+
def calculate_shift(
|
87 |
+
image_seq_len,
|
88 |
+
base_seq_len: int = 256,
|
89 |
+
max_seq_len: int = 4096,
|
90 |
+
base_shift: float = 0.5,
|
91 |
+
max_shift: float = 1.16,
|
92 |
+
):
|
93 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
94 |
+
b = base_shift - m * base_seq_len
|
95 |
+
mu = image_seq_len * m + b
|
96 |
+
return mu
|
97 |
+
|
98 |
+
|
99 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
100 |
+
def retrieve_timesteps(
|
101 |
+
scheduler,
|
102 |
+
num_inference_steps: Optional[int] = None,
|
103 |
+
device: Optional[Union[str, torch.device]] = None,
|
104 |
+
timesteps: Optional[List[int]] = None,
|
105 |
+
sigmas: Optional[List[float]] = None,
|
106 |
+
**kwargs,
|
107 |
+
):
|
108 |
+
"""
|
109 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
110 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
scheduler (`SchedulerMixin`):
|
114 |
+
The scheduler to get timesteps from.
|
115 |
+
num_inference_steps (`int`):
|
116 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
117 |
+
must be `None`.
|
118 |
+
device (`str` or `torch.device`, *optional*):
|
119 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
120 |
+
timesteps (`List[int]`, *optional*):
|
121 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
122 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
123 |
+
sigmas (`List[float]`, *optional*):
|
124 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
125 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
129 |
+
second element is the number of inference steps.
|
130 |
+
"""
|
131 |
+
if timesteps is not None and sigmas is not None:
|
132 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
133 |
+
if timesteps is not None:
|
134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
135 |
+
if not accepts_timesteps:
|
136 |
+
raise ValueError(
|
137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
139 |
+
)
|
140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
141 |
+
timesteps = scheduler.timesteps
|
142 |
+
num_inference_steps = len(timesteps)
|
143 |
+
elif sigmas is not None:
|
144 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
145 |
+
if not accept_sigmas:
|
146 |
+
raise ValueError(
|
147 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
148 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
149 |
+
)
|
150 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
151 |
+
timesteps = scheduler.timesteps
|
152 |
+
num_inference_steps = len(timesteps)
|
153 |
+
else:
|
154 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
155 |
+
timesteps = scheduler.timesteps
|
156 |
+
return timesteps, num_inference_steps
|
157 |
+
|
158 |
+
|
159 |
+
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
160 |
+
r"""
|
161 |
+
The Flux pipeline for text-to-image generation.
|
162 |
+
|
163 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
164 |
+
|
165 |
+
Args:
|
166 |
+
transformer ([`FluxTransformer2DModel`]):
|
167 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
168 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
169 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
170 |
+
vae ([`AutoencoderKL`]):
|
171 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
172 |
+
text_encoder ([`CLIPTextModel`]):
|
173 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
174 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
175 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
176 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
177 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
178 |
+
tokenizer (`CLIPTokenizer`):
|
179 |
+
Tokenizer of class
|
180 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
181 |
+
tokenizer_2 (`T5TokenizerFast`):
|
182 |
+
Second Tokenizer of class
|
183 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
184 |
+
"""
|
185 |
+
|
186 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
187 |
+
_optional_components = []
|
188 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
193 |
+
vae: AutoencoderKL,
|
194 |
+
text_encoder: CLIPTextModel,
|
195 |
+
tokenizer: CLIPTokenizer,
|
196 |
+
transformer: FluxTransformer2DModel,
|
197 |
+
controlnet: Union[
|
198 |
+
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
199 |
+
],
|
200 |
+
text_encoder_2: T5EncoderModel | None = None,
|
201 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.register_modules(
|
206 |
+
vae=vae,
|
207 |
+
text_encoder=text_encoder,
|
208 |
+
#text_encoder_2=text_encoder_2,
|
209 |
+
tokenizer=tokenizer,
|
210 |
+
#tokenizer_2=tokenizer_2,
|
211 |
+
transformer=transformer,
|
212 |
+
scheduler=scheduler,
|
213 |
+
controlnet=controlnet,
|
214 |
+
)
|
215 |
+
self.vae_scale_factor = (
|
216 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
217 |
+
)
|
218 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
219 |
+
self.tokenizer_max_length = (
|
220 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
221 |
+
)
|
222 |
+
self.default_sample_size = 64
|
223 |
+
|
224 |
+
def _get_t5_prompt_embeds(
|
225 |
+
self,
|
226 |
+
prompt: Union[str, List[str]] = None,
|
227 |
+
num_images_per_prompt: int = 1,
|
228 |
+
max_sequence_length: int = 512,
|
229 |
+
device: Optional[torch.device] = None,
|
230 |
+
dtype: Optional[torch.dtype] = None,
|
231 |
+
):
|
232 |
+
device = device or self._execution_device
|
233 |
+
dtype = dtype or self.text_encoder.dtype
|
234 |
+
|
235 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
236 |
+
batch_size = len(prompt)
|
237 |
+
|
238 |
+
text_inputs = self.tokenizer_2(
|
239 |
+
prompt,
|
240 |
+
padding="max_length",
|
241 |
+
max_length=max_sequence_length,
|
242 |
+
truncation=True,
|
243 |
+
return_length=False,
|
244 |
+
return_overflowing_tokens=False,
|
245 |
+
return_tensors="pt",
|
246 |
+
)
|
247 |
+
text_input_ids = text_inputs.input_ids
|
248 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
249 |
+
|
250 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
251 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
252 |
+
logger.warning(
|
253 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
254 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
255 |
+
)
|
256 |
+
|
257 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
258 |
+
|
259 |
+
dtype = self.text_encoder_2.dtype
|
260 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
261 |
+
|
262 |
+
_, seq_len, _ = prompt_embeds.shape
|
263 |
+
|
264 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
265 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
266 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
267 |
+
|
268 |
+
return prompt_embeds
|
269 |
+
|
270 |
+
def _get_clip_prompt_embeds(
|
271 |
+
self,
|
272 |
+
prompt: Union[str, List[str]],
|
273 |
+
num_images_per_prompt: int = 1,
|
274 |
+
device: Optional[torch.device] = None,
|
275 |
+
):
|
276 |
+
device = device or self._execution_device
|
277 |
+
|
278 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
279 |
+
batch_size = len(prompt)
|
280 |
+
|
281 |
+
text_inputs = self.tokenizer(
|
282 |
+
prompt,
|
283 |
+
padding="max_length",
|
284 |
+
max_length=self.tokenizer_max_length,
|
285 |
+
truncation=True,
|
286 |
+
return_overflowing_tokens=False,
|
287 |
+
return_length=False,
|
288 |
+
return_tensors="pt",
|
289 |
+
)
|
290 |
+
|
291 |
+
text_input_ids = text_inputs.input_ids
|
292 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
293 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
294 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
295 |
+
logger.warning(
|
296 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
297 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
298 |
+
)
|
299 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
300 |
+
|
301 |
+
# Use pooled output of CLIPTextModel
|
302 |
+
prompt_embeds = prompt_embeds.pooler_output
|
303 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
304 |
+
|
305 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
306 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
307 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
308 |
+
|
309 |
+
return prompt_embeds
|
310 |
+
|
311 |
+
def encode_prompt(
|
312 |
+
self,
|
313 |
+
prompt: Union[str, List[str]],
|
314 |
+
prompt_2: Union[str, List[str]],
|
315 |
+
device: Optional[torch.device] = None,
|
316 |
+
num_images_per_prompt: int = 1,
|
317 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
318 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
319 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
320 |
+
max_sequence_length: int = 512,
|
321 |
+
lora_scale: Optional[float] = None,
|
322 |
+
):
|
323 |
+
r"""
|
324 |
+
|
325 |
+
Args:
|
326 |
+
prompt (`str` or `List[str]`, *optional*):
|
327 |
+
prompt to be encoded
|
328 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
329 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
330 |
+
used in all text-encoders
|
331 |
+
device: (`torch.device`):
|
332 |
+
torch device
|
333 |
+
num_images_per_prompt (`int`):
|
334 |
+
number of images that should be generated per prompt
|
335 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
336 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
337 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
338 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
339 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
340 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
341 |
+
clip_skip (`int`, *optional*):
|
342 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
343 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
344 |
+
lora_scale (`float`, *optional*):
|
345 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
346 |
+
"""
|
347 |
+
device = device or self._execution_device
|
348 |
+
|
349 |
+
# set lora scale so that monkey patched LoRA
|
350 |
+
# function of text encoder can correctly access it
|
351 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
352 |
+
self._lora_scale = lora_scale
|
353 |
+
|
354 |
+
# dynamically adjust the LoRA scale
|
355 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
356 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
357 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
358 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
359 |
+
|
360 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
361 |
+
|
362 |
+
if prompt_embeds is None:
|
363 |
+
prompt_2 = prompt_2 or prompt
|
364 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
365 |
+
|
366 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
367 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
368 |
+
prompt=prompt,
|
369 |
+
device=device,
|
370 |
+
num_images_per_prompt=num_images_per_prompt,
|
371 |
+
)
|
372 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
373 |
+
prompt=prompt_2,
|
374 |
+
num_images_per_prompt=num_images_per_prompt,
|
375 |
+
max_sequence_length=max_sequence_length,
|
376 |
+
device=device,
|
377 |
+
)
|
378 |
+
|
379 |
+
if self.text_encoder is not None:
|
380 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
381 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
382 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
383 |
+
|
384 |
+
#if self.text_encoder_2 is not None:
|
385 |
+
# if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
386 |
+
# # Retrieve the original scale by scaling back the LoRA layers
|
387 |
+
# unscale_lora_layers(self.text_encoder_2, lora_scale)
|
388 |
+
|
389 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
390 |
+
if t5_prompt_embeds is not None:
|
391 |
+
text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
392 |
+
else:
|
393 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
394 |
+
|
395 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
396 |
+
|
397 |
+
def check_inputs(
|
398 |
+
self,
|
399 |
+
prompt,
|
400 |
+
prompt_2,
|
401 |
+
height,
|
402 |
+
width,
|
403 |
+
prompt_embeds=None,
|
404 |
+
pooled_prompt_embeds=None,
|
405 |
+
callback_on_step_end_tensor_inputs=None,
|
406 |
+
max_sequence_length=None,
|
407 |
+
):
|
408 |
+
if height % 8 != 0 or width % 8 != 0:
|
409 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
410 |
+
|
411 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
412 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
413 |
+
):
|
414 |
+
raise ValueError(
|
415 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
416 |
+
)
|
417 |
+
|
418 |
+
if prompt is not None and prompt_embeds is not None:
|
419 |
+
raise ValueError(
|
420 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
421 |
+
" only forward one of the two."
|
422 |
+
)
|
423 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
424 |
+
raise ValueError(
|
425 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
426 |
+
" only forward one of the two."
|
427 |
+
)
|
428 |
+
elif prompt is None and prompt_embeds is None:
|
429 |
+
raise ValueError(
|
430 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
431 |
+
)
|
432 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
433 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
434 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
435 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
436 |
+
|
437 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
438 |
+
raise ValueError(
|
439 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
440 |
+
)
|
441 |
+
|
442 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
443 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
447 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
448 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
449 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
450 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
451 |
+
|
452 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
453 |
+
|
454 |
+
latent_image_ids = latent_image_ids.reshape(
|
455 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
456 |
+
)
|
457 |
+
|
458 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
459 |
+
|
460 |
+
@staticmethod
|
461 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
462 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
463 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
464 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
465 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
466 |
+
|
467 |
+
return latents
|
468 |
+
|
469 |
+
@staticmethod
|
470 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
471 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
472 |
+
batch_size, num_patches, channels = latents.shape
|
473 |
+
|
474 |
+
height = height // vae_scale_factor
|
475 |
+
width = width // vae_scale_factor
|
476 |
+
|
477 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
478 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
479 |
+
|
480 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
481 |
+
|
482 |
+
return latents
|
483 |
+
|
484 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
485 |
+
def prepare_latents(
|
486 |
+
self,
|
487 |
+
batch_size,
|
488 |
+
num_channels_latents,
|
489 |
+
height,
|
490 |
+
width,
|
491 |
+
dtype,
|
492 |
+
device,
|
493 |
+
generator,
|
494 |
+
latents=None,
|
495 |
+
):
|
496 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
497 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
498 |
+
|
499 |
+
shape = (batch_size, num_channels_latents, height, width)
|
500 |
+
|
501 |
+
if latents is not None:
|
502 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
503 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
504 |
+
|
505 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
506 |
+
raise ValueError(
|
507 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
508 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
509 |
+
)
|
510 |
+
|
511 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
512 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
513 |
+
|
514 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
515 |
+
|
516 |
+
return latents, latent_image_ids
|
517 |
+
|
518 |
+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
519 |
+
def prepare_image(
|
520 |
+
self,
|
521 |
+
image,
|
522 |
+
width,
|
523 |
+
height,
|
524 |
+
batch_size,
|
525 |
+
num_images_per_prompt,
|
526 |
+
device,
|
527 |
+
dtype,
|
528 |
+
do_classifier_free_guidance=False,
|
529 |
+
guess_mode=False,
|
530 |
+
):
|
531 |
+
if isinstance(image, torch.Tensor):
|
532 |
+
pass
|
533 |
+
else:
|
534 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
535 |
+
|
536 |
+
image_batch_size = image.shape[0]
|
537 |
+
|
538 |
+
if image_batch_size == 1:
|
539 |
+
repeat_by = batch_size
|
540 |
+
else:
|
541 |
+
# image batch size is the same as prompt batch size
|
542 |
+
repeat_by = num_images_per_prompt
|
543 |
+
|
544 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
545 |
+
|
546 |
+
image = image.to(device=device, dtype=dtype)
|
547 |
+
|
548 |
+
if do_classifier_free_guidance and not guess_mode:
|
549 |
+
image = torch.cat([image] * 2)
|
550 |
+
|
551 |
+
return image
|
552 |
+
|
553 |
+
@property
|
554 |
+
def guidance_scale(self):
|
555 |
+
return self._guidance_scale
|
556 |
+
|
557 |
+
@property
|
558 |
+
def joint_attention_kwargs(self):
|
559 |
+
return self._joint_attention_kwargs
|
560 |
+
|
561 |
+
@property
|
562 |
+
def num_timesteps(self):
|
563 |
+
return self._num_timesteps
|
564 |
+
|
565 |
+
@property
|
566 |
+
def interrupt(self):
|
567 |
+
return self._interrupt
|
568 |
+
|
569 |
+
@torch.no_grad()
|
570 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
571 |
+
def __call__(
|
572 |
+
self,
|
573 |
+
prompt: Union[str, List[str]] = None,
|
574 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
575 |
+
height: Optional[int] = None,
|
576 |
+
width: Optional[int] = None,
|
577 |
+
num_inference_steps: int = 28,
|
578 |
+
timesteps: List[int] = None,
|
579 |
+
guidance_scale: float = 7.0,
|
580 |
+
control_image: PipelineImageInput = None,
|
581 |
+
control_mode: Optional[Union[int, List[int]]] = None,
|
582 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
583 |
+
num_images_per_prompt: Optional[int] = 1,
|
584 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
585 |
+
latents: Optional[torch.FloatTensor] = None,
|
586 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
587 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
588 |
+
prompt_embeds_control: Optional[torch.FloatTensor] = None,
|
589 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
590 |
+
output_type: Optional[str] = "pil",
|
591 |
+
return_dict: bool = True,
|
592 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
593 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
594 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
595 |
+
max_sequence_length: int = 512,
|
596 |
+
):
|
597 |
+
r"""
|
598 |
+
Function invoked when calling the pipeline for generation.
|
599 |
+
|
600 |
+
Args:
|
601 |
+
prompt (`str` or `List[str]`, *optional*):
|
602 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
603 |
+
instead.
|
604 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
605 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
606 |
+
will be used instead
|
607 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
608 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
609 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
610 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
611 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
612 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
613 |
+
expense of slower inference.
|
614 |
+
timesteps (`List[int]`, *optional*):
|
615 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
616 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
617 |
+
passed will be used. Must be in descending order.
|
618 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
619 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
620 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
621 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
622 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
623 |
+
usually at the expense of lower image quality.
|
624 |
+
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
625 |
+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
626 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
627 |
+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
628 |
+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
629 |
+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
630 |
+
images must be passed as a list such that each element of the list can be correctly batched for input
|
631 |
+
to a single ControlNet.
|
632 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
633 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
634 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
635 |
+
the corresponding scale as a list.
|
636 |
+
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
|
637 |
+
The control mode when applying ControlNet-Union.
|
638 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
639 |
+
The number of images to generate per prompt.
|
640 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
641 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
642 |
+
to make generation deterministic.
|
643 |
+
latents (`torch.FloatTensor`, *optional*):
|
644 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
645 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
646 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
647 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
648 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
649 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
650 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
651 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
652 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
653 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
654 |
+
The output format of the generate image. Choose between
|
655 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
656 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
657 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
658 |
+
joint_attention_kwargs (`dict`, *optional*):
|
659 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
660 |
+
`self.processor` in
|
661 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
662 |
+
callback_on_step_end (`Callable`, *optional*):
|
663 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
664 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
665 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
666 |
+
`callback_on_step_end_tensor_inputs`.
|
667 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
668 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
669 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
670 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
671 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
672 |
+
|
673 |
+
Examples:
|
674 |
+
|
675 |
+
Returns:
|
676 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
677 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
678 |
+
images.
|
679 |
+
"""
|
680 |
+
|
681 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
682 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
683 |
+
|
684 |
+
# 1. Check inputs. Raise error if not correct
|
685 |
+
self.check_inputs(
|
686 |
+
prompt,
|
687 |
+
prompt_2,
|
688 |
+
height,
|
689 |
+
width,
|
690 |
+
prompt_embeds=prompt_embeds,
|
691 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
692 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
693 |
+
max_sequence_length=max_sequence_length,
|
694 |
+
)
|
695 |
+
|
696 |
+
self._guidance_scale = guidance_scale
|
697 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
698 |
+
self._interrupt = False
|
699 |
+
|
700 |
+
# 2. Define call parameters
|
701 |
+
if prompt is not None and isinstance(prompt, str):
|
702 |
+
batch_size = 1
|
703 |
+
elif prompt is not None and isinstance(prompt, list):
|
704 |
+
batch_size = len(prompt)
|
705 |
+
else:
|
706 |
+
batch_size = prompt_embeds.shape[0]
|
707 |
+
|
708 |
+
device = self._execution_device
|
709 |
+
dtype = self.transformer.dtype
|
710 |
+
|
711 |
+
lora_scale = (
|
712 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
713 |
+
)
|
714 |
+
(
|
715 |
+
prompt_embeds,
|
716 |
+
pooled_prompt_embeds,
|
717 |
+
text_ids,
|
718 |
+
) = self.encode_prompt(
|
719 |
+
prompt=prompt,
|
720 |
+
prompt_2=prompt_2,
|
721 |
+
prompt_embeds=prompt_embeds,
|
722 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
723 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
724 |
+
device=device,
|
725 |
+
num_images_per_prompt=num_images_per_prompt,
|
726 |
+
max_sequence_length=max_sequence_length,
|
727 |
+
lora_scale=lora_scale,
|
728 |
+
)
|
729 |
+
|
730 |
+
# 3. Prepare control image
|
731 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
732 |
+
if isinstance(self.controlnet, FluxControlNetModel):
|
733 |
+
control_image = self.prepare_image(
|
734 |
+
image=control_image,
|
735 |
+
width=width,
|
736 |
+
height=height,
|
737 |
+
batch_size=batch_size * num_images_per_prompt,
|
738 |
+
num_images_per_prompt=num_images_per_prompt,
|
739 |
+
device=device,
|
740 |
+
dtype=self.vae.dtype,
|
741 |
+
)
|
742 |
+
height, width = control_image.shape[-2:]
|
743 |
+
|
744 |
+
# vae encode
|
745 |
+
control_image = self.vae.encode(control_image).latent_dist.sample()
|
746 |
+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
747 |
+
|
748 |
+
# pack
|
749 |
+
height_control_image, width_control_image = control_image.shape[2:]
|
750 |
+
control_image = self._pack_latents(
|
751 |
+
control_image,
|
752 |
+
batch_size * num_images_per_prompt,
|
753 |
+
num_channels_latents,
|
754 |
+
height_control_image,
|
755 |
+
width_control_image,
|
756 |
+
)
|
757 |
+
|
758 |
+
# Here we ensure that `control_mode` has the same length as the control_image.
|
759 |
+
if control_mode is not None:
|
760 |
+
if not isinstance(control_mode, int):
|
761 |
+
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
|
762 |
+
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
763 |
+
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
764 |
+
|
765 |
+
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
766 |
+
control_images = []
|
767 |
+
|
768 |
+
for control_image_ in control_image:
|
769 |
+
control_image_ = self.prepare_image(
|
770 |
+
image=control_image_,
|
771 |
+
width=width,
|
772 |
+
height=height,
|
773 |
+
batch_size=batch_size * num_images_per_prompt,
|
774 |
+
num_images_per_prompt=num_images_per_prompt,
|
775 |
+
device=device,
|
776 |
+
dtype=self.vae.dtype,
|
777 |
+
)
|
778 |
+
height, width = control_image_.shape[-2:]
|
779 |
+
|
780 |
+
# vae encode
|
781 |
+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
782 |
+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
783 |
+
|
784 |
+
# pack
|
785 |
+
height_control_image, width_control_image = control_image_.shape[2:]
|
786 |
+
control_image_ = self._pack_latents(
|
787 |
+
control_image_,
|
788 |
+
batch_size * num_images_per_prompt,
|
789 |
+
num_channels_latents,
|
790 |
+
height_control_image,
|
791 |
+
width_control_image,
|
792 |
+
)
|
793 |
+
|
794 |
+
control_images.append(control_image_)
|
795 |
+
|
796 |
+
control_image = control_images
|
797 |
+
|
798 |
+
# Here we ensure that `control_mode` has the same length as the control_image.
|
799 |
+
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
800 |
+
raise ValueError(
|
801 |
+
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
802 |
+
+ " length as the number of controlnets (control images) specified"
|
803 |
+
)
|
804 |
+
if not isinstance(control_mode, list):
|
805 |
+
control_mode = [control_mode] * len(control_image)
|
806 |
+
# set control mode
|
807 |
+
control_modes = []
|
808 |
+
for cmode in control_mode:
|
809 |
+
if cmode is None:
|
810 |
+
cmode = -1
|
811 |
+
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
812 |
+
control_modes.append(control_mode)
|
813 |
+
control_mode = control_modes
|
814 |
+
|
815 |
+
# 4. Prepare latent variables
|
816 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
817 |
+
latents, latent_image_ids = self.prepare_latents(
|
818 |
+
batch_size * num_images_per_prompt,
|
819 |
+
num_channels_latents,
|
820 |
+
height,
|
821 |
+
width,
|
822 |
+
prompt_embeds.dtype,
|
823 |
+
device,
|
824 |
+
generator,
|
825 |
+
latents,
|
826 |
+
)
|
827 |
+
|
828 |
+
# 5. Prepare timesteps
|
829 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
830 |
+
image_seq_len = latents.shape[1]
|
831 |
+
mu = calculate_shift(
|
832 |
+
image_seq_len,
|
833 |
+
self.scheduler.config.base_image_seq_len,
|
834 |
+
self.scheduler.config.max_image_seq_len,
|
835 |
+
self.scheduler.config.base_shift,
|
836 |
+
self.scheduler.config.max_shift,
|
837 |
+
)
|
838 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
839 |
+
self.scheduler,
|
840 |
+
num_inference_steps,
|
841 |
+
device,
|
842 |
+
timesteps,
|
843 |
+
sigmas,
|
844 |
+
mu=mu,
|
845 |
+
)
|
846 |
+
|
847 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
848 |
+
self._num_timesteps = len(timesteps)
|
849 |
+
|
850 |
+
# 6. Denoising loop
|
851 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
852 |
+
for i, t in enumerate(timesteps):
|
853 |
+
if self.interrupt:
|
854 |
+
continue
|
855 |
+
|
856 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
857 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
858 |
+
|
859 |
+
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
860 |
+
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
861 |
+
else:
|
862 |
+
use_guidance = self.controlnet.config.guidance_embeds
|
863 |
+
|
864 |
+
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
865 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
866 |
+
|
867 |
+
# controlnet
|
868 |
+
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
869 |
+
hidden_states=latents,
|
870 |
+
controlnet_cond=control_image,
|
871 |
+
controlnet_mode=control_mode,
|
872 |
+
conditioning_scale=controlnet_conditioning_scale,
|
873 |
+
timestep=timestep / 1000,
|
874 |
+
guidance=guidance,
|
875 |
+
pooled_projections=pooled_prompt_embeds,
|
876 |
+
encoder_hidden_states=prompt_embeds_control,
|
877 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
878 |
+
txt_ids=text_ids,
|
879 |
+
img_ids=latent_image_ids,
|
880 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
881 |
+
return_dict=False,
|
882 |
+
)
|
883 |
+
|
884 |
+
guidance = (
|
885 |
+
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
|
886 |
+
)
|
887 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
888 |
+
|
889 |
+
noise_pred = self.transformer(
|
890 |
+
hidden_states=latents,
|
891 |
+
timestep=timestep / 1000,
|
892 |
+
guidance=guidance,
|
893 |
+
pooled_projections=pooled_prompt_embeds,
|
894 |
+
encoder_hidden_states=prompt_embeds,
|
895 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
896 |
+
controlnet_block_samples=controlnet_block_samples,
|
897 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
898 |
+
txt_ids=text_ids,
|
899 |
+
img_ids=latent_image_ids,
|
900 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
901 |
+
return_dict=False,
|
902 |
+
)[0]
|
903 |
+
|
904 |
+
# compute the previous noisy sample x_t -> x_t-1
|
905 |
+
latents_dtype = latents.dtype
|
906 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
907 |
+
|
908 |
+
if latents.dtype != latents_dtype:
|
909 |
+
if torch.backends.mps.is_available():
|
910 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
911 |
+
latents = latents.to(latents_dtype)
|
912 |
+
|
913 |
+
if callback_on_step_end is not None:
|
914 |
+
callback_kwargs = {}
|
915 |
+
for k in callback_on_step_end_tensor_inputs:
|
916 |
+
callback_kwargs[k] = locals()[k]
|
917 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
918 |
+
|
919 |
+
latents = callback_outputs.pop("latents", latents)
|
920 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
921 |
+
|
922 |
+
# call the callback, if provided
|
923 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
924 |
+
progress_bar.update()
|
925 |
+
|
926 |
+
if XLA_AVAILABLE:
|
927 |
+
xm.mark_step()
|
928 |
+
|
929 |
+
if output_type == "latent":
|
930 |
+
image = latents
|
931 |
+
|
932 |
+
else:
|
933 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
934 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
935 |
+
|
936 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
937 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
938 |
+
|
939 |
+
# Offload all models
|
940 |
+
self.maybe_free_model_hooks()
|
941 |
+
|
942 |
+
if not return_dict:
|
943 |
+
return (image,)
|
944 |
+
|
945 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_controlnet_img2img.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import (
|
21 |
+
CLIPTextModel,
|
22 |
+
CLIPTokenizer,
|
23 |
+
T5EncoderModel,
|
24 |
+
T5TokenizerFast,
|
25 |
+
)
|
26 |
+
|
27 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
29 |
+
from diffusers.loaders import FromSingleFileMixin
|
30 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
31 |
+
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
32 |
+
from .transformer_flux import FluxTransformer2DModel
|
33 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
34 |
+
from diffusers.utils import (
|
35 |
+
USE_PEFT_BACKEND,
|
36 |
+
is_torch_xla_available,
|
37 |
+
logging,
|
38 |
+
replace_example_docstring,
|
39 |
+
scale_lora_layers,
|
40 |
+
unscale_lora_layers,
|
41 |
+
)
|
42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
44 |
+
from .pipeline_output import FluxPipelineOutput
|
45 |
+
|
46 |
+
|
47 |
+
if is_torch_xla_available():
|
48 |
+
import torch_xla.core.xla_model as xm
|
49 |
+
|
50 |
+
XLA_AVAILABLE = True
|
51 |
+
else:
|
52 |
+
XLA_AVAILABLE = False
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
+
|
57 |
+
EXAMPLE_DOC_STRING = """
|
58 |
+
Examples:
|
59 |
+
```py
|
60 |
+
>>> import torch
|
61 |
+
>>> from diffusers.utils import load_image
|
62 |
+
>>> from diffusers import FluxControlNetPipeline
|
63 |
+
>>> from diffusers import FluxControlNetModel
|
64 |
+
|
65 |
+
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
66 |
+
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
67 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
68 |
+
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
69 |
+
... )
|
70 |
+
>>> pipe.to("cuda")
|
71 |
+
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
72 |
+
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
73 |
+
>>> image = pipe(
|
74 |
+
... prompt,
|
75 |
+
... control_image=control_image,
|
76 |
+
... controlnet_conditioning_scale=0.6,
|
77 |
+
... num_inference_steps=28,
|
78 |
+
... guidance_scale=3.5,
|
79 |
+
... ).images[0]
|
80 |
+
>>> image.save("flux.png")
|
81 |
+
```
|
82 |
+
"""
|
83 |
+
|
84 |
+
|
85 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
86 |
+
def calculate_shift(
|
87 |
+
image_seq_len,
|
88 |
+
base_seq_len: int = 256,
|
89 |
+
max_seq_len: int = 4096,
|
90 |
+
base_shift: float = 0.5,
|
91 |
+
max_shift: float = 1.16,
|
92 |
+
):
|
93 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
94 |
+
b = base_shift - m * base_seq_len
|
95 |
+
mu = image_seq_len * m + b
|
96 |
+
return mu
|
97 |
+
|
98 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
99 |
+
def retrieve_latents(
|
100 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
101 |
+
):
|
102 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
103 |
+
return encoder_output.latent_dist.sample(generator)
|
104 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
105 |
+
return encoder_output.latent_dist.mode()
|
106 |
+
elif hasattr(encoder_output, "latents"):
|
107 |
+
return encoder_output.latents
|
108 |
+
else:
|
109 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
110 |
+
|
111 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
112 |
+
def retrieve_timesteps(
|
113 |
+
scheduler,
|
114 |
+
num_inference_steps: Optional[int] = None,
|
115 |
+
device: Optional[Union[str, torch.device]] = None,
|
116 |
+
timesteps: Optional[List[int]] = None,
|
117 |
+
sigmas: Optional[List[float]] = None,
|
118 |
+
**kwargs,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
122 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
scheduler (`SchedulerMixin`):
|
126 |
+
The scheduler to get timesteps from.
|
127 |
+
num_inference_steps (`int`):
|
128 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
129 |
+
must be `None`.
|
130 |
+
device (`str` or `torch.device`, *optional*):
|
131 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
132 |
+
timesteps (`List[int]`, *optional*):
|
133 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
134 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
135 |
+
sigmas (`List[float]`, *optional*):
|
136 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
137 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
141 |
+
second element is the number of inference steps.
|
142 |
+
"""
|
143 |
+
if timesteps is not None and sigmas is not None:
|
144 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
145 |
+
if timesteps is not None:
|
146 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
147 |
+
if not accepts_timesteps:
|
148 |
+
raise ValueError(
|
149 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
150 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
151 |
+
)
|
152 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
153 |
+
timesteps = scheduler.timesteps
|
154 |
+
num_inference_steps = len(timesteps)
|
155 |
+
elif sigmas is not None:
|
156 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
157 |
+
if not accept_sigmas:
|
158 |
+
raise ValueError(
|
159 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
160 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
161 |
+
)
|
162 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
163 |
+
timesteps = scheduler.timesteps
|
164 |
+
num_inference_steps = len(timesteps)
|
165 |
+
else:
|
166 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
167 |
+
timesteps = scheduler.timesteps
|
168 |
+
return timesteps, num_inference_steps
|
169 |
+
|
170 |
+
|
171 |
+
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
172 |
+
r"""
|
173 |
+
The Flux pipeline for text-to-image generation.
|
174 |
+
|
175 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
176 |
+
|
177 |
+
Args:
|
178 |
+
transformer ([`FluxTransformer2DModel`]):
|
179 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
180 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
181 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
182 |
+
vae ([`AutoencoderKL`]):
|
183 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
184 |
+
text_encoder ([`CLIPTextModel`]):
|
185 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
186 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
187 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
188 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
189 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
190 |
+
tokenizer (`CLIPTokenizer`):
|
191 |
+
Tokenizer of class
|
192 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
193 |
+
tokenizer_2 (`T5TokenizerFast`):
|
194 |
+
Second Tokenizer of class
|
195 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
196 |
+
"""
|
197 |
+
|
198 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
199 |
+
_optional_components = []
|
200 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
205 |
+
vae: AutoencoderKL,
|
206 |
+
text_encoder: CLIPTextModel,
|
207 |
+
tokenizer: CLIPTokenizer,
|
208 |
+
transformer: FluxTransformer2DModel,
|
209 |
+
controlnet: Union[
|
210 |
+
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
211 |
+
],
|
212 |
+
text_encoder_2: T5EncoderModel | None = None,
|
213 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
self.register_modules(
|
218 |
+
vae=vae,
|
219 |
+
text_encoder=text_encoder,
|
220 |
+
#text_encoder_2=text_encoder_2,
|
221 |
+
tokenizer=tokenizer,
|
222 |
+
#tokenizer_2=tokenizer_2,
|
223 |
+
transformer=transformer,
|
224 |
+
scheduler=scheduler,
|
225 |
+
controlnet=controlnet,
|
226 |
+
)
|
227 |
+
self.vae_scale_factor = (
|
228 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
229 |
+
)
|
230 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
231 |
+
self.tokenizer_max_length = (
|
232 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
233 |
+
)
|
234 |
+
self.default_sample_size = 64
|
235 |
+
|
236 |
+
def _get_t5_prompt_embeds(
|
237 |
+
self,
|
238 |
+
prompt: Union[str, List[str]] = None,
|
239 |
+
num_images_per_prompt: int = 1,
|
240 |
+
max_sequence_length: int = 512,
|
241 |
+
device: Optional[torch.device] = None,
|
242 |
+
dtype: Optional[torch.dtype] = None,
|
243 |
+
):
|
244 |
+
device = device or self._execution_device
|
245 |
+
dtype = dtype or self.text_encoder.dtype
|
246 |
+
|
247 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
248 |
+
batch_size = len(prompt)
|
249 |
+
|
250 |
+
text_inputs = self.tokenizer_2(
|
251 |
+
prompt,
|
252 |
+
padding="max_length",
|
253 |
+
max_length=max_sequence_length,
|
254 |
+
truncation=True,
|
255 |
+
return_length=False,
|
256 |
+
return_overflowing_tokens=False,
|
257 |
+
return_tensors="pt",
|
258 |
+
)
|
259 |
+
text_input_ids = text_inputs.input_ids
|
260 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
261 |
+
|
262 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
263 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
264 |
+
logger.warning(
|
265 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
266 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
267 |
+
)
|
268 |
+
|
269 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
270 |
+
|
271 |
+
dtype = self.text_encoder_2.dtype
|
272 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
273 |
+
|
274 |
+
_, seq_len, _ = prompt_embeds.shape
|
275 |
+
|
276 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
277 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
278 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
279 |
+
|
280 |
+
return prompt_embeds
|
281 |
+
|
282 |
+
def _get_clip_prompt_embeds(
|
283 |
+
self,
|
284 |
+
prompt: Union[str, List[str]],
|
285 |
+
num_images_per_prompt: int = 1,
|
286 |
+
device: Optional[torch.device] = None,
|
287 |
+
):
|
288 |
+
device = device or self._execution_device
|
289 |
+
|
290 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
291 |
+
batch_size = len(prompt)
|
292 |
+
|
293 |
+
text_inputs = self.tokenizer(
|
294 |
+
prompt,
|
295 |
+
padding="max_length",
|
296 |
+
max_length=self.tokenizer_max_length,
|
297 |
+
truncation=True,
|
298 |
+
return_overflowing_tokens=False,
|
299 |
+
return_length=False,
|
300 |
+
return_tensors="pt",
|
301 |
+
)
|
302 |
+
|
303 |
+
text_input_ids = text_inputs.input_ids
|
304 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
305 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
306 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
307 |
+
logger.warning(
|
308 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
309 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
310 |
+
)
|
311 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
312 |
+
|
313 |
+
# Use pooled output of CLIPTextModel
|
314 |
+
prompt_embeds = prompt_embeds.pooler_output
|
315 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
316 |
+
|
317 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
318 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
319 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
320 |
+
|
321 |
+
return prompt_embeds
|
322 |
+
|
323 |
+
def encode_prompt(
|
324 |
+
self,
|
325 |
+
prompt: Union[str, List[str]],
|
326 |
+
prompt_2: Union[str, List[str]],
|
327 |
+
device: Optional[torch.device] = None,
|
328 |
+
num_images_per_prompt: int = 1,
|
329 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
330 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
331 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
332 |
+
max_sequence_length: int = 512,
|
333 |
+
lora_scale: Optional[float] = None,
|
334 |
+
):
|
335 |
+
r"""
|
336 |
+
|
337 |
+
Args:
|
338 |
+
prompt (`str` or `List[str]`, *optional*):
|
339 |
+
prompt to be encoded
|
340 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
341 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
342 |
+
used in all text-encoders
|
343 |
+
device: (`torch.device`):
|
344 |
+
torch device
|
345 |
+
num_images_per_prompt (`int`):
|
346 |
+
number of images that should be generated per prompt
|
347 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
348 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
349 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
350 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
351 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
352 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
353 |
+
clip_skip (`int`, *optional*):
|
354 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
355 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
356 |
+
lora_scale (`float`, *optional*):
|
357 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
358 |
+
"""
|
359 |
+
device = device or self._execution_device
|
360 |
+
|
361 |
+
# set lora scale so that monkey patched LoRA
|
362 |
+
# function of text encoder can correctly access it
|
363 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
364 |
+
self._lora_scale = lora_scale
|
365 |
+
|
366 |
+
# dynamically adjust the LoRA scale
|
367 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
368 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
369 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
370 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
371 |
+
|
372 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
373 |
+
|
374 |
+
if prompt_embeds is None:
|
375 |
+
prompt_2 = prompt_2 or prompt
|
376 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
377 |
+
|
378 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
379 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
380 |
+
prompt=prompt,
|
381 |
+
device=device,
|
382 |
+
num_images_per_prompt=num_images_per_prompt,
|
383 |
+
)
|
384 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
385 |
+
prompt=prompt_2,
|
386 |
+
num_images_per_prompt=num_images_per_prompt,
|
387 |
+
max_sequence_length=max_sequence_length,
|
388 |
+
device=device,
|
389 |
+
)
|
390 |
+
|
391 |
+
if self.text_encoder is not None:
|
392 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
393 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
394 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
395 |
+
|
396 |
+
#if self.text_encoder_2 is not None:
|
397 |
+
# if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
398 |
+
# # Retrieve the original scale by scaling back the LoRA layers
|
399 |
+
# unscale_lora_layers(self.text_encoder_2, lora_scale)
|
400 |
+
|
401 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
402 |
+
if t5_prompt_embeds is not None:
|
403 |
+
text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
404 |
+
else:
|
405 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
406 |
+
|
407 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
408 |
+
|
409 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
410 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
411 |
+
if isinstance(generator, list):
|
412 |
+
image_latents = [
|
413 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
414 |
+
for i in range(image.shape[0])
|
415 |
+
]
|
416 |
+
image_latents = torch.cat(image_latents, dim=0)
|
417 |
+
else:
|
418 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
419 |
+
|
420 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
421 |
+
|
422 |
+
return image_latents
|
423 |
+
|
424 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
425 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
426 |
+
# get the original timestep using init_timestep
|
427 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
428 |
+
|
429 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
430 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
431 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
432 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
433 |
+
|
434 |
+
return timesteps, num_inference_steps - t_start
|
435 |
+
|
436 |
+
def check_inputs(
|
437 |
+
self,
|
438 |
+
prompt,
|
439 |
+
prompt_2,
|
440 |
+
height,
|
441 |
+
width,
|
442 |
+
prompt_embeds=None,
|
443 |
+
pooled_prompt_embeds=None,
|
444 |
+
callback_on_step_end_tensor_inputs=None,
|
445 |
+
max_sequence_length=None,
|
446 |
+
):
|
447 |
+
if height % 8 != 0 or width % 8 != 0:
|
448 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
449 |
+
|
450 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
451 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
452 |
+
):
|
453 |
+
raise ValueError(
|
454 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
455 |
+
)
|
456 |
+
|
457 |
+
if prompt is not None and prompt_embeds is not None:
|
458 |
+
raise ValueError(
|
459 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
460 |
+
" only forward one of the two."
|
461 |
+
)
|
462 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
463 |
+
raise ValueError(
|
464 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
465 |
+
" only forward one of the two."
|
466 |
+
)
|
467 |
+
elif prompt is None and prompt_embeds is None:
|
468 |
+
raise ValueError(
|
469 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
470 |
+
)
|
471 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
472 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
473 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
474 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
475 |
+
|
476 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
477 |
+
raise ValueError(
|
478 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
479 |
+
)
|
480 |
+
|
481 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
482 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
483 |
+
|
484 |
+
@staticmethod
|
485 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
486 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
487 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
488 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
489 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
490 |
+
|
491 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
492 |
+
|
493 |
+
latent_image_ids = latent_image_ids.reshape(
|
494 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
495 |
+
)
|
496 |
+
|
497 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
498 |
+
|
499 |
+
@staticmethod
|
500 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
501 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
502 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
503 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
504 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
505 |
+
|
506 |
+
return latents
|
507 |
+
|
508 |
+
@staticmethod
|
509 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
510 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
511 |
+
batch_size, num_patches, channels = latents.shape
|
512 |
+
|
513 |
+
height = height // vae_scale_factor
|
514 |
+
width = width // vae_scale_factor
|
515 |
+
|
516 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
517 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
518 |
+
|
519 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
520 |
+
|
521 |
+
return latents
|
522 |
+
|
523 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
524 |
+
def prepare_latents(
|
525 |
+
self,
|
526 |
+
image,
|
527 |
+
timestep,
|
528 |
+
batch_size,
|
529 |
+
num_channels_latents,
|
530 |
+
height,
|
531 |
+
width,
|
532 |
+
dtype,
|
533 |
+
device,
|
534 |
+
generator,
|
535 |
+
latents=None,
|
536 |
+
):
|
537 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
538 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
539 |
+
|
540 |
+
shape = (batch_size, num_channels_latents, height, width)
|
541 |
+
|
542 |
+
if latents is not None:
|
543 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
544 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
545 |
+
|
546 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
547 |
+
raise ValueError(
|
548 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
549 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
550 |
+
)
|
551 |
+
image = image.to(device=device, dtype=dtype)
|
552 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
553 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
554 |
+
if timestep == 28:
|
555 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
556 |
+
else:
|
557 |
+
latents = noise
|
558 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
559 |
+
|
560 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
561 |
+
|
562 |
+
return latents, latent_image_ids
|
563 |
+
|
564 |
+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
565 |
+
def prepare_image(
|
566 |
+
self,
|
567 |
+
image,
|
568 |
+
width,
|
569 |
+
height,
|
570 |
+
batch_size,
|
571 |
+
num_images_per_prompt,
|
572 |
+
device,
|
573 |
+
dtype,
|
574 |
+
do_classifier_free_guidance=False,
|
575 |
+
guess_mode=False,
|
576 |
+
):
|
577 |
+
if isinstance(image, torch.Tensor):
|
578 |
+
pass
|
579 |
+
else:
|
580 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
581 |
+
|
582 |
+
image_batch_size = image.shape[0]
|
583 |
+
|
584 |
+
if image_batch_size == 1:
|
585 |
+
repeat_by = batch_size
|
586 |
+
else:
|
587 |
+
# image batch size is the same as prompt batch size
|
588 |
+
repeat_by = num_images_per_prompt
|
589 |
+
|
590 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
591 |
+
|
592 |
+
image = image.to(device=device, dtype=dtype)
|
593 |
+
|
594 |
+
if do_classifier_free_guidance and not guess_mode:
|
595 |
+
image = torch.cat([image] * 2)
|
596 |
+
|
597 |
+
return image
|
598 |
+
|
599 |
+
@property
|
600 |
+
def guidance_scale(self):
|
601 |
+
return self._guidance_scale
|
602 |
+
|
603 |
+
@property
|
604 |
+
def joint_attention_kwargs(self):
|
605 |
+
return self._joint_attention_kwargs
|
606 |
+
|
607 |
+
@property
|
608 |
+
def num_timesteps(self):
|
609 |
+
return self._num_timesteps
|
610 |
+
|
611 |
+
@property
|
612 |
+
def interrupt(self):
|
613 |
+
return self._interrupt
|
614 |
+
|
615 |
+
@torch.no_grad()
|
616 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
617 |
+
def __call__(
|
618 |
+
self,
|
619 |
+
prompt: Union[str, List[str]] = None,
|
620 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
621 |
+
image: PipelineImageInput = None,
|
622 |
+
height: Optional[int] = None,
|
623 |
+
width: Optional[int] = None,
|
624 |
+
strength: float = 0.6,
|
625 |
+
num_inference_steps: int = 28,
|
626 |
+
timesteps: List[int] = None,
|
627 |
+
guidance_scale: float = 7.0,
|
628 |
+
control_image: PipelineImageInput = None,
|
629 |
+
control_mode: Optional[Union[int, List[int]]] = None,
|
630 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
631 |
+
num_images_per_prompt: Optional[int] = 1,
|
632 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
633 |
+
latents: Optional[torch.FloatTensor] = None,
|
634 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
635 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
636 |
+
prompt_embeds_control: Optional[torch.FloatTensor] = None,
|
637 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
638 |
+
output_type: Optional[str] = "pil",
|
639 |
+
return_dict: bool = True,
|
640 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
641 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
642 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
643 |
+
max_sequence_length: int = 512,
|
644 |
+
):
|
645 |
+
r"""
|
646 |
+
Function invoked when calling the pipeline for generation.
|
647 |
+
|
648 |
+
Args:
|
649 |
+
prompt (`str` or `List[str]`, *optional*):
|
650 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
651 |
+
instead.
|
652 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
653 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
654 |
+
will be used instead
|
655 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
656 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
657 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
658 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
659 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
660 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
661 |
+
expense of slower inference.
|
662 |
+
timesteps (`List[int]`, *optional*):
|
663 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
664 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
665 |
+
passed will be used. Must be in descending order.
|
666 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
667 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
668 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
669 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
670 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
671 |
+
usually at the expense of lower image quality.
|
672 |
+
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
673 |
+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
674 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
675 |
+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
676 |
+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
677 |
+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
678 |
+
images must be passed as a list such that each element of the list can be correctly batched for input
|
679 |
+
to a single ControlNet.
|
680 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
681 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
682 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
683 |
+
the corresponding scale as a list.
|
684 |
+
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
|
685 |
+
The control mode when applying ControlNet-Union.
|
686 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
687 |
+
The number of images to generate per prompt.
|
688 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
689 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
690 |
+
to make generation deterministic.
|
691 |
+
latents (`torch.FloatTensor`, *optional*):
|
692 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
693 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
694 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
695 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
696 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
697 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
698 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
699 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
700 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
701 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
702 |
+
The output format of the generate image. Choose between
|
703 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
704 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
705 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
706 |
+
joint_attention_kwargs (`dict`, *optional*):
|
707 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
708 |
+
`self.processor` in
|
709 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
710 |
+
callback_on_step_end (`Callable`, *optional*):
|
711 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
712 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
713 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
714 |
+
`callback_on_step_end_tensor_inputs`.
|
715 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
716 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
717 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
718 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
719 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
720 |
+
|
721 |
+
Examples:
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
725 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
726 |
+
images.
|
727 |
+
"""
|
728 |
+
|
729 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
730 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
731 |
+
|
732 |
+
# 1. Check inputs. Raise error if not correct
|
733 |
+
self.check_inputs(
|
734 |
+
prompt,
|
735 |
+
prompt_2,
|
736 |
+
height,
|
737 |
+
width,
|
738 |
+
prompt_embeds=prompt_embeds,
|
739 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
740 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
741 |
+
max_sequence_length=max_sequence_length,
|
742 |
+
)
|
743 |
+
|
744 |
+
self._guidance_scale = guidance_scale
|
745 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
746 |
+
self._interrupt = False
|
747 |
+
|
748 |
+
# 2. Preprocess image
|
749 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
750 |
+
init_image = init_image.to(dtype=torch.float32)
|
751 |
+
|
752 |
+
# 2. Define call parameters
|
753 |
+
if prompt is not None and isinstance(prompt, str):
|
754 |
+
batch_size = 1
|
755 |
+
elif prompt is not None and isinstance(prompt, list):
|
756 |
+
batch_size = len(prompt)
|
757 |
+
else:
|
758 |
+
batch_size = prompt_embeds.shape[0]
|
759 |
+
|
760 |
+
device = self._execution_device
|
761 |
+
dtype = self.transformer.dtype
|
762 |
+
|
763 |
+
lora_scale = (
|
764 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
765 |
+
)
|
766 |
+
(
|
767 |
+
prompt_embeds,
|
768 |
+
pooled_prompt_embeds,
|
769 |
+
text_ids,
|
770 |
+
) = self.encode_prompt(
|
771 |
+
prompt=prompt,
|
772 |
+
prompt_2=prompt_2,
|
773 |
+
prompt_embeds=prompt_embeds,
|
774 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
775 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
776 |
+
device=device,
|
777 |
+
num_images_per_prompt=num_images_per_prompt,
|
778 |
+
max_sequence_length=max_sequence_length,
|
779 |
+
lora_scale=lora_scale,
|
780 |
+
)
|
781 |
+
|
782 |
+
# 3. Prepare control image
|
783 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
784 |
+
if isinstance(self.controlnet, FluxControlNetModel):
|
785 |
+
control_image = self.prepare_image(
|
786 |
+
image=control_image,
|
787 |
+
width=width,
|
788 |
+
height=height,
|
789 |
+
batch_size=batch_size * num_images_per_prompt,
|
790 |
+
num_images_per_prompt=num_images_per_prompt,
|
791 |
+
device=device,
|
792 |
+
dtype=self.vae.dtype,
|
793 |
+
)
|
794 |
+
height, width = control_image.shape[-2:]
|
795 |
+
|
796 |
+
# vae encode
|
797 |
+
control_image = self.vae.encode(control_image).latent_dist.sample()
|
798 |
+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
799 |
+
|
800 |
+
# pack
|
801 |
+
height_control_image, width_control_image = control_image.shape[2:]
|
802 |
+
control_image = self._pack_latents(
|
803 |
+
control_image,
|
804 |
+
batch_size * num_images_per_prompt,
|
805 |
+
num_channels_latents,
|
806 |
+
height_control_image,
|
807 |
+
width_control_image,
|
808 |
+
)
|
809 |
+
|
810 |
+
# Here we ensure that `control_mode` has the same length as the control_image.
|
811 |
+
if control_mode is not None:
|
812 |
+
if not isinstance(control_mode, int):
|
813 |
+
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
|
814 |
+
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
815 |
+
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
816 |
+
|
817 |
+
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
818 |
+
control_images = []
|
819 |
+
|
820 |
+
for control_image_ in control_image:
|
821 |
+
control_image_ = self.prepare_image(
|
822 |
+
image=control_image_,
|
823 |
+
width=width,
|
824 |
+
height=height,
|
825 |
+
batch_size=batch_size * num_images_per_prompt,
|
826 |
+
num_images_per_prompt=num_images_per_prompt,
|
827 |
+
device=device,
|
828 |
+
dtype=self.vae.dtype,
|
829 |
+
)
|
830 |
+
height, width = control_image_.shape[-2:]
|
831 |
+
|
832 |
+
# vae encode
|
833 |
+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
834 |
+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
835 |
+
|
836 |
+
# pack
|
837 |
+
height_control_image, width_control_image = control_image_.shape[2:]
|
838 |
+
control_image_ = self._pack_latents(
|
839 |
+
control_image_,
|
840 |
+
batch_size * num_images_per_prompt,
|
841 |
+
num_channels_latents,
|
842 |
+
height_control_image,
|
843 |
+
width_control_image,
|
844 |
+
)
|
845 |
+
|
846 |
+
control_images.append(control_image_)
|
847 |
+
|
848 |
+
control_image = control_images
|
849 |
+
|
850 |
+
# Here we ensure that `control_mode` has the same length as the control_image.
|
851 |
+
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
852 |
+
raise ValueError(
|
853 |
+
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
854 |
+
+ " length as the number of controlnets (control images) specified"
|
855 |
+
)
|
856 |
+
if not isinstance(control_mode, list):
|
857 |
+
control_mode = [control_mode] * len(control_image)
|
858 |
+
# set control mode
|
859 |
+
control_modes = []
|
860 |
+
for cmode in control_mode:
|
861 |
+
if cmode is None:
|
862 |
+
cmode = -1
|
863 |
+
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
864 |
+
control_modes.append(control_mode)
|
865 |
+
control_mode = control_modes
|
866 |
+
|
867 |
+
# 5. Prepare timesteps
|
868 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
869 |
+
#image_seq_len = latents.shape[1]
|
870 |
+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
871 |
+
mu = calculate_shift(
|
872 |
+
image_seq_len,
|
873 |
+
self.scheduler.config.base_image_seq_len,
|
874 |
+
self.scheduler.config.max_image_seq_len,
|
875 |
+
self.scheduler.config.base_shift,
|
876 |
+
self.scheduler.config.max_shift,
|
877 |
+
)
|
878 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
879 |
+
self.scheduler,
|
880 |
+
num_inference_steps,
|
881 |
+
device,
|
882 |
+
timesteps,
|
883 |
+
sigmas,
|
884 |
+
mu=mu,
|
885 |
+
)
|
886 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
887 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
888 |
+
|
889 |
+
# 4. Prepare latent variables
|
890 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
891 |
+
latents, latent_image_ids = self.prepare_latents(
|
892 |
+
init_image,
|
893 |
+
latent_timestep,
|
894 |
+
batch_size * num_images_per_prompt,
|
895 |
+
num_channels_latents,
|
896 |
+
height,
|
897 |
+
width,
|
898 |
+
prompt_embeds.dtype,
|
899 |
+
device,
|
900 |
+
generator,
|
901 |
+
latents,
|
902 |
+
)
|
903 |
+
|
904 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
905 |
+
self._num_timesteps = len(timesteps)
|
906 |
+
|
907 |
+
# 6. Denoising loop
|
908 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
909 |
+
for i, t in enumerate(timesteps):
|
910 |
+
if self.interrupt:
|
911 |
+
continue
|
912 |
+
|
913 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
914 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
915 |
+
|
916 |
+
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
917 |
+
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
918 |
+
else:
|
919 |
+
use_guidance = self.controlnet.config.guidance_embeds
|
920 |
+
|
921 |
+
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
922 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
923 |
+
|
924 |
+
# controlnet
|
925 |
+
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
926 |
+
hidden_states=latents,
|
927 |
+
controlnet_cond=control_image,
|
928 |
+
controlnet_mode=control_mode,
|
929 |
+
conditioning_scale=controlnet_conditioning_scale,
|
930 |
+
timestep=timestep / 1000,
|
931 |
+
guidance=guidance,
|
932 |
+
pooled_projections=pooled_prompt_embeds,
|
933 |
+
encoder_hidden_states=prompt_embeds_control,
|
934 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
935 |
+
txt_ids=text_ids,
|
936 |
+
img_ids=latent_image_ids,
|
937 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
938 |
+
return_dict=False,
|
939 |
+
)
|
940 |
+
|
941 |
+
guidance = (
|
942 |
+
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
|
943 |
+
)
|
944 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
945 |
+
|
946 |
+
noise_pred = self.transformer(
|
947 |
+
hidden_states=latents,
|
948 |
+
timestep=timestep / 1000,
|
949 |
+
guidance=guidance,
|
950 |
+
pooled_projections=pooled_prompt_embeds,
|
951 |
+
encoder_hidden_states=prompt_embeds,
|
952 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
953 |
+
controlnet_block_samples=controlnet_block_samples,
|
954 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
955 |
+
txt_ids=text_ids,
|
956 |
+
img_ids=latent_image_ids,
|
957 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
958 |
+
return_dict=False,
|
959 |
+
)[0]
|
960 |
+
|
961 |
+
# compute the previous noisy sample x_t -> x_t-1
|
962 |
+
latents_dtype = latents.dtype
|
963 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
964 |
+
|
965 |
+
if latents.dtype != latents_dtype:
|
966 |
+
if torch.backends.mps.is_available():
|
967 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
968 |
+
latents = latents.to(latents_dtype)
|
969 |
+
|
970 |
+
if callback_on_step_end is not None:
|
971 |
+
callback_kwargs = {}
|
972 |
+
for k in callback_on_step_end_tensor_inputs:
|
973 |
+
callback_kwargs[k] = locals()[k]
|
974 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
975 |
+
|
976 |
+
latents = callback_outputs.pop("latents", latents)
|
977 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
978 |
+
|
979 |
+
# call the callback, if provided
|
980 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
981 |
+
progress_bar.update()
|
982 |
+
|
983 |
+
if XLA_AVAILABLE:
|
984 |
+
xm.mark_step()
|
985 |
+
|
986 |
+
if output_type == "latent":
|
987 |
+
image = latents
|
988 |
+
|
989 |
+
else:
|
990 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
991 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
992 |
+
|
993 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
994 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
995 |
+
|
996 |
+
# Offload all models
|
997 |
+
self.maybe_free_model_hooks()
|
998 |
+
|
999 |
+
if not return_dict:
|
1000 |
+
return (image,)
|
1001 |
+
|
1002 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_controlnet_inpainting.py
ADDED
@@ -0,0 +1,1199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import torch
|
7 |
+
from transformers import (
|
8 |
+
CLIPTextModel,
|
9 |
+
CLIPTokenizer,
|
10 |
+
T5EncoderModel,
|
11 |
+
T5TokenizerFast,
|
12 |
+
)
|
13 |
+
|
14 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
15 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
16 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
17 |
+
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
18 |
+
from .transformer_flux import FluxTransformer2DModel
|
19 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
20 |
+
from diffusers.utils import (
|
21 |
+
USE_PEFT_BACKEND,
|
22 |
+
is_torch_xla_available,
|
23 |
+
logging,
|
24 |
+
replace_example_docstring,
|
25 |
+
scale_lora_layers,
|
26 |
+
unscale_lora_layers,
|
27 |
+
)
|
28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
29 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
30 |
+
from .pipeline_output import FluxPipelineOutput
|
31 |
+
|
32 |
+
|
33 |
+
if is_torch_xla_available():
|
34 |
+
import torch_xla.core.xla_model as xm
|
35 |
+
|
36 |
+
XLA_AVAILABLE = True
|
37 |
+
else:
|
38 |
+
XLA_AVAILABLE = False
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
EXAMPLE_DOC_STRING = """
|
43 |
+
Examples:
|
44 |
+
```py
|
45 |
+
>>> import torch
|
46 |
+
>>> from diffusers import FluxControlNetInpaintPipeline
|
47 |
+
>>> from diffusers.models import FluxControlNetModel
|
48 |
+
>>> from diffusers.utils import load_image
|
49 |
+
|
50 |
+
>>> controlnet = FluxControlNetModel.from_pretrained(
|
51 |
+
... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
|
52 |
+
... )
|
53 |
+
>>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
|
54 |
+
... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
|
55 |
+
... )
|
56 |
+
>>> pipe.to("cuda")
|
57 |
+
|
58 |
+
>>> control_image = load_image(
|
59 |
+
... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
|
60 |
+
... )
|
61 |
+
>>> init_image = load_image(
|
62 |
+
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
63 |
+
... )
|
64 |
+
>>> mask_image = load_image(
|
65 |
+
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
66 |
+
... )
|
67 |
+
|
68 |
+
>>> prompt = "A girl holding a sign that says InstantX"
|
69 |
+
>>> image = pipe(
|
70 |
+
... prompt,
|
71 |
+
... image=init_image,
|
72 |
+
... mask_image=mask_image,
|
73 |
+
... control_image=control_image,
|
74 |
+
... control_guidance_start=0.2,
|
75 |
+
... control_guidance_end=0.8,
|
76 |
+
... controlnet_conditioning_scale=0.7,
|
77 |
+
... strength=0.7,
|
78 |
+
... num_inference_steps=28,
|
79 |
+
... guidance_scale=3.5,
|
80 |
+
... ).images[0]
|
81 |
+
>>> image.save("flux_controlnet_inpaint.png")
|
82 |
+
```
|
83 |
+
"""
|
84 |
+
|
85 |
+
|
86 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
87 |
+
def calculate_shift(
|
88 |
+
image_seq_len,
|
89 |
+
base_seq_len: int = 256,
|
90 |
+
max_seq_len: int = 4096,
|
91 |
+
base_shift: float = 0.5,
|
92 |
+
max_shift: float = 1.16,
|
93 |
+
):
|
94 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
95 |
+
b = base_shift - m * base_seq_len
|
96 |
+
mu = image_seq_len * m + b
|
97 |
+
return mu
|
98 |
+
|
99 |
+
|
100 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
101 |
+
def retrieve_latents(
|
102 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
103 |
+
):
|
104 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
105 |
+
return encoder_output.latent_dist.sample(generator)
|
106 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
107 |
+
return encoder_output.latent_dist.mode()
|
108 |
+
elif hasattr(encoder_output, "latents"):
|
109 |
+
return encoder_output.latents
|
110 |
+
else:
|
111 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
112 |
+
|
113 |
+
|
114 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
115 |
+
def retrieve_timesteps(
|
116 |
+
scheduler,
|
117 |
+
num_inference_steps: Optional[int] = None,
|
118 |
+
device: Optional[Union[str, torch.device]] = None,
|
119 |
+
timesteps: Optional[List[int]] = None,
|
120 |
+
sigmas: Optional[List[float]] = None,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
r"""
|
124 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
125 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scheduler (`SchedulerMixin`):
|
129 |
+
The scheduler to get timesteps from.
|
130 |
+
num_inference_steps (`int`):
|
131 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
132 |
+
must be `None`.
|
133 |
+
device (`str` or `torch.device`, *optional*):
|
134 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
135 |
+
timesteps (`List[int]`, *optional*):
|
136 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
137 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
138 |
+
sigmas (`List[float]`, *optional*):
|
139 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
140 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
144 |
+
second element is the number of inference steps.
|
145 |
+
"""
|
146 |
+
if timesteps is not None and sigmas is not None:
|
147 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
148 |
+
if timesteps is not None:
|
149 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
150 |
+
if not accepts_timesteps:
|
151 |
+
raise ValueError(
|
152 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
153 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
154 |
+
)
|
155 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
156 |
+
timesteps = scheduler.timesteps
|
157 |
+
num_inference_steps = len(timesteps)
|
158 |
+
elif sigmas is not None:
|
159 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
160 |
+
if not accept_sigmas:
|
161 |
+
raise ValueError(
|
162 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
163 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
164 |
+
)
|
165 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
166 |
+
timesteps = scheduler.timesteps
|
167 |
+
num_inference_steps = len(timesteps)
|
168 |
+
else:
|
169 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
170 |
+
timesteps = scheduler.timesteps
|
171 |
+
return timesteps, num_inference_steps
|
172 |
+
|
173 |
+
|
174 |
+
class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
175 |
+
r"""
|
176 |
+
The Flux controlnet pipeline for inpainting.
|
177 |
+
|
178 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
179 |
+
|
180 |
+
Args:
|
181 |
+
transformer ([`FluxTransformer2DModel`]):
|
182 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
183 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
184 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
185 |
+
vae ([`AutoencoderKL`]):
|
186 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
187 |
+
text_encoder ([`CLIPTextModel`]):
|
188 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
189 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
190 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
191 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
192 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
193 |
+
tokenizer (`CLIPTokenizer`):
|
194 |
+
Tokenizer of class
|
195 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
196 |
+
tokenizer_2 (`T5TokenizerFast`):
|
197 |
+
Second Tokenizer of class
|
198 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
199 |
+
"""
|
200 |
+
|
201 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
202 |
+
_optional_components = []
|
203 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
208 |
+
vae: AutoencoderKL,
|
209 |
+
text_encoder: CLIPTextModel,
|
210 |
+
tokenizer: CLIPTokenizer,
|
211 |
+
transformer: FluxTransformer2DModel,
|
212 |
+
controlnet: Union[
|
213 |
+
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
214 |
+
],
|
215 |
+
text_encoder_2: T5EncoderModel | None = None,
|
216 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
if isinstance(controlnet, (list, tuple)):
|
220 |
+
controlnet = FluxMultiControlNetModel(controlnet)
|
221 |
+
|
222 |
+
self.register_modules(
|
223 |
+
scheduler=scheduler,
|
224 |
+
vae=vae,
|
225 |
+
text_encoder=text_encoder,
|
226 |
+
tokenizer=tokenizer,
|
227 |
+
text_encoder_2=text_encoder_2,
|
228 |
+
tokenizer_2=tokenizer_2,
|
229 |
+
transformer=transformer,
|
230 |
+
controlnet=controlnet,
|
231 |
+
)
|
232 |
+
|
233 |
+
self.vae_scale_factor = (
|
234 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
235 |
+
)
|
236 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
237 |
+
self.mask_processor = VaeImageProcessor(
|
238 |
+
vae_scale_factor=self.vae_scale_factor,
|
239 |
+
vae_latent_channels=self.vae.config.latent_channels,
|
240 |
+
do_normalize=False,
|
241 |
+
do_binarize=True,
|
242 |
+
do_convert_grayscale=True,
|
243 |
+
)
|
244 |
+
self.tokenizer_max_length = (
|
245 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
246 |
+
)
|
247 |
+
self.default_sample_size = 64
|
248 |
+
|
249 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
250 |
+
def _get_t5_prompt_embeds(
|
251 |
+
self,
|
252 |
+
prompt: Union[str, List[str]] = None,
|
253 |
+
num_images_per_prompt: int = 1,
|
254 |
+
max_sequence_length: int = 512,
|
255 |
+
device: Optional[torch.device] = None,
|
256 |
+
dtype: Optional[torch.dtype] = None,
|
257 |
+
):
|
258 |
+
device = device or self._execution_device
|
259 |
+
dtype = dtype or self.text_encoder.dtype
|
260 |
+
|
261 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
262 |
+
batch_size = len(prompt)
|
263 |
+
|
264 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
265 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
266 |
+
|
267 |
+
text_inputs = self.tokenizer_2(
|
268 |
+
prompt,
|
269 |
+
padding="max_length",
|
270 |
+
max_length=max_sequence_length,
|
271 |
+
truncation=True,
|
272 |
+
return_length=False,
|
273 |
+
return_overflowing_tokens=False,
|
274 |
+
return_tensors="pt",
|
275 |
+
)
|
276 |
+
text_input_ids = text_inputs.input_ids
|
277 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
278 |
+
|
279 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
280 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
281 |
+
logger.warning(
|
282 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
283 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
284 |
+
)
|
285 |
+
|
286 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
287 |
+
|
288 |
+
dtype = self.text_encoder_2.dtype
|
289 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
290 |
+
|
291 |
+
_, seq_len, _ = prompt_embeds.shape
|
292 |
+
|
293 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
294 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
295 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
296 |
+
|
297 |
+
return prompt_embeds
|
298 |
+
|
299 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
300 |
+
def _get_clip_prompt_embeds(
|
301 |
+
self,
|
302 |
+
prompt: Union[str, List[str]],
|
303 |
+
num_images_per_prompt: int = 1,
|
304 |
+
device: Optional[torch.device] = None,
|
305 |
+
):
|
306 |
+
device = device or self._execution_device
|
307 |
+
|
308 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
309 |
+
batch_size = len(prompt)
|
310 |
+
|
311 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
312 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
313 |
+
|
314 |
+
text_inputs = self.tokenizer(
|
315 |
+
prompt,
|
316 |
+
padding="max_length",
|
317 |
+
max_length=self.tokenizer_max_length,
|
318 |
+
truncation=True,
|
319 |
+
return_overflowing_tokens=False,
|
320 |
+
return_length=False,
|
321 |
+
return_tensors="pt",
|
322 |
+
)
|
323 |
+
|
324 |
+
text_input_ids = text_inputs.input_ids
|
325 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
326 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
327 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
328 |
+
logger.warning(
|
329 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
330 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
331 |
+
)
|
332 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
333 |
+
|
334 |
+
# Use pooled output of CLIPTextModel
|
335 |
+
prompt_embeds = prompt_embeds.pooler_output
|
336 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
337 |
+
|
338 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
339 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
340 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
341 |
+
|
342 |
+
return prompt_embeds
|
343 |
+
|
344 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
345 |
+
def encode_prompt(
|
346 |
+
self,
|
347 |
+
prompt: Union[str, List[str]],
|
348 |
+
prompt_2: Union[str, List[str]],
|
349 |
+
device: Optional[torch.device] = None,
|
350 |
+
num_images_per_prompt: int = 1,
|
351 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
352 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
353 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
354 |
+
max_sequence_length: int = 512,
|
355 |
+
lora_scale: Optional[float] = None,
|
356 |
+
):
|
357 |
+
r"""
|
358 |
+
|
359 |
+
Args:
|
360 |
+
prompt (`str` or `List[str]`, *optional*):
|
361 |
+
prompt to be encoded
|
362 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
363 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
364 |
+
used in all text-encoders
|
365 |
+
device: (`torch.device`):
|
366 |
+
torch device
|
367 |
+
num_images_per_prompt (`int`):
|
368 |
+
number of images that should be generated per prompt
|
369 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
370 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
371 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
372 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
373 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
374 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
375 |
+
lora_scale (`float`, *optional*):
|
376 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
377 |
+
"""
|
378 |
+
device = device or self._execution_device
|
379 |
+
|
380 |
+
# set lora scale so that monkey patched LoRA
|
381 |
+
# function of text encoder can correctly access it
|
382 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
383 |
+
self._lora_scale = lora_scale
|
384 |
+
|
385 |
+
# dynamically adjust the LoRA scale
|
386 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
387 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
388 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
389 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
390 |
+
|
391 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
392 |
+
|
393 |
+
if prompt_embeds is None:
|
394 |
+
prompt_2 = prompt_2 or prompt
|
395 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
396 |
+
|
397 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
398 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
399 |
+
prompt=prompt,
|
400 |
+
device=device,
|
401 |
+
num_images_per_prompt=num_images_per_prompt,
|
402 |
+
)
|
403 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
404 |
+
prompt=prompt_2,
|
405 |
+
num_images_per_prompt=num_images_per_prompt,
|
406 |
+
max_sequence_length=max_sequence_length,
|
407 |
+
device=device,
|
408 |
+
)
|
409 |
+
|
410 |
+
if self.text_encoder is not None:
|
411 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
412 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
413 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
414 |
+
|
415 |
+
if self.text_encoder_2 is not None:
|
416 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
417 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
418 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
419 |
+
|
420 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
421 |
+
if t5_prompt_embeds is not None:
|
422 |
+
text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
423 |
+
else:
|
424 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
425 |
+
|
426 |
+
|
427 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
428 |
+
|
429 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
430 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
431 |
+
if isinstance(generator, list):
|
432 |
+
image_latents = [
|
433 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
434 |
+
for i in range(image.shape[0])
|
435 |
+
]
|
436 |
+
image_latents = torch.cat(image_latents, dim=0)
|
437 |
+
else:
|
438 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
439 |
+
|
440 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
441 |
+
|
442 |
+
return image_latents
|
443 |
+
|
444 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
445 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
446 |
+
# get the original timestep using init_timestep
|
447 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
448 |
+
|
449 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
450 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
451 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
452 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
453 |
+
|
454 |
+
return timesteps, num_inference_steps - t_start
|
455 |
+
|
456 |
+
def check_inputs(
|
457 |
+
self,
|
458 |
+
prompt,
|
459 |
+
prompt_2,
|
460 |
+
image,
|
461 |
+
mask_image,
|
462 |
+
strength,
|
463 |
+
height,
|
464 |
+
width,
|
465 |
+
output_type,
|
466 |
+
prompt_embeds=None,
|
467 |
+
pooled_prompt_embeds=None,
|
468 |
+
callback_on_step_end_tensor_inputs=None,
|
469 |
+
padding_mask_crop=None,
|
470 |
+
max_sequence_length=None,
|
471 |
+
):
|
472 |
+
if strength < 0 or strength > 1:
|
473 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
474 |
+
|
475 |
+
if height % 8 != 0 or width % 8 != 0:
|
476 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
477 |
+
|
478 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
479 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
480 |
+
):
|
481 |
+
raise ValueError(
|
482 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
483 |
+
)
|
484 |
+
|
485 |
+
if prompt is not None and prompt_embeds is not None:
|
486 |
+
raise ValueError(
|
487 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
488 |
+
" only forward one of the two."
|
489 |
+
)
|
490 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
491 |
+
raise ValueError(
|
492 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
493 |
+
" only forward one of the two."
|
494 |
+
)
|
495 |
+
elif prompt is None and prompt_embeds is None:
|
496 |
+
raise ValueError(
|
497 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
498 |
+
)
|
499 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
500 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
501 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
502 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
503 |
+
|
504 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
505 |
+
raise ValueError(
|
506 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
507 |
+
)
|
508 |
+
|
509 |
+
if padding_mask_crop is not None:
|
510 |
+
if not isinstance(image, PIL.Image.Image):
|
511 |
+
raise ValueError(
|
512 |
+
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
513 |
+
)
|
514 |
+
if not isinstance(mask_image, PIL.Image.Image):
|
515 |
+
raise ValueError(
|
516 |
+
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
517 |
+
f" {type(mask_image)}."
|
518 |
+
)
|
519 |
+
if output_type != "pil":
|
520 |
+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
521 |
+
|
522 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
523 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
524 |
+
|
525 |
+
@staticmethod
|
526 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
527 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
528 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
529 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
530 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
531 |
+
|
532 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
533 |
+
|
534 |
+
latent_image_ids = latent_image_ids.reshape(
|
535 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
536 |
+
)
|
537 |
+
|
538 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
539 |
+
|
540 |
+
@staticmethod
|
541 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
542 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
543 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
544 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
545 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
546 |
+
|
547 |
+
return latents
|
548 |
+
|
549 |
+
@staticmethod
|
550 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
551 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
552 |
+
batch_size, num_patches, channels = latents.shape
|
553 |
+
|
554 |
+
height = height // vae_scale_factor
|
555 |
+
width = width // vae_scale_factor
|
556 |
+
|
557 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
558 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
559 |
+
|
560 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
561 |
+
|
562 |
+
return latents
|
563 |
+
|
564 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
|
565 |
+
def prepare_latents(
|
566 |
+
self,
|
567 |
+
image,
|
568 |
+
timestep,
|
569 |
+
batch_size,
|
570 |
+
num_channels_latents,
|
571 |
+
height,
|
572 |
+
width,
|
573 |
+
dtype,
|
574 |
+
device,
|
575 |
+
generator,
|
576 |
+
latents=None,
|
577 |
+
):
|
578 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
579 |
+
raise ValueError(
|
580 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
581 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
582 |
+
)
|
583 |
+
|
584 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
585 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
586 |
+
|
587 |
+
shape = (batch_size, num_channels_latents, height, width)
|
588 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
589 |
+
|
590 |
+
image = image.to(device=device, dtype=dtype)
|
591 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
592 |
+
|
593 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
594 |
+
# expand init_latents for batch_size
|
595 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
596 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
597 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
598 |
+
raise ValueError(
|
599 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
image_latents = torch.cat([image_latents], dim=0)
|
603 |
+
|
604 |
+
if latents is None:
|
605 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
606 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
607 |
+
else:
|
608 |
+
noise = latents.to(device)
|
609 |
+
latents = noise
|
610 |
+
|
611 |
+
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
|
612 |
+
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
|
613 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
614 |
+
return latents, noise, image_latents, latent_image_ids
|
615 |
+
|
616 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
|
617 |
+
def prepare_mask_latents(
|
618 |
+
self,
|
619 |
+
mask,
|
620 |
+
masked_image,
|
621 |
+
batch_size,
|
622 |
+
num_channels_latents,
|
623 |
+
num_images_per_prompt,
|
624 |
+
height,
|
625 |
+
width,
|
626 |
+
dtype,
|
627 |
+
device,
|
628 |
+
generator,
|
629 |
+
):
|
630 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
631 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
632 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
633 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
634 |
+
# and half precision
|
635 |
+
mask = torch.nn.functional.interpolate(mask, size=(height, width))
|
636 |
+
mask = mask.to(device=device, dtype=dtype)
|
637 |
+
|
638 |
+
batch_size = batch_size * num_images_per_prompt
|
639 |
+
|
640 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
641 |
+
|
642 |
+
if masked_image.shape[1] == 16:
|
643 |
+
masked_image_latents = masked_image
|
644 |
+
else:
|
645 |
+
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
646 |
+
|
647 |
+
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
648 |
+
|
649 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
650 |
+
if mask.shape[0] < batch_size:
|
651 |
+
if not batch_size % mask.shape[0] == 0:
|
652 |
+
raise ValueError(
|
653 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
654 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
655 |
+
" of masks that you pass is divisible by the total requested batch size."
|
656 |
+
)
|
657 |
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
658 |
+
if masked_image_latents.shape[0] < batch_size:
|
659 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
660 |
+
raise ValueError(
|
661 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
662 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
663 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
664 |
+
)
|
665 |
+
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
666 |
+
|
667 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
668 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
669 |
+
|
670 |
+
masked_image_latents = self._pack_latents(
|
671 |
+
masked_image_latents,
|
672 |
+
batch_size,
|
673 |
+
num_channels_latents,
|
674 |
+
height,
|
675 |
+
width,
|
676 |
+
)
|
677 |
+
mask = self._pack_latents(
|
678 |
+
mask.repeat(1, num_channels_latents, 1, 1),
|
679 |
+
batch_size,
|
680 |
+
num_channels_latents,
|
681 |
+
height,
|
682 |
+
width,
|
683 |
+
)
|
684 |
+
|
685 |
+
return mask, masked_image_latents
|
686 |
+
|
687 |
+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
688 |
+
def prepare_image(
|
689 |
+
self,
|
690 |
+
image,
|
691 |
+
width,
|
692 |
+
height,
|
693 |
+
batch_size,
|
694 |
+
num_images_per_prompt,
|
695 |
+
device,
|
696 |
+
dtype,
|
697 |
+
do_classifier_free_guidance=False,
|
698 |
+
guess_mode=False,
|
699 |
+
):
|
700 |
+
if isinstance(image, torch.Tensor):
|
701 |
+
pass
|
702 |
+
else:
|
703 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
704 |
+
|
705 |
+
image_batch_size = image.shape[0]
|
706 |
+
|
707 |
+
if image_batch_size == 1:
|
708 |
+
repeat_by = batch_size
|
709 |
+
else:
|
710 |
+
# image batch size is the same as prompt batch size
|
711 |
+
repeat_by = num_images_per_prompt
|
712 |
+
|
713 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
714 |
+
|
715 |
+
image = image.to(device=device, dtype=dtype)
|
716 |
+
|
717 |
+
if do_classifier_free_guidance and not guess_mode:
|
718 |
+
image = torch.cat([image] * 2)
|
719 |
+
|
720 |
+
return image
|
721 |
+
|
722 |
+
@property
|
723 |
+
def guidance_scale(self):
|
724 |
+
return self._guidance_scale
|
725 |
+
|
726 |
+
@property
|
727 |
+
def joint_attention_kwargs(self):
|
728 |
+
return self._joint_attention_kwargs
|
729 |
+
|
730 |
+
@property
|
731 |
+
def num_timesteps(self):
|
732 |
+
return self._num_timesteps
|
733 |
+
|
734 |
+
@property
|
735 |
+
def interrupt(self):
|
736 |
+
return self._interrupt
|
737 |
+
|
738 |
+
@torch.no_grad()
|
739 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
740 |
+
def __call__(
|
741 |
+
self,
|
742 |
+
prompt: Union[str, List[str]] = None,
|
743 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
744 |
+
image: PipelineImageInput = None,
|
745 |
+
mask_image: PipelineImageInput = None,
|
746 |
+
masked_image_latents: PipelineImageInput = None,
|
747 |
+
control_image: PipelineImageInput = None,
|
748 |
+
height: Optional[int] = None,
|
749 |
+
width: Optional[int] = None,
|
750 |
+
strength: float = 0.6,
|
751 |
+
padding_mask_crop: Optional[int] = None,
|
752 |
+
timesteps: List[int] = None,
|
753 |
+
num_inference_steps: int = 28,
|
754 |
+
guidance_scale: float = 7.0,
|
755 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
756 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
757 |
+
control_mode: Optional[Union[int, List[int]]] = None,
|
758 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
759 |
+
num_images_per_prompt: Optional[int] = 1,
|
760 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
761 |
+
latents: Optional[torch.FloatTensor] = None,
|
762 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
763 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
764 |
+
prompt_embeds_control: Optional[torch.FloatTensor] = None,
|
765 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
766 |
+
output_type: Optional[str] = "pil",
|
767 |
+
return_dict: bool = True,
|
768 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
769 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
770 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
771 |
+
max_sequence_length: int = 512,
|
772 |
+
):
|
773 |
+
"""
|
774 |
+
Function invoked when calling the pipeline for generation.
|
775 |
+
|
776 |
+
Args:
|
777 |
+
prompt (`str` or `List[str]`, *optional*):
|
778 |
+
The prompt or prompts to guide the image generation.
|
779 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
780 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
|
781 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
782 |
+
The image(s) to inpaint.
|
783 |
+
mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
784 |
+
The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
|
785 |
+
will be preserved.
|
786 |
+
masked_image_latents (`torch.FloatTensor`, *optional*):
|
787 |
+
Pre-generated masked image latents.
|
788 |
+
control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
789 |
+
The ControlNet input condition. Image to control the generation.
|
790 |
+
height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
|
791 |
+
The height in pixels of the generated image.
|
792 |
+
width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
|
793 |
+
The width in pixels of the generated image.
|
794 |
+
strength (`float`, *optional*, defaults to 0.6):
|
795 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
|
796 |
+
padding_mask_crop (`int`, *optional*):
|
797 |
+
The size of the padding to use when cropping the mask.
|
798 |
+
num_inference_steps (`int`, *optional*, defaults to 28):
|
799 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
800 |
+
expense of slower inference.
|
801 |
+
timesteps (`List[int]`, *optional*):
|
802 |
+
Custom timesteps to use for the denoising process.
|
803 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
804 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
805 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
806 |
+
The percentage of total steps at which the ControlNet starts applying.
|
807 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
808 |
+
The percentage of total steps at which the ControlNet stops applying.
|
809 |
+
control_mode (`int` or `List[int]`, *optional*):
|
810 |
+
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
|
811 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
812 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
813 |
+
to the residual in the original transformer.
|
814 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
815 |
+
The number of images to generate per prompt.
|
816 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
817 |
+
One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
|
818 |
+
make generation deterministic.
|
819 |
+
latents (`torch.FloatTensor`, *optional*):
|
820 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
821 |
+
generation. Can be used to tweak the same generation with different prompts.
|
822 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
823 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
824 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
825 |
+
Pre-generated pooled text embeddings.
|
826 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
827 |
+
The output format of the generate image. Choose between `PIL.Image` or `np.array`.
|
828 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
829 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
830 |
+
joint_attention_kwargs (`dict`, *optional*):
|
831 |
+
Additional keyword arguments to be passed to the joint attention mechanism.
|
832 |
+
callback_on_step_end (`Callable`, *optional*):
|
833 |
+
A function that calls at the end of each denoising step during the inference.
|
834 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
835 |
+
The list of tensor inputs for the `callback_on_step_end` function.
|
836 |
+
max_sequence_length (`int`, *optional*, defaults to 512):
|
837 |
+
The maximum length of the sequence to be generated.
|
838 |
+
|
839 |
+
Examples:
|
840 |
+
|
841 |
+
Returns:
|
842 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
843 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
844 |
+
images.
|
845 |
+
"""
|
846 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
847 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
848 |
+
|
849 |
+
global_height = height
|
850 |
+
global_width = width
|
851 |
+
|
852 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
853 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
854 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
855 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
856 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
857 |
+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
|
858 |
+
control_guidance_start, control_guidance_end = (
|
859 |
+
mult * [control_guidance_start],
|
860 |
+
mult * [control_guidance_end],
|
861 |
+
)
|
862 |
+
|
863 |
+
# 1. Check inputs
|
864 |
+
self.check_inputs(
|
865 |
+
prompt,
|
866 |
+
prompt_2,
|
867 |
+
image,
|
868 |
+
mask_image,
|
869 |
+
strength,
|
870 |
+
height,
|
871 |
+
width,
|
872 |
+
output_type=output_type,
|
873 |
+
prompt_embeds=prompt_embeds,
|
874 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
875 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
876 |
+
padding_mask_crop=padding_mask_crop,
|
877 |
+
max_sequence_length=max_sequence_length,
|
878 |
+
)
|
879 |
+
|
880 |
+
self._guidance_scale = guidance_scale
|
881 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
882 |
+
self._interrupt = False
|
883 |
+
|
884 |
+
# 2. Define call parameters
|
885 |
+
if prompt is not None and isinstance(prompt, str):
|
886 |
+
batch_size = 1
|
887 |
+
elif prompt is not None and isinstance(prompt, list):
|
888 |
+
batch_size = len(prompt)
|
889 |
+
else:
|
890 |
+
batch_size = prompt_embeds.shape[0]
|
891 |
+
|
892 |
+
device = self._execution_device
|
893 |
+
dtype = self.transformer.dtype
|
894 |
+
|
895 |
+
# 3. Encode input prompt
|
896 |
+
lora_scale = (
|
897 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
898 |
+
)
|
899 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
|
900 |
+
prompt=prompt,
|
901 |
+
prompt_2=prompt_2,
|
902 |
+
prompt_embeds=prompt_embeds,
|
903 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
904 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
905 |
+
device=device,
|
906 |
+
num_images_per_prompt=num_images_per_prompt,
|
907 |
+
max_sequence_length=max_sequence_length,
|
908 |
+
lora_scale=lora_scale,
|
909 |
+
)
|
910 |
+
|
911 |
+
# 4. Preprocess mask and image
|
912 |
+
if padding_mask_crop is not None:
|
913 |
+
crops_coords = self.mask_processor.get_crop_region(
|
914 |
+
mask_image, global_width, global_height, pad=padding_mask_crop
|
915 |
+
)
|
916 |
+
resize_mode = "fill"
|
917 |
+
else:
|
918 |
+
crops_coords = None
|
919 |
+
resize_mode = "default"
|
920 |
+
|
921 |
+
original_image = image
|
922 |
+
init_image = self.image_processor.preprocess(
|
923 |
+
image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
|
924 |
+
)
|
925 |
+
init_image = init_image.to(dtype=torch.float32)
|
926 |
+
|
927 |
+
# 5. Prepare control image
|
928 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
929 |
+
if isinstance(self.controlnet, FluxControlNetModel):
|
930 |
+
control_image = self.prepare_image(
|
931 |
+
image=control_image,
|
932 |
+
width=height,
|
933 |
+
height=width,
|
934 |
+
batch_size=batch_size * num_images_per_prompt,
|
935 |
+
num_images_per_prompt=num_images_per_prompt,
|
936 |
+
device=device,
|
937 |
+
dtype=self.vae.dtype,
|
938 |
+
)
|
939 |
+
height, width = control_image.shape[-2:]
|
940 |
+
|
941 |
+
# vae encode
|
942 |
+
control_image = self.vae.encode(control_image).latent_dist.sample()
|
943 |
+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
944 |
+
|
945 |
+
# pack
|
946 |
+
height_control_image, width_control_image = control_image.shape[2:]
|
947 |
+
control_image = self._pack_latents(
|
948 |
+
control_image,
|
949 |
+
batch_size * num_images_per_prompt,
|
950 |
+
num_channels_latents,
|
951 |
+
height_control_image,
|
952 |
+
width_control_image,
|
953 |
+
)
|
954 |
+
|
955 |
+
# set control mode
|
956 |
+
if control_mode is not None:
|
957 |
+
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
958 |
+
control_mode = control_mode.reshape([-1, 1])
|
959 |
+
|
960 |
+
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
961 |
+
control_images = []
|
962 |
+
|
963 |
+
for control_image_ in control_image:
|
964 |
+
control_image_ = self.prepare_image(
|
965 |
+
image=control_image_,
|
966 |
+
width=width,
|
967 |
+
height=height,
|
968 |
+
batch_size=batch_size * num_images_per_prompt,
|
969 |
+
num_images_per_prompt=num_images_per_prompt,
|
970 |
+
device=device,
|
971 |
+
dtype=self.vae.dtype,
|
972 |
+
)
|
973 |
+
height, width = control_image_.shape[-2:]
|
974 |
+
|
975 |
+
# vae encode
|
976 |
+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
977 |
+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
978 |
+
|
979 |
+
# pack
|
980 |
+
height_control_image, width_control_image = control_image_.shape[2:]
|
981 |
+
control_image_ = self._pack_latents(
|
982 |
+
control_image_,
|
983 |
+
batch_size * num_images_per_prompt,
|
984 |
+
num_channels_latents,
|
985 |
+
height_control_image,
|
986 |
+
width_control_image,
|
987 |
+
)
|
988 |
+
|
989 |
+
control_images.append(control_image_)
|
990 |
+
|
991 |
+
control_image = control_images
|
992 |
+
|
993 |
+
## set control mode
|
994 |
+
#control_mode_ = []
|
995 |
+
#if isinstance(control_mode, list):
|
996 |
+
# for cmode in control_mode:
|
997 |
+
# if cmode is None:
|
998 |
+
# control_mode_.append(-1)
|
999 |
+
# else:
|
1000 |
+
# control_mode_.append(cmode)
|
1001 |
+
#control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
1002 |
+
#control_mode = control_mode.reshape([-1, 1])
|
1003 |
+
control_modes = []
|
1004 |
+
for cmode in control_mode:
|
1005 |
+
if cmode is None:
|
1006 |
+
cmode = -1
|
1007 |
+
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
1008 |
+
control_modes.append(control_mode)
|
1009 |
+
control_mode = control_modes
|
1010 |
+
|
1011 |
+
# 6. Prepare timesteps
|
1012 |
+
|
1013 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
1014 |
+
image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
|
1015 |
+
mu = calculate_shift(
|
1016 |
+
image_seq_len,
|
1017 |
+
self.scheduler.config.base_image_seq_len,
|
1018 |
+
self.scheduler.config.max_image_seq_len,
|
1019 |
+
self.scheduler.config.base_shift,
|
1020 |
+
self.scheduler.config.max_shift,
|
1021 |
+
)
|
1022 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
1023 |
+
self.scheduler,
|
1024 |
+
num_inference_steps,
|
1025 |
+
device,
|
1026 |
+
timesteps,
|
1027 |
+
sigmas,
|
1028 |
+
mu=mu,
|
1029 |
+
)
|
1030 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
1031 |
+
|
1032 |
+
if num_inference_steps < 1:
|
1033 |
+
raise ValueError(
|
1034 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
1035 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
1036 |
+
)
|
1037 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
1038 |
+
|
1039 |
+
# 7. Prepare latent variables
|
1040 |
+
|
1041 |
+
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
|
1042 |
+
init_image,
|
1043 |
+
latent_timestep,
|
1044 |
+
batch_size * num_images_per_prompt,
|
1045 |
+
num_channels_latents,
|
1046 |
+
global_height,
|
1047 |
+
global_width,
|
1048 |
+
prompt_embeds.dtype,
|
1049 |
+
device,
|
1050 |
+
generator,
|
1051 |
+
latents,
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
# 8. Prepare mask latents
|
1055 |
+
mask_condition = self.mask_processor.preprocess(
|
1056 |
+
mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
|
1057 |
+
)
|
1058 |
+
if masked_image_latents is None:
|
1059 |
+
masked_image = init_image * (mask_condition < 0.5)
|
1060 |
+
else:
|
1061 |
+
masked_image = masked_image_latents
|
1062 |
+
|
1063 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
1064 |
+
mask_condition,
|
1065 |
+
masked_image,
|
1066 |
+
batch_size,
|
1067 |
+
num_channels_latents,
|
1068 |
+
num_images_per_prompt,
|
1069 |
+
global_height,
|
1070 |
+
global_width,
|
1071 |
+
prompt_embeds.dtype,
|
1072 |
+
device,
|
1073 |
+
generator,
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
controlnet_keep = []
|
1077 |
+
for i in range(len(timesteps)):
|
1078 |
+
keeps = [
|
1079 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1080 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
1081 |
+
]
|
1082 |
+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
1083 |
+
|
1084 |
+
# 9. Denoising loop
|
1085 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1086 |
+
self._num_timesteps = len(timesteps)
|
1087 |
+
|
1088 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1089 |
+
for i, t in enumerate(timesteps):
|
1090 |
+
if self.interrupt:
|
1091 |
+
continue
|
1092 |
+
|
1093 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
1094 |
+
|
1095 |
+
# predict the noise residual
|
1096 |
+
#if self.controlnet.config.guidance_embeds:
|
1097 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
1098 |
+
guidance = guidance.expand(latents.shape[0])
|
1099 |
+
#else:
|
1100 |
+
# guidance = None
|
1101 |
+
|
1102 |
+
if isinstance(controlnet_keep[i], list):
|
1103 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
1104 |
+
else:
|
1105 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
1106 |
+
if isinstance(controlnet_cond_scale, list):
|
1107 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
1108 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1109 |
+
|
1110 |
+
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
1111 |
+
hidden_states=latents,
|
1112 |
+
controlnet_cond=control_image,
|
1113 |
+
controlnet_mode=control_mode,
|
1114 |
+
conditioning_scale=cond_scale,
|
1115 |
+
timestep=timestep / 1000,
|
1116 |
+
guidance=guidance,
|
1117 |
+
pooled_projections=pooled_prompt_embeds,
|
1118 |
+
encoder_hidden_states=prompt_embeds_control,
|
1119 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
1120 |
+
txt_ids=text_ids,
|
1121 |
+
img_ids=latent_image_ids,
|
1122 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1123 |
+
return_dict=False,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
if self.transformer.config.guidance_embeds:
|
1127 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
1128 |
+
guidance = guidance.expand(latents.shape[0])
|
1129 |
+
else:
|
1130 |
+
guidance = None
|
1131 |
+
|
1132 |
+
noise_pred = self.transformer(
|
1133 |
+
hidden_states=latents,
|
1134 |
+
timestep=timestep / 1000,
|
1135 |
+
guidance=guidance,
|
1136 |
+
pooled_projections=pooled_prompt_embeds,
|
1137 |
+
encoder_hidden_states=prompt_embeds,
|
1138 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
1139 |
+
controlnet_block_samples=controlnet_block_samples,
|
1140 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
1141 |
+
txt_ids=text_ids,
|
1142 |
+
img_ids=latent_image_ids,
|
1143 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1144 |
+
return_dict=False,
|
1145 |
+
)[0]
|
1146 |
+
|
1147 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1148 |
+
latents_dtype = latents.dtype
|
1149 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
1150 |
+
|
1151 |
+
# For inpainting, we need to apply the mask and add the masked image latents
|
1152 |
+
init_latents_proper = image_latents
|
1153 |
+
init_mask = mask
|
1154 |
+
|
1155 |
+
if i < len(timesteps) - 1:
|
1156 |
+
noise_timestep = timesteps[i + 1]
|
1157 |
+
init_latents_proper = self.scheduler.scale_noise(
|
1158 |
+
init_latents_proper, torch.tensor([noise_timestep]), noise
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
1162 |
+
|
1163 |
+
if latents.dtype != latents_dtype:
|
1164 |
+
if torch.backends.mps.is_available():
|
1165 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1166 |
+
latents = latents.to(latents_dtype)
|
1167 |
+
|
1168 |
+
# call the callback, if provided
|
1169 |
+
if callback_on_step_end is not None:
|
1170 |
+
callback_kwargs = {}
|
1171 |
+
for k in callback_on_step_end_tensor_inputs:
|
1172 |
+
callback_kwargs[k] = locals()[k]
|
1173 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1174 |
+
|
1175 |
+
latents = callback_outputs.pop("latents", latents)
|
1176 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1177 |
+
|
1178 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1179 |
+
progress_bar.update()
|
1180 |
+
|
1181 |
+
if XLA_AVAILABLE:
|
1182 |
+
xm.mark_step()
|
1183 |
+
|
1184 |
+
# Post-processing
|
1185 |
+
if output_type == "latent":
|
1186 |
+
image = latents
|
1187 |
+
else:
|
1188 |
+
latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
|
1189 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1190 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1191 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1192 |
+
|
1193 |
+
# Offload all models
|
1194 |
+
self.maybe_free_model_hooks()
|
1195 |
+
|
1196 |
+
if not return_dict:
|
1197 |
+
return (image,)
|
1198 |
+
|
1199 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_img2img.py
ADDED
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
21 |
+
|
22 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
23 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
24 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
25 |
+
from .transformer_flux import FluxTransformer2DModel
|
26 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
27 |
+
from diffusers.utils import (
|
28 |
+
USE_PEFT_BACKEND,
|
29 |
+
is_torch_xla_available,
|
30 |
+
logging,
|
31 |
+
replace_example_docstring,
|
32 |
+
scale_lora_layers,
|
33 |
+
unscale_lora_layers,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
37 |
+
from .pipeline_output import FluxPipelineOutput
|
38 |
+
|
39 |
+
|
40 |
+
if is_torch_xla_available():
|
41 |
+
import torch_xla.core.xla_model as xm
|
42 |
+
|
43 |
+
XLA_AVAILABLE = True
|
44 |
+
else:
|
45 |
+
XLA_AVAILABLE = False
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
EXAMPLE_DOC_STRING = """
|
51 |
+
Examples:
|
52 |
+
```py
|
53 |
+
>>> import torch
|
54 |
+
|
55 |
+
>>> from diffusers import FluxImg2ImgPipeline
|
56 |
+
>>> from diffusers.utils import load_image
|
57 |
+
|
58 |
+
>>> device = "cuda"
|
59 |
+
>>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
60 |
+
>>> pipe = pipe.to(device)
|
61 |
+
|
62 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
63 |
+
>>> init_image = load_image(url).resize((1024, 1024))
|
64 |
+
|
65 |
+
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
66 |
+
|
67 |
+
>>> images = pipe(
|
68 |
+
... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
|
69 |
+
... ).images[0]
|
70 |
+
```
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
75 |
+
def calculate_shift(
|
76 |
+
image_seq_len,
|
77 |
+
base_seq_len: int = 256,
|
78 |
+
max_seq_len: int = 4096,
|
79 |
+
base_shift: float = 0.5,
|
80 |
+
max_shift: float = 1.16,
|
81 |
+
):
|
82 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
83 |
+
b = base_shift - m * base_seq_len
|
84 |
+
mu = image_seq_len * m + b
|
85 |
+
return mu
|
86 |
+
|
87 |
+
|
88 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
89 |
+
def retrieve_latents(
|
90 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
91 |
+
):
|
92 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
93 |
+
return encoder_output.latent_dist.sample(generator)
|
94 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
95 |
+
return encoder_output.latent_dist.mode()
|
96 |
+
elif hasattr(encoder_output, "latents"):
|
97 |
+
return encoder_output.latents
|
98 |
+
else:
|
99 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
100 |
+
|
101 |
+
|
102 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
103 |
+
def retrieve_timesteps(
|
104 |
+
scheduler,
|
105 |
+
num_inference_steps: Optional[int] = None,
|
106 |
+
device: Optional[Union[str, torch.device]] = None,
|
107 |
+
timesteps: Optional[List[int]] = None,
|
108 |
+
sigmas: Optional[List[float]] = None,
|
109 |
+
**kwargs,
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
113 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
scheduler (`SchedulerMixin`):
|
117 |
+
The scheduler to get timesteps from.
|
118 |
+
num_inference_steps (`int`):
|
119 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
120 |
+
must be `None`.
|
121 |
+
device (`str` or `torch.device`, *optional*):
|
122 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
123 |
+
timesteps (`List[int]`, *optional*):
|
124 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
125 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
126 |
+
sigmas (`List[float]`, *optional*):
|
127 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
128 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
132 |
+
second element is the number of inference steps.
|
133 |
+
"""
|
134 |
+
if timesteps is not None and sigmas is not None:
|
135 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
136 |
+
if timesteps is not None:
|
137 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
138 |
+
if not accepts_timesteps:
|
139 |
+
raise ValueError(
|
140 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
141 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
142 |
+
)
|
143 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
144 |
+
timesteps = scheduler.timesteps
|
145 |
+
num_inference_steps = len(timesteps)
|
146 |
+
elif sigmas is not None:
|
147 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
148 |
+
if not accept_sigmas:
|
149 |
+
raise ValueError(
|
150 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
151 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
152 |
+
)
|
153 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
154 |
+
timesteps = scheduler.timesteps
|
155 |
+
num_inference_steps = len(timesteps)
|
156 |
+
else:
|
157 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
158 |
+
timesteps = scheduler.timesteps
|
159 |
+
return timesteps, num_inference_steps
|
160 |
+
|
161 |
+
|
162 |
+
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
163 |
+
r"""
|
164 |
+
The Flux pipeline for image inpainting.
|
165 |
+
|
166 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
167 |
+
|
168 |
+
Args:
|
169 |
+
transformer ([`FluxTransformer2DModel`]):
|
170 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
171 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
172 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
173 |
+
vae ([`AutoencoderKL`]):
|
174 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
175 |
+
text_encoder ([`CLIPTextModel`]):
|
176 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
177 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
178 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
179 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
180 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
181 |
+
tokenizer (`CLIPTokenizer`):
|
182 |
+
Tokenizer of class
|
183 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
184 |
+
tokenizer_2 (`T5TokenizerFast`):
|
185 |
+
Second Tokenizer of class
|
186 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
187 |
+
"""
|
188 |
+
|
189 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
190 |
+
_optional_components = []
|
191 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
192 |
+
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
196 |
+
vae: AutoencoderKL,
|
197 |
+
text_encoder: CLIPTextModel,
|
198 |
+
tokenizer: CLIPTokenizer,
|
199 |
+
transformer: FluxTransformer2DModel,
|
200 |
+
text_encoder_2: T5EncoderModel | None = None,
|
201 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.register_modules(
|
206 |
+
vae=vae,
|
207 |
+
text_encoder=text_encoder,
|
208 |
+
#text_encoder_2=text_encoder_2,
|
209 |
+
tokenizer=tokenizer,
|
210 |
+
#tokenizer_2=tokenizer_2,
|
211 |
+
transformer=transformer,
|
212 |
+
scheduler=scheduler,
|
213 |
+
)
|
214 |
+
self.vae_scale_factor = (
|
215 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
216 |
+
)
|
217 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
218 |
+
self.tokenizer_max_length = (
|
219 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
220 |
+
)
|
221 |
+
self.default_sample_size = 64
|
222 |
+
|
223 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
224 |
+
def _get_t5_prompt_embeds(
|
225 |
+
self,
|
226 |
+
prompt: Union[str, List[str]] = None,
|
227 |
+
num_images_per_prompt: int = 1,
|
228 |
+
max_sequence_length: int = 512,
|
229 |
+
device: Optional[torch.device] = None,
|
230 |
+
dtype: Optional[torch.dtype] = None,
|
231 |
+
):
|
232 |
+
device = device or self._execution_device
|
233 |
+
dtype = dtype or self.text_encoder.dtype
|
234 |
+
|
235 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
236 |
+
batch_size = len(prompt)
|
237 |
+
|
238 |
+
text_inputs = self.tokenizer_2(
|
239 |
+
prompt,
|
240 |
+
padding="max_length",
|
241 |
+
max_length=max_sequence_length,
|
242 |
+
truncation=True,
|
243 |
+
return_length=False,
|
244 |
+
return_overflowing_tokens=False,
|
245 |
+
return_tensors="pt",
|
246 |
+
)
|
247 |
+
text_input_ids = text_inputs.input_ids
|
248 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
249 |
+
|
250 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
251 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
252 |
+
logger.warning(
|
253 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
254 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
255 |
+
)
|
256 |
+
|
257 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
258 |
+
|
259 |
+
dtype = self.text_encoder_2.dtype
|
260 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
261 |
+
|
262 |
+
_, seq_len, _ = prompt_embeds.shape
|
263 |
+
|
264 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
265 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
266 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
267 |
+
|
268 |
+
return prompt_embeds
|
269 |
+
|
270 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
271 |
+
def _get_clip_prompt_embeds(
|
272 |
+
self,
|
273 |
+
prompt: Union[str, List[str]],
|
274 |
+
num_images_per_prompt: int = 1,
|
275 |
+
device: Optional[torch.device] = None,
|
276 |
+
):
|
277 |
+
device = device or self._execution_device
|
278 |
+
|
279 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
280 |
+
batch_size = len(prompt)
|
281 |
+
|
282 |
+
text_inputs = self.tokenizer(
|
283 |
+
prompt,
|
284 |
+
padding="max_length",
|
285 |
+
max_length=self.tokenizer_max_length,
|
286 |
+
truncation=True,
|
287 |
+
return_overflowing_tokens=False,
|
288 |
+
return_length=False,
|
289 |
+
return_tensors="pt",
|
290 |
+
)
|
291 |
+
|
292 |
+
text_input_ids = text_inputs.input_ids
|
293 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
294 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
295 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
296 |
+
logger.warning(
|
297 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
298 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
299 |
+
)
|
300 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
301 |
+
|
302 |
+
# Use pooled output of CLIPTextModel
|
303 |
+
prompt_embeds = prompt_embeds.pooler_output
|
304 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
305 |
+
|
306 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
307 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
308 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
309 |
+
|
310 |
+
return prompt_embeds
|
311 |
+
|
312 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
313 |
+
def encode_prompt(
|
314 |
+
self,
|
315 |
+
prompt: Union[str, List[str]],
|
316 |
+
prompt_2: Union[str, List[str]],
|
317 |
+
device: Optional[torch.device] = None,
|
318 |
+
num_images_per_prompt: int = 1,
|
319 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
320 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
321 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
322 |
+
max_sequence_length: int = 512,
|
323 |
+
lora_scale: Optional[float] = None,
|
324 |
+
):
|
325 |
+
r"""
|
326 |
+
|
327 |
+
Args:
|
328 |
+
prompt (`str` or `List[str]`, *optional*):
|
329 |
+
prompt to be encoded
|
330 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
331 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
332 |
+
used in all text-encoders
|
333 |
+
device: (`torch.device`):
|
334 |
+
torch device
|
335 |
+
num_images_per_prompt (`int`):
|
336 |
+
number of images that should be generated per prompt
|
337 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
338 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
339 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
340 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
341 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
342 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
343 |
+
lora_scale (`float`, *optional*):
|
344 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
345 |
+
"""
|
346 |
+
device = device or self._execution_device
|
347 |
+
|
348 |
+
# set lora scale so that monkey patched LoRA
|
349 |
+
# function of text encoder can correctly access it
|
350 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
351 |
+
self._lora_scale = lora_scale
|
352 |
+
|
353 |
+
# dynamically adjust the LoRA scale
|
354 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
355 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
356 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
357 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
358 |
+
|
359 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
360 |
+
if prompt is not None:
|
361 |
+
batch_size = len(prompt)
|
362 |
+
else:
|
363 |
+
batch_size = prompt_embeds.shape[0]
|
364 |
+
|
365 |
+
if prompt_embeds is None:
|
366 |
+
prompt_2 = prompt_2 or prompt
|
367 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
368 |
+
|
369 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
370 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
371 |
+
prompt=prompt,
|
372 |
+
device=device,
|
373 |
+
num_images_per_prompt=num_images_per_prompt,
|
374 |
+
)
|
375 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
376 |
+
prompt=prompt_2,
|
377 |
+
num_images_per_prompt=num_images_per_prompt,
|
378 |
+
max_sequence_length=max_sequence_length,
|
379 |
+
device=device,
|
380 |
+
)
|
381 |
+
|
382 |
+
if self.text_encoder is not None:
|
383 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
384 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
385 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
386 |
+
|
387 |
+
#if self.text_encoder_2 is not None:
|
388 |
+
# if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
389 |
+
# # Retrieve the original scale by scaling back the LoRA layers
|
390 |
+
# unscale_lora_layers(self.text_encoder_2, lora_scale)
|
391 |
+
|
392 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
393 |
+
if t5_prompt_embeds is not None:
|
394 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
395 |
+
else:
|
396 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
397 |
+
#text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
398 |
+
|
399 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
400 |
+
|
401 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
402 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
403 |
+
if isinstance(generator, list):
|
404 |
+
image_latents = [
|
405 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
406 |
+
for i in range(image.shape[0])
|
407 |
+
]
|
408 |
+
image_latents = torch.cat(image_latents, dim=0)
|
409 |
+
else:
|
410 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
411 |
+
|
412 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
413 |
+
|
414 |
+
return image_latents
|
415 |
+
|
416 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
417 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
418 |
+
# get the original timestep using init_timestep
|
419 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
420 |
+
|
421 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
422 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
423 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
424 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
425 |
+
|
426 |
+
return timesteps, num_inference_steps - t_start
|
427 |
+
|
428 |
+
def check_inputs(
|
429 |
+
self,
|
430 |
+
prompt,
|
431 |
+
prompt_2,
|
432 |
+
strength,
|
433 |
+
height,
|
434 |
+
width,
|
435 |
+
prompt_embeds=None,
|
436 |
+
pooled_prompt_embeds=None,
|
437 |
+
callback_on_step_end_tensor_inputs=None,
|
438 |
+
max_sequence_length=None,
|
439 |
+
):
|
440 |
+
if strength < 0 or strength > 1:
|
441 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
442 |
+
|
443 |
+
if height % 8 != 0 or width % 8 != 0:
|
444 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
445 |
+
|
446 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
447 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
448 |
+
):
|
449 |
+
raise ValueError(
|
450 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
451 |
+
)
|
452 |
+
|
453 |
+
if prompt is not None and prompt_embeds is not None:
|
454 |
+
raise ValueError(
|
455 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
456 |
+
" only forward one of the two."
|
457 |
+
)
|
458 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
459 |
+
raise ValueError(
|
460 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
461 |
+
" only forward one of the two."
|
462 |
+
)
|
463 |
+
elif prompt is None and prompt_embeds is None:
|
464 |
+
raise ValueError(
|
465 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
466 |
+
)
|
467 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
468 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
469 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
470 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
471 |
+
|
472 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
473 |
+
raise ValueError(
|
474 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
475 |
+
)
|
476 |
+
|
477 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
478 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
479 |
+
|
480 |
+
@staticmethod
|
481 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
482 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
483 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
484 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
485 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
486 |
+
|
487 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
488 |
+
|
489 |
+
latent_image_ids = latent_image_ids.reshape(
|
490 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
491 |
+
)
|
492 |
+
|
493 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
494 |
+
|
495 |
+
@staticmethod
|
496 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
497 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
498 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
499 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
500 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
501 |
+
|
502 |
+
return latents
|
503 |
+
|
504 |
+
@staticmethod
|
505 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
506 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
507 |
+
batch_size, num_patches, channels = latents.shape
|
508 |
+
|
509 |
+
height = height // vae_scale_factor
|
510 |
+
width = width // vae_scale_factor
|
511 |
+
|
512 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
513 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
514 |
+
|
515 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
516 |
+
|
517 |
+
return latents
|
518 |
+
|
519 |
+
def prepare_latents(
|
520 |
+
self,
|
521 |
+
image,
|
522 |
+
timestep,
|
523 |
+
batch_size,
|
524 |
+
num_channels_latents,
|
525 |
+
height,
|
526 |
+
width,
|
527 |
+
dtype,
|
528 |
+
device,
|
529 |
+
generator,
|
530 |
+
latents=None,
|
531 |
+
):
|
532 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
533 |
+
raise ValueError(
|
534 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
535 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
536 |
+
)
|
537 |
+
|
538 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
539 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
540 |
+
|
541 |
+
shape = (batch_size, num_channels_latents, height, width)
|
542 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
543 |
+
|
544 |
+
if latents is not None:
|
545 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
546 |
+
|
547 |
+
image = image.to(device=device, dtype=dtype)
|
548 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
549 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
550 |
+
# expand init_latents for batch_size
|
551 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
552 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
553 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
554 |
+
raise ValueError(
|
555 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
556 |
+
)
|
557 |
+
else:
|
558 |
+
image_latents = torch.cat([image_latents], dim=0)
|
559 |
+
|
560 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
561 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
562 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
563 |
+
return latents, latent_image_ids
|
564 |
+
|
565 |
+
@property
|
566 |
+
def guidance_scale(self):
|
567 |
+
return self._guidance_scale
|
568 |
+
|
569 |
+
@property
|
570 |
+
def joint_attention_kwargs(self):
|
571 |
+
return self._joint_attention_kwargs
|
572 |
+
|
573 |
+
@property
|
574 |
+
def num_timesteps(self):
|
575 |
+
return self._num_timesteps
|
576 |
+
|
577 |
+
@property
|
578 |
+
def interrupt(self):
|
579 |
+
return self._interrupt
|
580 |
+
|
581 |
+
@torch.no_grad()
|
582 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
583 |
+
def __call__(
|
584 |
+
self,
|
585 |
+
prompt: Union[str, List[str]] = None,
|
586 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
587 |
+
image: PipelineImageInput = None,
|
588 |
+
height: Optional[int] = None,
|
589 |
+
width: Optional[int] = None,
|
590 |
+
strength: float = 0.6,
|
591 |
+
num_inference_steps: int = 28,
|
592 |
+
timesteps: List[int] = None,
|
593 |
+
guidance_scale: float = 7.0,
|
594 |
+
num_images_per_prompt: Optional[int] = 1,
|
595 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
596 |
+
latents: Optional[torch.FloatTensor] = None,
|
597 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
598 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
599 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
600 |
+
output_type: Optional[str] = "pil",
|
601 |
+
return_dict: bool = True,
|
602 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
603 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
604 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
605 |
+
max_sequence_length: int = 512,
|
606 |
+
):
|
607 |
+
r"""
|
608 |
+
Function invoked when calling the pipeline for generation.
|
609 |
+
|
610 |
+
Args:
|
611 |
+
prompt (`str` or `List[str]`, *optional*):
|
612 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
613 |
+
instead.
|
614 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
615 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
616 |
+
will be used instead
|
617 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
618 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
619 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
620 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
621 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
622 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
623 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
624 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
625 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
626 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
627 |
+
strength (`float`, *optional*, defaults to 1.0):
|
628 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
629 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
630 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
631 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
632 |
+
essentially ignores `image`.
|
633 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
634 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
635 |
+
expense of slower inference.
|
636 |
+
timesteps (`List[int]`, *optional*):
|
637 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
638 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
639 |
+
passed will be used. Must be in descending order.
|
640 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
641 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
642 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
643 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
644 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
645 |
+
usually at the expense of lower image quality.
|
646 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
647 |
+
The number of images to generate per prompt.
|
648 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
649 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
650 |
+
to make generation deterministic.
|
651 |
+
latents (`torch.FloatTensor`, *optional*):
|
652 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
653 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
654 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
655 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
656 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
657 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
658 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
659 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
660 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
661 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
662 |
+
The output format of the generate image. Choose between
|
663 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
664 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
665 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
666 |
+
joint_attention_kwargs (`dict`, *optional*):
|
667 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
668 |
+
`self.processor` in
|
669 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
670 |
+
callback_on_step_end (`Callable`, *optional*):
|
671 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
672 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
673 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
674 |
+
`callback_on_step_end_tensor_inputs`.
|
675 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
676 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
677 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
678 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
679 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
680 |
+
|
681 |
+
Examples:
|
682 |
+
|
683 |
+
Returns:
|
684 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
685 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
686 |
+
images.
|
687 |
+
"""
|
688 |
+
|
689 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
690 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
691 |
+
|
692 |
+
# 1. Check inputs. Raise error if not correct
|
693 |
+
self.check_inputs(
|
694 |
+
prompt,
|
695 |
+
prompt_2,
|
696 |
+
strength,
|
697 |
+
height,
|
698 |
+
width,
|
699 |
+
prompt_embeds=prompt_embeds,
|
700 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
701 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
702 |
+
max_sequence_length=max_sequence_length,
|
703 |
+
)
|
704 |
+
|
705 |
+
self._guidance_scale = guidance_scale
|
706 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
707 |
+
self._interrupt = False
|
708 |
+
|
709 |
+
# 2. Preprocess image
|
710 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
711 |
+
init_image = init_image.to(dtype=torch.float32)
|
712 |
+
|
713 |
+
# 3. Define call parameters
|
714 |
+
if prompt is not None and isinstance(prompt, str):
|
715 |
+
batch_size = 1
|
716 |
+
elif prompt is not None and isinstance(prompt, list):
|
717 |
+
batch_size = len(prompt)
|
718 |
+
else:
|
719 |
+
batch_size = prompt_embeds.shape[0]
|
720 |
+
|
721 |
+
device = self._execution_device
|
722 |
+
|
723 |
+
lora_scale = (
|
724 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
725 |
+
)
|
726 |
+
(
|
727 |
+
prompt_embeds,
|
728 |
+
pooled_prompt_embeds,
|
729 |
+
text_ids,
|
730 |
+
) = self.encode_prompt(
|
731 |
+
prompt=prompt,
|
732 |
+
prompt_2=prompt_2,
|
733 |
+
prompt_embeds=prompt_embeds,
|
734 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
735 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
736 |
+
device=device,
|
737 |
+
num_images_per_prompt=num_images_per_prompt,
|
738 |
+
max_sequence_length=max_sequence_length,
|
739 |
+
lora_scale=lora_scale,
|
740 |
+
)
|
741 |
+
|
742 |
+
# 4.Prepare timesteps
|
743 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
744 |
+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
745 |
+
mu = calculate_shift(
|
746 |
+
image_seq_len,
|
747 |
+
self.scheduler.config.base_image_seq_len,
|
748 |
+
self.scheduler.config.max_image_seq_len,
|
749 |
+
self.scheduler.config.base_shift,
|
750 |
+
self.scheduler.config.max_shift,
|
751 |
+
)
|
752 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
753 |
+
self.scheduler,
|
754 |
+
num_inference_steps,
|
755 |
+
device,
|
756 |
+
timesteps,
|
757 |
+
sigmas,
|
758 |
+
mu=mu,
|
759 |
+
)
|
760 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
761 |
+
|
762 |
+
if num_inference_steps < 1:
|
763 |
+
raise ValueError(
|
764 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
765 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
766 |
+
)
|
767 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
768 |
+
|
769 |
+
# 5. Prepare latent variables
|
770 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
771 |
+
|
772 |
+
latents, latent_image_ids = self.prepare_latents(
|
773 |
+
init_image,
|
774 |
+
latent_timestep,
|
775 |
+
batch_size * num_images_per_prompt,
|
776 |
+
num_channels_latents,
|
777 |
+
height,
|
778 |
+
width,
|
779 |
+
prompt_embeds.dtype,
|
780 |
+
device,
|
781 |
+
generator,
|
782 |
+
latents,
|
783 |
+
)
|
784 |
+
|
785 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
786 |
+
self._num_timesteps = len(timesteps)
|
787 |
+
|
788 |
+
# handle guidance
|
789 |
+
if self.transformer.config.guidance_embeds:
|
790 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
791 |
+
guidance = guidance.expand(latents.shape[0])
|
792 |
+
else:
|
793 |
+
guidance = None
|
794 |
+
|
795 |
+
# 6. Denoising loop
|
796 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
797 |
+
for i, t in enumerate(timesteps):
|
798 |
+
if self.interrupt:
|
799 |
+
continue
|
800 |
+
|
801 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
802 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
803 |
+
noise_pred = self.transformer(
|
804 |
+
hidden_states=latents,
|
805 |
+
timestep=timestep / 1000,
|
806 |
+
guidance=guidance,
|
807 |
+
pooled_projections=pooled_prompt_embeds,
|
808 |
+
encoder_hidden_states=prompt_embeds,
|
809 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
810 |
+
txt_ids=text_ids,
|
811 |
+
img_ids=latent_image_ids,
|
812 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
813 |
+
return_dict=False,
|
814 |
+
)[0]
|
815 |
+
|
816 |
+
# compute the previous noisy sample x_t -> x_t-1
|
817 |
+
latents_dtype = latents.dtype
|
818 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
819 |
+
|
820 |
+
if latents.dtype != latents_dtype:
|
821 |
+
if torch.backends.mps.is_available():
|
822 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
823 |
+
latents = latents.to(latents_dtype)
|
824 |
+
|
825 |
+
if callback_on_step_end is not None:
|
826 |
+
callback_kwargs = {}
|
827 |
+
for k in callback_on_step_end_tensor_inputs:
|
828 |
+
callback_kwargs[k] = locals()[k]
|
829 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
830 |
+
|
831 |
+
latents = callback_outputs.pop("latents", latents)
|
832 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
833 |
+
|
834 |
+
# call the callback, if provided
|
835 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
836 |
+
progress_bar.update()
|
837 |
+
|
838 |
+
if XLA_AVAILABLE:
|
839 |
+
xm.mark_step()
|
840 |
+
|
841 |
+
if output_type == "latent":
|
842 |
+
image = latents
|
843 |
+
|
844 |
+
else:
|
845 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
846 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
847 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
848 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
849 |
+
|
850 |
+
# Offload all models
|
851 |
+
self.maybe_free_model_hooks()
|
852 |
+
|
853 |
+
if not return_dict:
|
854 |
+
return (image,)
|
855 |
+
|
856 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_flux_inpaint.py
ADDED
@@ -0,0 +1,1021 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL.Image
|
20 |
+
import torch
|
21 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
22 |
+
|
23 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
24 |
+
from .lora.lora_pipeline import FluxLoraLoaderMixin
|
25 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
26 |
+
from .transformer_flux import FluxTransformer2DModel
|
27 |
+
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
28 |
+
from diffusers.utils import (
|
29 |
+
USE_PEFT_BACKEND,
|
30 |
+
is_torch_xla_available,
|
31 |
+
logging,
|
32 |
+
replace_example_docstring,
|
33 |
+
scale_lora_layers,
|
34 |
+
unscale_lora_layers,
|
35 |
+
)
|
36 |
+
from diffusers.utils.torch_utils import randn_tensor
|
37 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
38 |
+
from .pipeline_output import FluxPipelineOutput
|
39 |
+
|
40 |
+
|
41 |
+
if is_torch_xla_available():
|
42 |
+
import torch_xla.core.xla_model as xm
|
43 |
+
|
44 |
+
XLA_AVAILABLE = True
|
45 |
+
else:
|
46 |
+
XLA_AVAILABLE = False
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50 |
+
|
51 |
+
EXAMPLE_DOC_STRING = """
|
52 |
+
Examples:
|
53 |
+
```py
|
54 |
+
>>> import torch
|
55 |
+
>>> from diffusers import FluxInpaintPipeline
|
56 |
+
>>> from diffusers.utils import load_image
|
57 |
+
|
58 |
+
>>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
59 |
+
>>> pipe.to("cuda")
|
60 |
+
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
61 |
+
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
62 |
+
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
63 |
+
>>> source = load_image(img_url)
|
64 |
+
>>> mask = load_image(mask_url)
|
65 |
+
>>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
|
66 |
+
>>> image.save("flux_inpainting.png")
|
67 |
+
```
|
68 |
+
"""
|
69 |
+
|
70 |
+
|
71 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
72 |
+
def calculate_shift(
|
73 |
+
image_seq_len,
|
74 |
+
base_seq_len: int = 256,
|
75 |
+
max_seq_len: int = 4096,
|
76 |
+
base_shift: float = 0.5,
|
77 |
+
max_shift: float = 1.16,
|
78 |
+
):
|
79 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
80 |
+
b = base_shift - m * base_seq_len
|
81 |
+
mu = image_seq_len * m + b
|
82 |
+
return mu
|
83 |
+
|
84 |
+
|
85 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
86 |
+
def retrieve_latents(
|
87 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
88 |
+
):
|
89 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
90 |
+
return encoder_output.latent_dist.sample(generator)
|
91 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
92 |
+
return encoder_output.latent_dist.mode()
|
93 |
+
elif hasattr(encoder_output, "latents"):
|
94 |
+
return encoder_output.latents
|
95 |
+
else:
|
96 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
97 |
+
|
98 |
+
|
99 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
100 |
+
def retrieve_timesteps(
|
101 |
+
scheduler,
|
102 |
+
num_inference_steps: Optional[int] = None,
|
103 |
+
device: Optional[Union[str, torch.device]] = None,
|
104 |
+
timesteps: Optional[List[int]] = None,
|
105 |
+
sigmas: Optional[List[float]] = None,
|
106 |
+
**kwargs,
|
107 |
+
):
|
108 |
+
"""
|
109 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
110 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
scheduler (`SchedulerMixin`):
|
114 |
+
The scheduler to get timesteps from.
|
115 |
+
num_inference_steps (`int`):
|
116 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
117 |
+
must be `None`.
|
118 |
+
device (`str` or `torch.device`, *optional*):
|
119 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
120 |
+
timesteps (`List[int]`, *optional*):
|
121 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
122 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
123 |
+
sigmas (`List[float]`, *optional*):
|
124 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
125 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
129 |
+
second element is the number of inference steps.
|
130 |
+
"""
|
131 |
+
if timesteps is not None and sigmas is not None:
|
132 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
133 |
+
if timesteps is not None:
|
134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
135 |
+
if not accepts_timesteps:
|
136 |
+
raise ValueError(
|
137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
139 |
+
)
|
140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
141 |
+
timesteps = scheduler.timesteps
|
142 |
+
num_inference_steps = len(timesteps)
|
143 |
+
elif sigmas is not None:
|
144 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
145 |
+
if not accept_sigmas:
|
146 |
+
raise ValueError(
|
147 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
148 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
149 |
+
)
|
150 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
151 |
+
timesteps = scheduler.timesteps
|
152 |
+
num_inference_steps = len(timesteps)
|
153 |
+
else:
|
154 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
155 |
+
timesteps = scheduler.timesteps
|
156 |
+
return timesteps, num_inference_steps
|
157 |
+
|
158 |
+
|
159 |
+
class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
160 |
+
r"""
|
161 |
+
The Flux pipeline for image inpainting.
|
162 |
+
|
163 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
164 |
+
|
165 |
+
Args:
|
166 |
+
transformer ([`FluxTransformer2DModel`]):
|
167 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
168 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
169 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
170 |
+
vae ([`AutoencoderKL`]):
|
171 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
172 |
+
text_encoder ([`CLIPTextModel`]):
|
173 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
174 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
175 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
176 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
177 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
178 |
+
tokenizer (`CLIPTokenizer`):
|
179 |
+
Tokenizer of class
|
180 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
181 |
+
tokenizer_2 (`T5TokenizerFast`):
|
182 |
+
Second Tokenizer of class
|
183 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
184 |
+
"""
|
185 |
+
|
186 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
187 |
+
_optional_components = []
|
188 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
193 |
+
vae: AutoencoderKL,
|
194 |
+
text_encoder: CLIPTextModel,
|
195 |
+
tokenizer: CLIPTokenizer,
|
196 |
+
transformer: FluxTransformer2DModel,
|
197 |
+
text_encoder_2: T5EncoderModel | None = None,
|
198 |
+
tokenizer_2: T5TokenizerFast | None = None,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
|
202 |
+
self.register_modules(
|
203 |
+
vae=vae,
|
204 |
+
text_encoder=text_encoder,
|
205 |
+
#text_encoder_2=text_encoder_2,
|
206 |
+
tokenizer=tokenizer,
|
207 |
+
#tokenizer_2=tokenizer_2,
|
208 |
+
transformer=transformer,
|
209 |
+
scheduler=scheduler,
|
210 |
+
)
|
211 |
+
self.vae_scale_factor = (
|
212 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
213 |
+
)
|
214 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
215 |
+
self.mask_processor = VaeImageProcessor(
|
216 |
+
vae_scale_factor=self.vae_scale_factor,
|
217 |
+
vae_latent_channels=self.vae.config.latent_channels,
|
218 |
+
do_normalize=False,
|
219 |
+
do_binarize=True,
|
220 |
+
do_convert_grayscale=True,
|
221 |
+
)
|
222 |
+
self.tokenizer_max_length = (
|
223 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
224 |
+
)
|
225 |
+
self.default_sample_size = 64
|
226 |
+
|
227 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
228 |
+
def _get_t5_prompt_embeds(
|
229 |
+
self,
|
230 |
+
prompt: Union[str, List[str]] = None,
|
231 |
+
num_images_per_prompt: int = 1,
|
232 |
+
max_sequence_length: int = 512,
|
233 |
+
device: Optional[torch.device] = None,
|
234 |
+
dtype: Optional[torch.dtype] = None,
|
235 |
+
):
|
236 |
+
device = device or self._execution_device
|
237 |
+
dtype = dtype or self.text_encoder.dtype
|
238 |
+
|
239 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
240 |
+
batch_size = len(prompt)
|
241 |
+
|
242 |
+
text_inputs = self.tokenizer_2(
|
243 |
+
prompt,
|
244 |
+
padding="max_length",
|
245 |
+
max_length=max_sequence_length,
|
246 |
+
truncation=True,
|
247 |
+
return_length=False,
|
248 |
+
return_overflowing_tokens=False,
|
249 |
+
return_tensors="pt",
|
250 |
+
)
|
251 |
+
text_input_ids = text_inputs.input_ids
|
252 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
253 |
+
|
254 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
255 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
256 |
+
logger.warning(
|
257 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
258 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
259 |
+
)
|
260 |
+
|
261 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
262 |
+
|
263 |
+
dtype = self.text_encoder_2.dtype
|
264 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
265 |
+
|
266 |
+
_, seq_len, _ = prompt_embeds.shape
|
267 |
+
|
268 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
269 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
270 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
271 |
+
|
272 |
+
return prompt_embeds
|
273 |
+
|
274 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
275 |
+
def _get_clip_prompt_embeds(
|
276 |
+
self,
|
277 |
+
prompt: Union[str, List[str]],
|
278 |
+
num_images_per_prompt: int = 1,
|
279 |
+
device: Optional[torch.device] = None,
|
280 |
+
):
|
281 |
+
device = device or self._execution_device
|
282 |
+
|
283 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
284 |
+
batch_size = len(prompt)
|
285 |
+
|
286 |
+
text_inputs = self.tokenizer(
|
287 |
+
prompt,
|
288 |
+
padding="max_length",
|
289 |
+
max_length=self.tokenizer_max_length,
|
290 |
+
truncation=True,
|
291 |
+
return_overflowing_tokens=False,
|
292 |
+
return_length=False,
|
293 |
+
return_tensors="pt",
|
294 |
+
)
|
295 |
+
|
296 |
+
text_input_ids = text_inputs.input_ids
|
297 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
298 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
299 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
300 |
+
logger.warning(
|
301 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
302 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
303 |
+
)
|
304 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
305 |
+
|
306 |
+
# Use pooled output of CLIPTextModel
|
307 |
+
prompt_embeds = prompt_embeds.pooler_output
|
308 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
309 |
+
|
310 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
311 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
312 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
313 |
+
|
314 |
+
return prompt_embeds
|
315 |
+
|
316 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
317 |
+
def encode_prompt(
|
318 |
+
self,
|
319 |
+
prompt: Union[str, List[str]],
|
320 |
+
prompt_2: Union[str, List[str]],
|
321 |
+
device: Optional[torch.device] = None,
|
322 |
+
num_images_per_prompt: int = 1,
|
323 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
324 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
325 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
326 |
+
max_sequence_length: int = 512,
|
327 |
+
lora_scale: Optional[float] = None,
|
328 |
+
):
|
329 |
+
r"""
|
330 |
+
|
331 |
+
Args:
|
332 |
+
prompt (`str` or `List[str]`, *optional*):
|
333 |
+
prompt to be encoded
|
334 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
335 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
336 |
+
used in all text-encoders
|
337 |
+
device: (`torch.device`):
|
338 |
+
torch device
|
339 |
+
num_images_per_prompt (`int`):
|
340 |
+
number of images that should be generated per prompt
|
341 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
342 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
343 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
344 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
345 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
346 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
347 |
+
lora_scale (`float`, *optional*):
|
348 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
349 |
+
"""
|
350 |
+
device = device or self._execution_device
|
351 |
+
|
352 |
+
# set lora scale so that monkey patched LoRA
|
353 |
+
# function of text encoder can correctly access it
|
354 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
355 |
+
self._lora_scale = lora_scale
|
356 |
+
|
357 |
+
# dynamically adjust the LoRA scale
|
358 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
359 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
360 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
361 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
362 |
+
|
363 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
364 |
+
if prompt is not None:
|
365 |
+
batch_size = len(prompt)
|
366 |
+
else:
|
367 |
+
batch_size = prompt_embeds.shape[0]
|
368 |
+
|
369 |
+
if prompt_embeds is None:
|
370 |
+
prompt_2 = prompt_2 or prompt
|
371 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
372 |
+
|
373 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
374 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
375 |
+
prompt=prompt,
|
376 |
+
device=device,
|
377 |
+
num_images_per_prompt=num_images_per_prompt,
|
378 |
+
)
|
379 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
380 |
+
prompt=prompt_2,
|
381 |
+
num_images_per_prompt=num_images_per_prompt,
|
382 |
+
max_sequence_length=max_sequence_length,
|
383 |
+
device=device,
|
384 |
+
)
|
385 |
+
|
386 |
+
if self.text_encoder is not None:
|
387 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
388 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
389 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
390 |
+
|
391 |
+
#if self.text_encoder_2 is not None:
|
392 |
+
# if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
393 |
+
# # Retrieve the original scale by scaling back the LoRA layers
|
394 |
+
# unscale_lora_layers(self.text_encoder_2, lora_scale)
|
395 |
+
|
396 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
397 |
+
if t5_prompt_embeds is not None:
|
398 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
399 |
+
else:
|
400 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
401 |
+
#text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
402 |
+
|
403 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
404 |
+
|
405 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
406 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
407 |
+
if isinstance(generator, list):
|
408 |
+
image_latents = [
|
409 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
410 |
+
for i in range(image.shape[0])
|
411 |
+
]
|
412 |
+
image_latents = torch.cat(image_latents, dim=0)
|
413 |
+
else:
|
414 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
415 |
+
|
416 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
417 |
+
|
418 |
+
return image_latents
|
419 |
+
|
420 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
421 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
422 |
+
# get the original timestep using init_timestep
|
423 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
424 |
+
|
425 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
426 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
427 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
428 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
429 |
+
|
430 |
+
return timesteps, num_inference_steps - t_start
|
431 |
+
|
432 |
+
def check_inputs(
|
433 |
+
self,
|
434 |
+
prompt,
|
435 |
+
prompt_2,
|
436 |
+
image,
|
437 |
+
mask_image,
|
438 |
+
strength,
|
439 |
+
height,
|
440 |
+
width,
|
441 |
+
output_type,
|
442 |
+
prompt_embeds=None,
|
443 |
+
pooled_prompt_embeds=None,
|
444 |
+
callback_on_step_end_tensor_inputs=None,
|
445 |
+
padding_mask_crop=None,
|
446 |
+
max_sequence_length=None,
|
447 |
+
):
|
448 |
+
if strength < 0 or strength > 1:
|
449 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
450 |
+
|
451 |
+
if height % 8 != 0 or width % 8 != 0:
|
452 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
453 |
+
|
454 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
455 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
456 |
+
):
|
457 |
+
raise ValueError(
|
458 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
459 |
+
)
|
460 |
+
|
461 |
+
if prompt is not None and prompt_embeds is not None:
|
462 |
+
raise ValueError(
|
463 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
464 |
+
" only forward one of the two."
|
465 |
+
)
|
466 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
467 |
+
raise ValueError(
|
468 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
469 |
+
" only forward one of the two."
|
470 |
+
)
|
471 |
+
elif prompt is None and prompt_embeds is None:
|
472 |
+
raise ValueError(
|
473 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
474 |
+
)
|
475 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
476 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
477 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
478 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
479 |
+
|
480 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
481 |
+
raise ValueError(
|
482 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
483 |
+
)
|
484 |
+
|
485 |
+
if padding_mask_crop is not None:
|
486 |
+
if not isinstance(image, PIL.Image.Image):
|
487 |
+
raise ValueError(
|
488 |
+
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
489 |
+
)
|
490 |
+
if not isinstance(mask_image, PIL.Image.Image):
|
491 |
+
raise ValueError(
|
492 |
+
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
493 |
+
f" {type(mask_image)}."
|
494 |
+
)
|
495 |
+
if output_type != "pil":
|
496 |
+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
497 |
+
|
498 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
499 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
503 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
504 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
505 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
506 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
507 |
+
|
508 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
509 |
+
|
510 |
+
latent_image_ids = latent_image_ids.reshape(
|
511 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
512 |
+
)
|
513 |
+
|
514 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
515 |
+
|
516 |
+
@staticmethod
|
517 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
518 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
519 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
520 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
521 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
522 |
+
|
523 |
+
return latents
|
524 |
+
|
525 |
+
@staticmethod
|
526 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
527 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
528 |
+
batch_size, num_patches, channels = latents.shape
|
529 |
+
|
530 |
+
height = height // vae_scale_factor
|
531 |
+
width = width // vae_scale_factor
|
532 |
+
|
533 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
534 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
535 |
+
|
536 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
537 |
+
|
538 |
+
return latents
|
539 |
+
|
540 |
+
def prepare_latents(
|
541 |
+
self,
|
542 |
+
image,
|
543 |
+
timestep,
|
544 |
+
batch_size,
|
545 |
+
num_channels_latents,
|
546 |
+
height,
|
547 |
+
width,
|
548 |
+
dtype,
|
549 |
+
device,
|
550 |
+
generator,
|
551 |
+
latents=None,
|
552 |
+
):
|
553 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
554 |
+
raise ValueError(
|
555 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
556 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
557 |
+
)
|
558 |
+
|
559 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
560 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
561 |
+
|
562 |
+
shape = (batch_size, num_channels_latents, height, width)
|
563 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
564 |
+
|
565 |
+
image = image.to(device=device, dtype=dtype)
|
566 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
567 |
+
|
568 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
569 |
+
# expand init_latents for batch_size
|
570 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
571 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
572 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
573 |
+
raise ValueError(
|
574 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
image_latents = torch.cat([image_latents], dim=0)
|
578 |
+
|
579 |
+
if latents is None:
|
580 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
581 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
582 |
+
else:
|
583 |
+
noise = latents.to(device)
|
584 |
+
latents = noise
|
585 |
+
|
586 |
+
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
|
587 |
+
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
|
588 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
589 |
+
return latents, noise, image_latents, latent_image_ids
|
590 |
+
|
591 |
+
def prepare_mask_latents(
|
592 |
+
self,
|
593 |
+
mask,
|
594 |
+
masked_image,
|
595 |
+
batch_size,
|
596 |
+
num_channels_latents,
|
597 |
+
num_images_per_prompt,
|
598 |
+
height,
|
599 |
+
width,
|
600 |
+
dtype,
|
601 |
+
device,
|
602 |
+
generator,
|
603 |
+
):
|
604 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
605 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
606 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
607 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
608 |
+
# and half precision
|
609 |
+
mask = torch.nn.functional.interpolate(mask, size=(height, width))
|
610 |
+
mask = mask.to(device=device, dtype=dtype)
|
611 |
+
|
612 |
+
batch_size = batch_size * num_images_per_prompt
|
613 |
+
|
614 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
615 |
+
|
616 |
+
if masked_image.shape[1] == 16:
|
617 |
+
masked_image_latents = masked_image
|
618 |
+
else:
|
619 |
+
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
620 |
+
|
621 |
+
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
622 |
+
|
623 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
624 |
+
if mask.shape[0] < batch_size:
|
625 |
+
if not batch_size % mask.shape[0] == 0:
|
626 |
+
raise ValueError(
|
627 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
628 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
629 |
+
" of masks that you pass is divisible by the total requested batch size."
|
630 |
+
)
|
631 |
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
632 |
+
if masked_image_latents.shape[0] < batch_size:
|
633 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
634 |
+
raise ValueError(
|
635 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
636 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
637 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
638 |
+
)
|
639 |
+
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
640 |
+
|
641 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
642 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
643 |
+
|
644 |
+
masked_image_latents = self._pack_latents(
|
645 |
+
masked_image_latents,
|
646 |
+
batch_size,
|
647 |
+
num_channels_latents,
|
648 |
+
height,
|
649 |
+
width,
|
650 |
+
)
|
651 |
+
mask = self._pack_latents(
|
652 |
+
mask.repeat(1, num_channels_latents, 1, 1),
|
653 |
+
batch_size,
|
654 |
+
num_channels_latents,
|
655 |
+
height,
|
656 |
+
width,
|
657 |
+
)
|
658 |
+
|
659 |
+
return mask, masked_image_latents
|
660 |
+
|
661 |
+
@property
|
662 |
+
def guidance_scale(self):
|
663 |
+
return self._guidance_scale
|
664 |
+
|
665 |
+
@property
|
666 |
+
def joint_attention_kwargs(self):
|
667 |
+
return self._joint_attention_kwargs
|
668 |
+
|
669 |
+
@property
|
670 |
+
def num_timesteps(self):
|
671 |
+
return self._num_timesteps
|
672 |
+
|
673 |
+
@property
|
674 |
+
def interrupt(self):
|
675 |
+
return self._interrupt
|
676 |
+
|
677 |
+
@torch.no_grad()
|
678 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
679 |
+
def __call__(
|
680 |
+
self,
|
681 |
+
prompt: Union[str, List[str]] = None,
|
682 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
683 |
+
image: PipelineImageInput = None,
|
684 |
+
mask_image: PipelineImageInput = None,
|
685 |
+
masked_image_latents: PipelineImageInput = None,
|
686 |
+
height: Optional[int] = None,
|
687 |
+
width: Optional[int] = None,
|
688 |
+
padding_mask_crop: Optional[int] = None,
|
689 |
+
strength: float = 0.6,
|
690 |
+
num_inference_steps: int = 28,
|
691 |
+
timesteps: List[int] = None,
|
692 |
+
guidance_scale: float = 7.0,
|
693 |
+
num_images_per_prompt: Optional[int] = 1,
|
694 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
695 |
+
latents: Optional[torch.FloatTensor] = None,
|
696 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
697 |
+
t5_prompt_embeds: Optional[torch.FloatTensor] = None,
|
698 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
699 |
+
output_type: Optional[str] = "pil",
|
700 |
+
return_dict: bool = True,
|
701 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
702 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
703 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
704 |
+
max_sequence_length: int = 512,
|
705 |
+
):
|
706 |
+
r"""
|
707 |
+
Function invoked when calling the pipeline for generation.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
prompt (`str` or `List[str]`, *optional*):
|
711 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
712 |
+
instead.
|
713 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
714 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
715 |
+
will be used instead
|
716 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
717 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
718 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
719 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
720 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
721 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
722 |
+
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
723 |
+
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
724 |
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
725 |
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
726 |
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
727 |
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
728 |
+
1)`, or `(H, W)`.
|
729 |
+
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
|
730 |
+
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
|
731 |
+
latents tensor will ge generated by `mask_image`.
|
732 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
733 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
734 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
735 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
736 |
+
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
737 |
+
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
|
738 |
+
image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
|
739 |
+
with the same aspect ration of the image and contains all masked area, and then expand that area based
|
740 |
+
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
741 |
+
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
742 |
+
the image is large and contain information irrelevant for inpainting, such as background.
|
743 |
+
strength (`float`, *optional*, defaults to 1.0):
|
744 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
745 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
746 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
747 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
748 |
+
essentially ignores `image`.
|
749 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
750 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
751 |
+
expense of slower inference.
|
752 |
+
timesteps (`List[int]`, *optional*):
|
753 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
754 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
755 |
+
passed will be used. Must be in descending order.
|
756 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
757 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
758 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
759 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
760 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
761 |
+
usually at the expense of lower image quality.
|
762 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
763 |
+
The number of images to generate per prompt.
|
764 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
765 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
766 |
+
to make generation deterministic.
|
767 |
+
latents (`torch.FloatTensor`, *optional*):
|
768 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
769 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
770 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
771 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
772 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
773 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
774 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
775 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
776 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
777 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
778 |
+
The output format of the generate image. Choose between
|
779 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
780 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
781 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
782 |
+
joint_attention_kwargs (`dict`, *optional*):
|
783 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
784 |
+
`self.processor` in
|
785 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
786 |
+
callback_on_step_end (`Callable`, *optional*):
|
787 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
788 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
789 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
790 |
+
`callback_on_step_end_tensor_inputs`.
|
791 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
792 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
793 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
794 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
795 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
796 |
+
|
797 |
+
Examples:
|
798 |
+
|
799 |
+
Returns:
|
800 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
801 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
802 |
+
images.
|
803 |
+
"""
|
804 |
+
|
805 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
806 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
807 |
+
|
808 |
+
# 1. Check inputs. Raise error if not correct
|
809 |
+
self.check_inputs(
|
810 |
+
prompt,
|
811 |
+
prompt_2,
|
812 |
+
image,
|
813 |
+
mask_image,
|
814 |
+
strength,
|
815 |
+
height,
|
816 |
+
width,
|
817 |
+
output_type=output_type,
|
818 |
+
prompt_embeds=prompt_embeds,
|
819 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
820 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
821 |
+
padding_mask_crop=padding_mask_crop,
|
822 |
+
max_sequence_length=max_sequence_length,
|
823 |
+
)
|
824 |
+
|
825 |
+
self._guidance_scale = guidance_scale
|
826 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
827 |
+
self._interrupt = False
|
828 |
+
|
829 |
+
# 2. Preprocess mask and image
|
830 |
+
if padding_mask_crop is not None:
|
831 |
+
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
832 |
+
resize_mode = "fill"
|
833 |
+
else:
|
834 |
+
crops_coords = None
|
835 |
+
resize_mode = "default"
|
836 |
+
|
837 |
+
original_image = image
|
838 |
+
init_image = self.image_processor.preprocess(
|
839 |
+
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
840 |
+
)
|
841 |
+
init_image = init_image.to(dtype=torch.float32)
|
842 |
+
|
843 |
+
# 3. Define call parameters
|
844 |
+
if prompt is not None and isinstance(prompt, str):
|
845 |
+
batch_size = 1
|
846 |
+
elif prompt is not None and isinstance(prompt, list):
|
847 |
+
batch_size = len(prompt)
|
848 |
+
else:
|
849 |
+
batch_size = prompt_embeds.shape[0]
|
850 |
+
|
851 |
+
device = self._execution_device
|
852 |
+
|
853 |
+
lora_scale = (
|
854 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
855 |
+
)
|
856 |
+
(
|
857 |
+
prompt_embeds,
|
858 |
+
pooled_prompt_embeds,
|
859 |
+
text_ids,
|
860 |
+
) = self.encode_prompt(
|
861 |
+
prompt=prompt,
|
862 |
+
prompt_2=prompt_2,
|
863 |
+
prompt_embeds=prompt_embeds,
|
864 |
+
t5_prompt_embeds=t5_prompt_embeds,
|
865 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
866 |
+
device=device,
|
867 |
+
num_images_per_prompt=num_images_per_prompt,
|
868 |
+
max_sequence_length=max_sequence_length,
|
869 |
+
lora_scale=lora_scale,
|
870 |
+
)
|
871 |
+
|
872 |
+
# 4.Prepare timesteps
|
873 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
874 |
+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
875 |
+
mu = calculate_shift(
|
876 |
+
image_seq_len,
|
877 |
+
self.scheduler.config.base_image_seq_len,
|
878 |
+
self.scheduler.config.max_image_seq_len,
|
879 |
+
self.scheduler.config.base_shift,
|
880 |
+
self.scheduler.config.max_shift,
|
881 |
+
)
|
882 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
883 |
+
self.scheduler,
|
884 |
+
num_inference_steps,
|
885 |
+
device,
|
886 |
+
timesteps,
|
887 |
+
sigmas,
|
888 |
+
mu=mu,
|
889 |
+
)
|
890 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
891 |
+
|
892 |
+
if num_inference_steps < 1:
|
893 |
+
raise ValueError(
|
894 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
895 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
896 |
+
)
|
897 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
898 |
+
|
899 |
+
# 5. Prepare latent variables
|
900 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
901 |
+
num_channels_transformer = self.transformer.config.in_channels
|
902 |
+
|
903 |
+
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
|
904 |
+
init_image,
|
905 |
+
latent_timestep,
|
906 |
+
batch_size * num_images_per_prompt,
|
907 |
+
num_channels_latents,
|
908 |
+
height,
|
909 |
+
width,
|
910 |
+
prompt_embeds.dtype,
|
911 |
+
device,
|
912 |
+
generator,
|
913 |
+
latents,
|
914 |
+
)
|
915 |
+
|
916 |
+
mask_condition = self.mask_processor.preprocess(
|
917 |
+
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
918 |
+
)
|
919 |
+
|
920 |
+
if masked_image_latents is None:
|
921 |
+
masked_image = init_image * (mask_condition < 0.5)
|
922 |
+
else:
|
923 |
+
masked_image = masked_image_latents
|
924 |
+
|
925 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
926 |
+
mask_condition,
|
927 |
+
masked_image,
|
928 |
+
batch_size,
|
929 |
+
num_channels_latents,
|
930 |
+
num_images_per_prompt,
|
931 |
+
height,
|
932 |
+
width,
|
933 |
+
prompt_embeds.dtype,
|
934 |
+
device,
|
935 |
+
generator,
|
936 |
+
)
|
937 |
+
|
938 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
939 |
+
self._num_timesteps = len(timesteps)
|
940 |
+
|
941 |
+
# handle guidance
|
942 |
+
if self.transformer.config.guidance_embeds:
|
943 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
944 |
+
guidance = guidance.expand(latents.shape[0])
|
945 |
+
else:
|
946 |
+
guidance = None
|
947 |
+
|
948 |
+
# 6. Denoising loop
|
949 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
950 |
+
for i, t in enumerate(timesteps):
|
951 |
+
if self.interrupt:
|
952 |
+
continue
|
953 |
+
|
954 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
955 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
956 |
+
noise_pred = self.transformer(
|
957 |
+
hidden_states=latents,
|
958 |
+
timestep=timestep / 1000,
|
959 |
+
guidance=guidance,
|
960 |
+
pooled_projections=pooled_prompt_embeds,
|
961 |
+
encoder_hidden_states=prompt_embeds,
|
962 |
+
t5_encoder_hidden_states=t5_prompt_embeds,
|
963 |
+
txt_ids=text_ids,
|
964 |
+
img_ids=latent_image_ids,
|
965 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
966 |
+
return_dict=False,
|
967 |
+
)[0]
|
968 |
+
|
969 |
+
# compute the previous noisy sample x_t -> x_t-1
|
970 |
+
latents_dtype = latents.dtype
|
971 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
972 |
+
|
973 |
+
# for 64 channel transformer only.
|
974 |
+
init_latents_proper = image_latents
|
975 |
+
init_mask = mask
|
976 |
+
|
977 |
+
if i < len(timesteps) - 1:
|
978 |
+
noise_timestep = timesteps[i + 1]
|
979 |
+
init_latents_proper = self.scheduler.scale_noise(
|
980 |
+
init_latents_proper, torch.tensor([noise_timestep]), noise
|
981 |
+
)
|
982 |
+
|
983 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
984 |
+
|
985 |
+
if latents.dtype != latents_dtype:
|
986 |
+
if torch.backends.mps.is_available():
|
987 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
988 |
+
latents = latents.to(latents_dtype)
|
989 |
+
|
990 |
+
if callback_on_step_end is not None:
|
991 |
+
callback_kwargs = {}
|
992 |
+
for k in callback_on_step_end_tensor_inputs:
|
993 |
+
callback_kwargs[k] = locals()[k]
|
994 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
995 |
+
|
996 |
+
latents = callback_outputs.pop("latents", latents)
|
997 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
998 |
+
|
999 |
+
# call the callback, if provided
|
1000 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1001 |
+
progress_bar.update()
|
1002 |
+
|
1003 |
+
if XLA_AVAILABLE:
|
1004 |
+
xm.mark_step()
|
1005 |
+
|
1006 |
+
if output_type == "latent":
|
1007 |
+
image = latents
|
1008 |
+
|
1009 |
+
else:
|
1010 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
1011 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1012 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1013 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1014 |
+
|
1015 |
+
# Offload all models
|
1016 |
+
self.maybe_free_model_hooks()
|
1017 |
+
|
1018 |
+
if not return_dict:
|
1019 |
+
return (image,)
|
1020 |
+
|
1021 |
+
return FluxPipelineOutput(images=image)
|
flux/pipeline_output.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
|
7 |
+
from diffusers.utils import BaseOutput
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class FluxPipelineOutput(BaseOutput):
|
12 |
+
"""
|
13 |
+
Output class for Stable Diffusion pipelines.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
17 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
18 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
19 |
+
"""
|
20 |
+
|
21 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
flux/scheduling_flow_match_euler_discrete.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Output class for the scheduler's `step` function output.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
"""
|
40 |
+
|
41 |
+
prev_sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
Euler scheduler.
|
47 |
+
|
48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
49 |
+
methods the library implements for all schedulers such as loading and saving.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
53 |
+
The number of diffusion steps to train the model.
|
54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
57 |
+
shift (`float`, defaults to 1.0):
|
58 |
+
The shift value for the timestep schedule.
|
59 |
+
"""
|
60 |
+
|
61 |
+
_compatibles = []
|
62 |
+
order = 1
|
63 |
+
|
64 |
+
@register_to_config
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
num_train_timesteps: int = 1000,
|
68 |
+
shift: float = 1.0,
|
69 |
+
use_dynamic_shifting=False,
|
70 |
+
base_shift: Optional[float] = 0.5,
|
71 |
+
max_shift: Optional[float] = 1.15,
|
72 |
+
base_image_seq_len: Optional[int] = 256,
|
73 |
+
max_image_seq_len: Optional[int] = 4096,
|
74 |
+
):
|
75 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
76 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
77 |
+
|
78 |
+
sigmas = timesteps / num_train_timesteps
|
79 |
+
if not use_dynamic_shifting:
|
80 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
81 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
82 |
+
|
83 |
+
self.timesteps = sigmas * num_train_timesteps
|
84 |
+
|
85 |
+
self._step_index = None
|
86 |
+
self._begin_index = None
|
87 |
+
|
88 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
89 |
+
self.sigma_min = self.sigmas[-1].item()
|
90 |
+
self.sigma_max = self.sigmas[0].item()
|
91 |
+
|
92 |
+
@property
|
93 |
+
def step_index(self):
|
94 |
+
"""
|
95 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
96 |
+
"""
|
97 |
+
return self._step_index
|
98 |
+
|
99 |
+
@property
|
100 |
+
def begin_index(self):
|
101 |
+
"""
|
102 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
103 |
+
"""
|
104 |
+
return self._begin_index
|
105 |
+
|
106 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
107 |
+
def set_begin_index(self, begin_index: int = 0):
|
108 |
+
"""
|
109 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
begin_index (`int`):
|
113 |
+
The begin index for the scheduler.
|
114 |
+
"""
|
115 |
+
self._begin_index = begin_index
|
116 |
+
|
117 |
+
def scale_noise(
|
118 |
+
self,
|
119 |
+
sample: torch.FloatTensor,
|
120 |
+
timestep: Union[float, torch.FloatTensor],
|
121 |
+
noise: Optional[torch.FloatTensor] = None,
|
122 |
+
) -> torch.FloatTensor:
|
123 |
+
"""
|
124 |
+
Forward process in flow-matching
|
125 |
+
|
126 |
+
Args:
|
127 |
+
sample (`torch.FloatTensor`):
|
128 |
+
The input sample.
|
129 |
+
timestep (`int`, *optional*):
|
130 |
+
The current timestep in the diffusion chain.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
`torch.FloatTensor`:
|
134 |
+
A scaled input sample.
|
135 |
+
"""
|
136 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
137 |
+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
138 |
+
|
139 |
+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
140 |
+
# mps does not support float64
|
141 |
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
142 |
+
timestep = timestep.to(sample.device, dtype=torch.float32)
|
143 |
+
else:
|
144 |
+
schedule_timesteps = self.timesteps.to(sample.device)
|
145 |
+
timestep = timestep.to(sample.device)
|
146 |
+
|
147 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
148 |
+
if self.begin_index is None:
|
149 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
150 |
+
elif self.step_index is not None:
|
151 |
+
# add_noise is called after first denoising step (for inpainting)
|
152 |
+
step_indices = [self.step_index] * timestep.shape[0]
|
153 |
+
else:
|
154 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
155 |
+
step_indices = [self.begin_index] * timestep.shape[0]
|
156 |
+
|
157 |
+
sigma = sigmas[step_indices].flatten()
|
158 |
+
while len(sigma.shape) < len(sample.shape):
|
159 |
+
sigma = sigma.unsqueeze(-1)
|
160 |
+
|
161 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
162 |
+
|
163 |
+
return sample
|
164 |
+
|
165 |
+
def _sigma_to_t(self, sigma):
|
166 |
+
return sigma * self.config.num_train_timesteps
|
167 |
+
|
168 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
169 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
170 |
+
|
171 |
+
def set_timesteps(
|
172 |
+
self,
|
173 |
+
num_inference_steps: int = None,
|
174 |
+
device: Union[str, torch.device] = None,
|
175 |
+
sigmas: Optional[List[float]] = None,
|
176 |
+
mu: Optional[float] = None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
180 |
+
|
181 |
+
Args:
|
182 |
+
num_inference_steps (`int`):
|
183 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
184 |
+
device (`str` or `torch.device`, *optional*):
|
185 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
186 |
+
"""
|
187 |
+
|
188 |
+
if self.config.use_dynamic_shifting and mu is None:
|
189 |
+
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
190 |
+
|
191 |
+
if sigmas is None:
|
192 |
+
self.num_inference_steps = num_inference_steps
|
193 |
+
timesteps = np.linspace(
|
194 |
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
195 |
+
)
|
196 |
+
|
197 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
198 |
+
|
199 |
+
if self.config.use_dynamic_shifting:
|
200 |
+
sigmas = self.time_shift(mu, 1.0, sigmas)
|
201 |
+
else:
|
202 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
203 |
+
|
204 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
205 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
206 |
+
|
207 |
+
self.timesteps = timesteps.to(device=device)
|
208 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
209 |
+
|
210 |
+
self._step_index = None
|
211 |
+
self._begin_index = None
|
212 |
+
|
213 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
214 |
+
if schedule_timesteps is None:
|
215 |
+
schedule_timesteps = self.timesteps
|
216 |
+
|
217 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
218 |
+
|
219 |
+
# The sigma index that is taken for the **very** first `step`
|
220 |
+
# is always the second index (or the last index if there is only 1)
|
221 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
222 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
223 |
+
pos = 1 if len(indices) > 1 else 0
|
224 |
+
|
225 |
+
return indices[pos].item()
|
226 |
+
|
227 |
+
def _init_step_index(self, timestep):
|
228 |
+
if self.begin_index is None:
|
229 |
+
if isinstance(timestep, torch.Tensor):
|
230 |
+
timestep = timestep.to(self.timesteps.device)
|
231 |
+
self._step_index = self.index_for_timestep(timestep)
|
232 |
+
else:
|
233 |
+
self._step_index = self._begin_index
|
234 |
+
|
235 |
+
def step(
|
236 |
+
self,
|
237 |
+
model_output: torch.FloatTensor,
|
238 |
+
timestep: Union[float, torch.FloatTensor],
|
239 |
+
sample: torch.FloatTensor,
|
240 |
+
s_churn: float = 0.0,
|
241 |
+
s_tmin: float = 0.0,
|
242 |
+
s_tmax: float = float("inf"),
|
243 |
+
s_noise: float = 1.0,
|
244 |
+
generator: Optional[torch.Generator] = None,
|
245 |
+
return_dict: bool = True,
|
246 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
247 |
+
"""
|
248 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
249 |
+
process from the learned model outputs (most often the predicted noise).
|
250 |
+
|
251 |
+
Args:
|
252 |
+
model_output (`torch.FloatTensor`):
|
253 |
+
The direct output from learned diffusion model.
|
254 |
+
timestep (`float`):
|
255 |
+
The current discrete timestep in the diffusion chain.
|
256 |
+
sample (`torch.FloatTensor`):
|
257 |
+
A current instance of a sample created by the diffusion process.
|
258 |
+
s_churn (`float`):
|
259 |
+
s_tmin (`float`):
|
260 |
+
s_tmax (`float`):
|
261 |
+
s_noise (`float`, defaults to 1.0):
|
262 |
+
Scaling factor for noise added to the sample.
|
263 |
+
generator (`torch.Generator`, *optional*):
|
264 |
+
A random number generator.
|
265 |
+
return_dict (`bool`):
|
266 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
267 |
+
tuple.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
271 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
272 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
273 |
+
"""
|
274 |
+
|
275 |
+
if (
|
276 |
+
isinstance(timestep, int)
|
277 |
+
or isinstance(timestep, torch.IntTensor)
|
278 |
+
or isinstance(timestep, torch.LongTensor)
|
279 |
+
):
|
280 |
+
raise ValueError(
|
281 |
+
(
|
282 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
283 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
284 |
+
" one of the `scheduler.timesteps` as a timestep."
|
285 |
+
),
|
286 |
+
)
|
287 |
+
|
288 |
+
if self.step_index is None:
|
289 |
+
self._init_step_index(timestep)
|
290 |
+
|
291 |
+
# Upcast to avoid precision issues when computing prev_sample
|
292 |
+
sample = sample.to(torch.float32)
|
293 |
+
|
294 |
+
sigma = self.sigmas[self.step_index]
|
295 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
296 |
+
|
297 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
298 |
+
|
299 |
+
# Cast sample back to model compatible dtype
|
300 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
301 |
+
|
302 |
+
# upon completion increase step index by one
|
303 |
+
self._step_index += 1
|
304 |
+
|
305 |
+
if not return_dict:
|
306 |
+
return (prev_sample,)
|
307 |
+
|
308 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
309 |
+
|
310 |
+
def __len__(self):
|
311 |
+
return self.config.num_train_timesteps
|
312 |
+
|
313 |
+
def step_to_x0(self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor) -> torch.FloatTensor:
|
314 |
+
"""
|
315 |
+
Compute the predicted x_0 given the model output and current sample at timestep t.
|
316 |
+
"""
|
317 |
+
if self.step_index is None:
|
318 |
+
self._init_step_index(timestep)
|
319 |
+
|
320 |
+
sigma = self.sigmas[self.step_index]
|
321 |
+
sigma_from = sigma
|
322 |
+
sigma_to = self.sigmas[-1] # This corresponds to x_0
|
323 |
+
|
324 |
+
x0 = sample + (sigma_to - sigma_from) * model_output
|
325 |
+
return x0
|
flux/transformer_flux.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from .lora.peft import PeftAdapterMixin
|
24 |
+
from diffusers.models.attention import FeedForward
|
25 |
+
from .attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
from .normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
28 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
29 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
30 |
+
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
32 |
+
import numpy as np
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
def get_1d_rotary_pos_embed(
|
38 |
+
dim: int,
|
39 |
+
pos: Union[np.ndarray, int],
|
40 |
+
theta: float = 10000.0,
|
41 |
+
use_real=False,
|
42 |
+
linear_factor=1.0,
|
43 |
+
ntk_factor=1.0,
|
44 |
+
repeat_interleave_real=True,
|
45 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
49 |
+
|
50 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
51 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
52 |
+
data type.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
dim (`int`): Dimension of the frequency tensor.
|
56 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
57 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
58 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
59 |
+
use_real (`bool`, *optional*):
|
60 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
61 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
62 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
63 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
64 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
65 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
66 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
67 |
+
Otherwise, they are concateanted with themselves.
|
68 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
69 |
+
the dtype of the frequency tensor.
|
70 |
+
Returns:
|
71 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
72 |
+
"""
|
73 |
+
assert dim % 2 == 0
|
74 |
+
|
75 |
+
if isinstance(pos, int):
|
76 |
+
pos = torch.arange(pos)
|
77 |
+
if isinstance(pos, np.ndarray):
|
78 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
79 |
+
|
80 |
+
theta = theta * ntk_factor
|
81 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
82 |
+
freqs = freqs.to(pos.device)
|
83 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
84 |
+
if use_real and repeat_interleave_real:
|
85 |
+
# flux, hunyuan-dit, cogvideox
|
86 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
87 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
88 |
+
return freqs_cos, freqs_sin
|
89 |
+
elif use_real:
|
90 |
+
# stable audio
|
91 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
92 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
93 |
+
return freqs_cos, freqs_sin
|
94 |
+
else:
|
95 |
+
# lumina
|
96 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
97 |
+
return freqs_cis
|
98 |
+
|
99 |
+
class FluxPosEmbed(nn.Module):
|
100 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
101 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
102 |
+
super().__init__()
|
103 |
+
self.theta = theta
|
104 |
+
self.axes_dim = axes_dim
|
105 |
+
|
106 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
107 |
+
n_axes = ids.shape[-1]
|
108 |
+
cos_out = []
|
109 |
+
sin_out = []
|
110 |
+
pos = ids.squeeze().float()
|
111 |
+
is_mps = ids.device.type == "mps"
|
112 |
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
113 |
+
for i in range(n_axes):
|
114 |
+
cos, sin = get_1d_rotary_pos_embed(
|
115 |
+
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
116 |
+
)
|
117 |
+
cos_out.append(cos)
|
118 |
+
sin_out.append(sin)
|
119 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
120 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
121 |
+
return freqs_cos, freqs_sin
|
122 |
+
|
123 |
+
# YiYi to-do: refactor rope related functions/classes
|
124 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
125 |
+
assert dim % 2 == 0, "The dimension must be even."
|
126 |
+
|
127 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
128 |
+
omega = 1.0 / (theta**scale)
|
129 |
+
|
130 |
+
batch_size, seq_length = pos.shape
|
131 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
132 |
+
cos_out = torch.cos(out)
|
133 |
+
sin_out = torch.sin(out)
|
134 |
+
|
135 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
136 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
137 |
+
return out.float()
|
138 |
+
|
139 |
+
|
140 |
+
# YiYi to-do: refactor rope related functions/classes
|
141 |
+
class EmbedND(nn.Module):
|
142 |
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
143 |
+
super().__init__()
|
144 |
+
self.dim = dim
|
145 |
+
self.theta = theta
|
146 |
+
self.axes_dim = axes_dim
|
147 |
+
|
148 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
149 |
+
n_axes = ids.shape[-1]
|
150 |
+
emb = torch.cat(
|
151 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
152 |
+
dim=-3,
|
153 |
+
)
|
154 |
+
return emb.unsqueeze(1)
|
155 |
+
|
156 |
+
|
157 |
+
@maybe_allow_in_graph
|
158 |
+
class FluxSingleTransformerBlock(nn.Module):
|
159 |
+
r"""
|
160 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
161 |
+
|
162 |
+
Reference: https://arxiv.org/abs/2403.03206
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
dim (`int`): The number of channels in the input and output.
|
166 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
167 |
+
attention_head_dim (`int`): The number of channels in each head.
|
168 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
169 |
+
processing of `context` conditions.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
173 |
+
super().__init__()
|
174 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
175 |
+
|
176 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
177 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
178 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
179 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
180 |
+
|
181 |
+
processor = FluxSingleAttnProcessor2_0()
|
182 |
+
self.attn = Attention(
|
183 |
+
query_dim=dim,
|
184 |
+
cross_attention_dim=None,
|
185 |
+
dim_head=attention_head_dim,
|
186 |
+
heads=num_attention_heads,
|
187 |
+
out_dim=dim,
|
188 |
+
bias=True,
|
189 |
+
processor=processor,
|
190 |
+
qk_norm="rms_norm",
|
191 |
+
eps=1e-6,
|
192 |
+
pre_only=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
hidden_states: torch.FloatTensor,
|
198 |
+
temb: torch.FloatTensor,
|
199 |
+
image_rotary_emb=None,
|
200 |
+
):
|
201 |
+
residual = hidden_states
|
202 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
203 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
204 |
+
|
205 |
+
attn_output = self.attn(
|
206 |
+
hidden_states=norm_hidden_states,
|
207 |
+
image_rotary_emb=image_rotary_emb,
|
208 |
+
)
|
209 |
+
|
210 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
211 |
+
gate = gate.unsqueeze(1)
|
212 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
213 |
+
hidden_states = residual + hidden_states
|
214 |
+
|
215 |
+
return hidden_states
|
216 |
+
|
217 |
+
|
218 |
+
@maybe_allow_in_graph
|
219 |
+
class FluxTransformerBlock(nn.Module):
|
220 |
+
r"""
|
221 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
222 |
+
|
223 |
+
Reference: https://arxiv.org/abs/2403.03206
|
224 |
+
|
225 |
+
Parameters:
|
226 |
+
dim (`int`): The number of channels in the input and output.
|
227 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
228 |
+
attention_head_dim (`int`): The number of channels in each head.
|
229 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
230 |
+
processing of `context` conditions.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
|
234 |
+
super().__init__()
|
235 |
+
|
236 |
+
self.norm1 = AdaLayerNormZero(dim)
|
237 |
+
|
238 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
239 |
+
|
240 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
241 |
+
processor = FluxAttnProcessor2_0()
|
242 |
+
else:
|
243 |
+
raise ValueError(
|
244 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
245 |
+
)
|
246 |
+
self.attn = Attention(
|
247 |
+
query_dim=dim,
|
248 |
+
cross_attention_dim=None,
|
249 |
+
added_kv_proj_dim=dim,
|
250 |
+
dim_head=attention_head_dim,
|
251 |
+
heads=num_attention_heads,
|
252 |
+
out_dim=dim,
|
253 |
+
context_pre_only=False,
|
254 |
+
bias=True,
|
255 |
+
processor=processor,
|
256 |
+
qk_norm=qk_norm,
|
257 |
+
eps=eps,
|
258 |
+
)
|
259 |
+
|
260 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
261 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
262 |
+
|
263 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
264 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
265 |
+
|
266 |
+
# let chunk size default to None
|
267 |
+
self._chunk_size = None
|
268 |
+
self._chunk_dim = 0
|
269 |
+
|
270 |
+
def forward(
|
271 |
+
self,
|
272 |
+
hidden_states: torch.FloatTensor,
|
273 |
+
encoder_hidden_states: torch.FloatTensor,
|
274 |
+
temb: torch.FloatTensor,
|
275 |
+
image_rotary_emb=None,
|
276 |
+
):
|
277 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
278 |
+
|
279 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
280 |
+
encoder_hidden_states, emb=temb
|
281 |
+
)
|
282 |
+
|
283 |
+
# Attention.
|
284 |
+
attn_output, context_attn_output = self.attn(
|
285 |
+
hidden_states=norm_hidden_states,
|
286 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
287 |
+
image_rotary_emb=image_rotary_emb,
|
288 |
+
)
|
289 |
+
|
290 |
+
# Process attention outputs for the `hidden_states`.
|
291 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
292 |
+
hidden_states = hidden_states + attn_output
|
293 |
+
|
294 |
+
norm_hidden_states = self.norm2(hidden_states)
|
295 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
296 |
+
|
297 |
+
ff_output = self.ff(norm_hidden_states)
|
298 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
299 |
+
|
300 |
+
hidden_states = hidden_states + ff_output
|
301 |
+
|
302 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
303 |
+
|
304 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
305 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
306 |
+
|
307 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
308 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
309 |
+
|
310 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
311 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
312 |
+
|
313 |
+
return encoder_hidden_states, hidden_states
|
314 |
+
|
315 |
+
|
316 |
+
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
317 |
+
"""
|
318 |
+
The Transformer model introduced in Flux.
|
319 |
+
|
320 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
321 |
+
|
322 |
+
Parameters:
|
323 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
324 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
325 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
326 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
327 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
328 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
329 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
330 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
331 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
332 |
+
"""
|
333 |
+
|
334 |
+
_supports_gradient_checkpointing = True
|
335 |
+
|
336 |
+
@register_to_config
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
patch_size: int = 1,
|
340 |
+
in_channels: int = 64,
|
341 |
+
num_layers: int = 19,
|
342 |
+
num_single_layers: int = 38,
|
343 |
+
attention_head_dim: int = 128,
|
344 |
+
num_attention_heads: int = 24,
|
345 |
+
joint_attention_dim: int = 4096,
|
346 |
+
pooled_projection_dim: int = 768,
|
347 |
+
guidance_embeds: bool = False,
|
348 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
349 |
+
):
|
350 |
+
super().__init__()
|
351 |
+
self.out_channels = in_channels
|
352 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
353 |
+
|
354 |
+
#self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
|
355 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
356 |
+
text_time_guidance_cls = (
|
357 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
358 |
+
)
|
359 |
+
self.time_text_embed = text_time_guidance_cls(
|
360 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
361 |
+
)
|
362 |
+
|
363 |
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
364 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
365 |
+
|
366 |
+
self.transformer_blocks = nn.ModuleList(
|
367 |
+
[
|
368 |
+
FluxTransformerBlock(
|
369 |
+
dim=self.inner_dim,
|
370 |
+
num_attention_heads=self.config.num_attention_heads,
|
371 |
+
attention_head_dim=self.config.attention_head_dim,
|
372 |
+
)
|
373 |
+
for i in range(self.config.num_layers)
|
374 |
+
]
|
375 |
+
)
|
376 |
+
|
377 |
+
self.single_transformer_blocks = nn.ModuleList(
|
378 |
+
[
|
379 |
+
FluxSingleTransformerBlock(
|
380 |
+
dim=self.inner_dim,
|
381 |
+
num_attention_heads=self.config.num_attention_heads,
|
382 |
+
attention_head_dim=self.config.attention_head_dim,
|
383 |
+
)
|
384 |
+
for i in range(self.config.num_single_layers)
|
385 |
+
]
|
386 |
+
)
|
387 |
+
|
388 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
389 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
390 |
+
|
391 |
+
self.gradient_checkpointing = True
|
392 |
+
|
393 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
394 |
+
if hasattr(module, "gradient_checkpointing"):
|
395 |
+
module.gradient_checkpointing = value
|
396 |
+
|
397 |
+
def forward(
|
398 |
+
self,
|
399 |
+
hidden_states: torch.Tensor,
|
400 |
+
encoder_hidden_states: torch.Tensor = None,
|
401 |
+
t5_encoder_hidden_states: torch.Tensor = None,
|
402 |
+
pooled_projections: torch.Tensor = None,
|
403 |
+
timestep: torch.LongTensor = None,
|
404 |
+
img_ids: torch.Tensor = None,
|
405 |
+
txt_ids: torch.Tensor = None,
|
406 |
+
guidance: torch.Tensor = None,
|
407 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
408 |
+
controlnet_block_samples=None,
|
409 |
+
controlnet_single_block_samples=None,
|
410 |
+
return_dict: bool = True,
|
411 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
412 |
+
"""
|
413 |
+
The [`FluxTransformer2DModel`] forward method.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
417 |
+
Input `hidden_states`.
|
418 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
419 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
420 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
421 |
+
from the embeddings of input conditions.
|
422 |
+
timestep ( `torch.LongTensor`):
|
423 |
+
Used to indicate denoising step.
|
424 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
425 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
426 |
+
joint_attention_kwargs (`dict`, *optional*):
|
427 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
428 |
+
`self.processor` in
|
429 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
430 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
431 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
432 |
+
tuple.
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
436 |
+
`tuple` where the first element is the sample tensor.
|
437 |
+
"""
|
438 |
+
if joint_attention_kwargs is not None:
|
439 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
440 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
441 |
+
else:
|
442 |
+
lora_scale = 1.0
|
443 |
+
|
444 |
+
if USE_PEFT_BACKEND:
|
445 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
446 |
+
scale_lora_layers(self, lora_scale)
|
447 |
+
else:
|
448 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
449 |
+
logger.warning(
|
450 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
451 |
+
)
|
452 |
+
hidden_states = self.x_embedder(hidden_states)
|
453 |
+
|
454 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
455 |
+
if guidance is not None:
|
456 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
457 |
+
else:
|
458 |
+
guidance = None
|
459 |
+
temb = (
|
460 |
+
self.time_text_embed(timestep, pooled_projections)
|
461 |
+
if guidance is None
|
462 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
463 |
+
)
|
464 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
465 |
+
if t5_encoder_hidden_states is not None:
|
466 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
|
467 |
+
|
468 |
+
#ids = torch.cat((txt_ids, img_ids), dim=1)
|
469 |
+
if txt_ids.ndim == 3:
|
470 |
+
#logger.warning(
|
471 |
+
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
472 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
473 |
+
#)
|
474 |
+
txt_ids = txt_ids[0]
|
475 |
+
if img_ids.ndim == 3:
|
476 |
+
#logger.warning(
|
477 |
+
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
478 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
479 |
+
#)
|
480 |
+
img_ids = img_ids[0]
|
481 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
482 |
+
image_rotary_emb = self.pos_embed(ids)
|
483 |
+
|
484 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
485 |
+
|
486 |
+
if self.training and self.gradient_checkpointing:
|
487 |
+
|
488 |
+
def create_custom_forward(module, return_dict=None):
|
489 |
+
def custom_forward(*inputs):
|
490 |
+
if return_dict is not None:
|
491 |
+
return module(*inputs, return_dict=return_dict)
|
492 |
+
else:
|
493 |
+
return module(*inputs)
|
494 |
+
|
495 |
+
return custom_forward
|
496 |
+
|
497 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
498 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
499 |
+
create_custom_forward(block),
|
500 |
+
hidden_states,
|
501 |
+
encoder_hidden_states,
|
502 |
+
temb,
|
503 |
+
image_rotary_emb,
|
504 |
+
**ckpt_kwargs,
|
505 |
+
)
|
506 |
+
|
507 |
+
else:
|
508 |
+
encoder_hidden_states, hidden_states = block(
|
509 |
+
hidden_states=hidden_states,
|
510 |
+
encoder_hidden_states=encoder_hidden_states,
|
511 |
+
temb=temb,
|
512 |
+
image_rotary_emb=image_rotary_emb,
|
513 |
+
)
|
514 |
+
|
515 |
+
# controlnet residual
|
516 |
+
if controlnet_block_samples is not None:
|
517 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
518 |
+
interval_control = int(np.ceil(interval_control))
|
519 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
520 |
+
|
521 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
522 |
+
|
523 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
524 |
+
if self.training and self.gradient_checkpointing:
|
525 |
+
|
526 |
+
def create_custom_forward(module, return_dict=None):
|
527 |
+
def custom_forward(*inputs):
|
528 |
+
if return_dict is not None:
|
529 |
+
return module(*inputs, return_dict=return_dict)
|
530 |
+
else:
|
531 |
+
return module(*inputs)
|
532 |
+
|
533 |
+
return custom_forward
|
534 |
+
|
535 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
536 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
537 |
+
create_custom_forward(block),
|
538 |
+
hidden_states,
|
539 |
+
temb,
|
540 |
+
image_rotary_emb,
|
541 |
+
**ckpt_kwargs,
|
542 |
+
)
|
543 |
+
|
544 |
+
else:
|
545 |
+
hidden_states = block(
|
546 |
+
hidden_states=hidden_states,
|
547 |
+
temb=temb,
|
548 |
+
image_rotary_emb=image_rotary_emb,
|
549 |
+
)
|
550 |
+
|
551 |
+
# controlnet residual
|
552 |
+
if controlnet_single_block_samples is not None:
|
553 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
554 |
+
interval_control = int(np.ceil(interval_control))
|
555 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
556 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
557 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
558 |
+
)
|
559 |
+
|
560 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
561 |
+
|
562 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
563 |
+
output = self.proj_out(hidden_states)
|
564 |
+
|
565 |
+
if USE_PEFT_BACKEND:
|
566 |
+
# remove `lora_scale` from each PEFT layer
|
567 |
+
unscale_lora_layers(self, lora_scale)
|
568 |
+
|
569 |
+
if not return_dict:
|
570 |
+
return (output,)
|
571 |
+
|
572 |
+
return Transformer2DModelOutput(sample=output)
|
main.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from model import FluxModel
|
6 |
+
|
7 |
+
def parse_args():
|
8 |
+
parser = argparse.ArgumentParser(description='Flux Image Generation Tool')
|
9 |
+
|
10 |
+
# Required arguments
|
11 |
+
parser.add_argument('--mode', type=str, required=True,
|
12 |
+
choices=['variation', 'img2img', 'inpaint', 'controlnet', 'controlnet-inpaint'],
|
13 |
+
help='Generation mode')
|
14 |
+
parser.add_argument('--input_image', type=str, required=True,
|
15 |
+
help='Path to the input image')
|
16 |
+
|
17 |
+
# Optional arguments
|
18 |
+
parser.add_argument('--prompt', type=str, default="",
|
19 |
+
help='Text prompt to guide the generation')
|
20 |
+
parser.add_argument('--reference_image', type=str, default=None,
|
21 |
+
help='Path to the reference image (for img2img/controlnet modes)')
|
22 |
+
parser.add_argument('--mask_image', type=str, default=None,
|
23 |
+
help='Path to the mask image (for inpainting modes)')
|
24 |
+
parser.add_argument('--output_dir', type=str, default='outputs',
|
25 |
+
help='Directory to save generated images')
|
26 |
+
parser.add_argument('--image_count', type=int, default=1,
|
27 |
+
help='Number of images to generate')
|
28 |
+
parser.add_argument('--aspect_ratio', type=str, default='1:1',
|
29 |
+
choices=['1:1', '16:9', '9:16', '2.4:1', '3:4', '4:3'],
|
30 |
+
help='Output image aspect ratio')
|
31 |
+
parser.add_argument('--steps', type=int, default=28,
|
32 |
+
help='Number of inference steps')
|
33 |
+
parser.add_argument('--guidance_scale', type=float, default=7.5,
|
34 |
+
help='Guidance scale for generation')
|
35 |
+
parser.add_argument('--denoise_strength', type=float, default=0.8,
|
36 |
+
help='Denoising strength for img2img/inpaint')
|
37 |
+
|
38 |
+
# Attention related arguments
|
39 |
+
parser.add_argument('--center_x', type=float, default=None,
|
40 |
+
help='X coordinate of attention center (0-1)')
|
41 |
+
parser.add_argument('--center_y', type=float, default=None,
|
42 |
+
help='Y coordinate of attention center (0-1)')
|
43 |
+
parser.add_argument('--radius', type=float, default=None,
|
44 |
+
help='Radius of attention circle (0-1)')
|
45 |
+
|
46 |
+
# ControlNet related arguments
|
47 |
+
parser.add_argument('--line_mode', action='store_true',
|
48 |
+
help='Enable line detection mode for ControlNet')
|
49 |
+
parser.add_argument('--depth_mode', action='store_true',
|
50 |
+
help='Enable depth mode for ControlNet')
|
51 |
+
parser.add_argument('--line_strength', type=float, default=0.4,
|
52 |
+
help='Strength of line guidance')
|
53 |
+
parser.add_argument('--depth_strength', type=float, default=0.2,
|
54 |
+
help='Strength of depth guidance')
|
55 |
+
|
56 |
+
# Device selection
|
57 |
+
parser.add_argument('--device', type=str, default='cuda',
|
58 |
+
choices=['cuda', 'cpu'],
|
59 |
+
help='Device to run the model on')
|
60 |
+
parser.add_argument('--turbo', action='store_true',
|
61 |
+
help='Enable turbo mode for faster inference')
|
62 |
+
|
63 |
+
return parser.parse_args()
|
64 |
+
|
65 |
+
def load_image(image_path):
|
66 |
+
"""Load and return a PIL Image."""
|
67 |
+
try:
|
68 |
+
return Image.open(image_path).convert('RGB')
|
69 |
+
except Exception as e:
|
70 |
+
raise ValueError(f"Error loading image {image_path}: {str(e)}")
|
71 |
+
|
72 |
+
def save_images(images, output_dir, prefix="generated"):
|
73 |
+
"""Save generated images with sequential numbering."""
|
74 |
+
import os
|
75 |
+
os.makedirs(output_dir, exist_ok=True)
|
76 |
+
|
77 |
+
for i, image in enumerate(images):
|
78 |
+
output_path = os.path.join(output_dir, f"{prefix}_{i+1}.png")
|
79 |
+
image.save(output_path)
|
80 |
+
print(f"Saved image to {output_path}")
|
81 |
+
|
82 |
+
def get_required_features(args):
|
83 |
+
"""Determine which model features are required based on the arguments."""
|
84 |
+
features = []
|
85 |
+
|
86 |
+
if args.mode in ['controlnet', 'controlnet-inpaint']:
|
87 |
+
features.append('controlnet')
|
88 |
+
if args.depth_mode:
|
89 |
+
features.append('depth')
|
90 |
+
if args.line_mode:
|
91 |
+
features.append('line')
|
92 |
+
|
93 |
+
if args.mode in ['inpaint', 'controlnet-inpaint']:
|
94 |
+
features.append('sam') # If you're using SAM for mask generation
|
95 |
+
|
96 |
+
return features
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
args = parse_args()
|
101 |
+
|
102 |
+
# Check CUDA availability if requested
|
103 |
+
if args.device == 'cuda' and not torch.cuda.is_available():
|
104 |
+
print("CUDA requested but not available. Falling back to CPU.")
|
105 |
+
args.device = 'cpu'
|
106 |
+
|
107 |
+
# Determine required features based on mode and arguments
|
108 |
+
required_features = get_required_features(args)
|
109 |
+
|
110 |
+
# Initialize model with only required features
|
111 |
+
print(f"Initializing model on {args.device} with features: {required_features}")
|
112 |
+
model = FluxModel(
|
113 |
+
is_turbo=args.turbo,
|
114 |
+
device=args.device,
|
115 |
+
required_features=required_features
|
116 |
+
)
|
117 |
+
|
118 |
+
# Load input images
|
119 |
+
input_image = load_image(args.input_image)
|
120 |
+
reference_image = load_image(args.reference_image) if args.reference_image else None
|
121 |
+
mask_image = load_image(args.mask_image) if args.mask_image else None
|
122 |
+
|
123 |
+
# Validate inputs based on mode
|
124 |
+
if args.mode in ['inpaint', 'controlnet-inpaint'] and mask_image is None:
|
125 |
+
raise ValueError(f"{args.mode} mode requires a mask image")
|
126 |
+
|
127 |
+
# Generate images
|
128 |
+
print(f"Generating {args.image_count} images in {args.mode} mode...")
|
129 |
+
generated_images = model.generate(
|
130 |
+
input_image_a=input_image,
|
131 |
+
input_image_b=reference_image,
|
132 |
+
prompt=args.prompt,
|
133 |
+
mask_image=mask_image,
|
134 |
+
mode=args.mode,
|
135 |
+
imageCount=args.image_count,
|
136 |
+
aspect_ratio=args.aspect_ratio,
|
137 |
+
num_inference_steps=args.steps,
|
138 |
+
guidance_scale=args.guidance_scale,
|
139 |
+
denoise_strength=args.denoise_strength,
|
140 |
+
center_x=args.center_x,
|
141 |
+
center_y=args.center_y,
|
142 |
+
radius=args.radius,
|
143 |
+
line_mode=args.line_mode,
|
144 |
+
depth_mode=args.depth_mode,
|
145 |
+
line_strength=args.line_strength,
|
146 |
+
depth_strength=args.depth_strength
|
147 |
+
)
|
148 |
+
|
149 |
+
# Save generated images
|
150 |
+
save_images(generated_images, args.output_dir)
|
151 |
+
print("Generation completed successfully!")
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
|
5 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
6 |
+
from flux.transformer_flux import FluxTransformer2DModel
|
7 |
+
|
8 |
+
from flux.pipeline_flux_chameleon import FluxPipeline
|
9 |
+
from flux.pipeline_flux_img2img import FluxImg2ImgPipeline
|
10 |
+
from flux.pipeline_flux_inpaint import FluxInpaintPipeline
|
11 |
+
from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel
|
12 |
+
from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline
|
13 |
+
from flux.controlnet_flux import FluxMultiControlNetModel
|
14 |
+
from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
15 |
+
|
16 |
+
from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
|
17 |
+
import os
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import math
|
21 |
+
|
22 |
+
def get_model_path(model_name):
|
23 |
+
"""Get the full path for a model based on the checkpoints directory."""
|
24 |
+
base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override
|
25 |
+
return os.path.join(base_dir, model_name)
|
26 |
+
|
27 |
+
# Model paths configuration
|
28 |
+
MODEL_PATHS = {
|
29 |
+
'flux': get_model_path('flux'),
|
30 |
+
'qwen2vl': get_model_path('qwen2-vl'),
|
31 |
+
'controlnet': get_model_path('controlnet'),
|
32 |
+
'depth_anything': {
|
33 |
+
'path': get_model_path('depth-anything-v2'),
|
34 |
+
'weights': 'depth_anything_v2_vitl.pth'
|
35 |
+
},
|
36 |
+
'anyline': {
|
37 |
+
'path': get_model_path('anyline'),
|
38 |
+
'weights': 'MTEED.pth'
|
39 |
+
},
|
40 |
+
'sam2': {
|
41 |
+
'path': get_model_path('segment-anything-2'),
|
42 |
+
'weights': 'sam2_hiera_large.pt',
|
43 |
+
'config': 'sam2_hiera_l.yaml'
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
ASPECT_RATIOS = {
|
49 |
+
"1:1": (1024, 1024),
|
50 |
+
"16:9": (1344, 768),
|
51 |
+
"9:16": (768, 1344),
|
52 |
+
"2.4:1": (1536, 640),
|
53 |
+
"3:4": (896, 1152),
|
54 |
+
"4:3": (1152, 896),
|
55 |
+
}
|
56 |
+
|
57 |
+
class Qwen2Connector(nn.Module):
|
58 |
+
def __init__(self, input_dim=3584, output_dim=4096):
|
59 |
+
super().__init__()
|
60 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.linear(x)
|
64 |
+
|
65 |
+
class FluxModel:
|
66 |
+
def __init__(self, is_turbo=False, device="cuda", required_features=None):
|
67 |
+
"""
|
68 |
+
Initialize FluxModel with specified features
|
69 |
+
Args:
|
70 |
+
is_turbo: Enable turbo mode for faster inference
|
71 |
+
device: Device to run the model on
|
72 |
+
required_features: List of required features ['controlnet', 'depth', 'line', 'sam']
|
73 |
+
"""
|
74 |
+
self.device = torch.device(device)
|
75 |
+
self.dtype = torch.bfloat16
|
76 |
+
if required_features is None:
|
77 |
+
required_features = []
|
78 |
+
|
79 |
+
self._line_detector_imported = False
|
80 |
+
self._depth_model_imported = False
|
81 |
+
self._sam_imported = False
|
82 |
+
self._turbo_imported = False
|
83 |
+
|
84 |
+
# Initialize base models (always required)
|
85 |
+
self._init_base_models()
|
86 |
+
|
87 |
+
# Initialize optional models based on requirements
|
88 |
+
if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']):
|
89 |
+
self._init_controlnet()
|
90 |
+
|
91 |
+
if 'depth' in required_features:
|
92 |
+
self._init_depth_model()
|
93 |
+
|
94 |
+
if 'line' in required_features:
|
95 |
+
self._init_line_detector()
|
96 |
+
|
97 |
+
if 'sam' in required_features:
|
98 |
+
self._init_sam()
|
99 |
+
|
100 |
+
if is_turbo:
|
101 |
+
self._enable_turbo()
|
102 |
+
|
103 |
+
def _init_base_models(self):
|
104 |
+
"""Initialize the core models that are always needed"""
|
105 |
+
# Qwen2VL and connector initialization
|
106 |
+
self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
|
107 |
+
MODEL_PATHS['qwen2vl'],
|
108 |
+
torch_dtype=self.dtype
|
109 |
+
)
|
110 |
+
self.qwen2vl.requires_grad_(False).to(self.device)
|
111 |
+
|
112 |
+
self.connector = Qwen2Connector(input_dim=3584, output_dim=4096)
|
113 |
+
connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt")
|
114 |
+
if os.path.exists(connector_path):
|
115 |
+
connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True)
|
116 |
+
connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()}
|
117 |
+
self.connector.load_state_dict(connector_state_dict)
|
118 |
+
self.connector.to(self.dtype).to(self.device)
|
119 |
+
|
120 |
+
# Text encoders initialization
|
121 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer")
|
122 |
+
self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder")
|
123 |
+
self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2")
|
124 |
+
self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2")
|
125 |
+
|
126 |
+
self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device)
|
127 |
+
self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device)
|
128 |
+
|
129 |
+
# T5 context embedder
|
130 |
+
self.t5_context_embedder = nn.Linear(4096, 3072)
|
131 |
+
t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt")
|
132 |
+
t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True)
|
133 |
+
self.t5_context_embedder.load_state_dict(t5_embedder_state_dict)
|
134 |
+
self.t5_context_embedder.to(self.dtype).to(self.device)
|
135 |
+
|
136 |
+
# Basic components
|
137 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1)
|
138 |
+
self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae")
|
139 |
+
self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer")
|
140 |
+
|
141 |
+
self.vae.requires_grad_(False).to(self.dtype).to(self.device)
|
142 |
+
self.transformer.requires_grad_(False).to(self.dtype).to(self.device)
|
143 |
+
|
144 |
+
def _init_controlnet(self):
|
145 |
+
"""Initialize ControlNet model"""
|
146 |
+
self.controlnet_union = FluxControlNetModel.from_pretrained(
|
147 |
+
MODEL_PATHS['controlnet'],
|
148 |
+
torch_dtype=torch.bfloat16
|
149 |
+
)
|
150 |
+
self.controlnet_union.requires_grad_(False).to(self.device)
|
151 |
+
self.controlnet = FluxMultiControlNetModel([self.controlnet_union])
|
152 |
+
|
153 |
+
def _init_depth_model(self):
|
154 |
+
"""Initialize Depth Anything V2 model"""
|
155 |
+
if not self._depth_model_imported:
|
156 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
157 |
+
self._depth_model_imported = True
|
158 |
+
|
159 |
+
self.depth_model = DepthAnythingV2(
|
160 |
+
encoder='vitl',
|
161 |
+
features=256,
|
162 |
+
out_channels=[256, 512, 1024, 1024]
|
163 |
+
)
|
164 |
+
depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'],
|
165 |
+
MODEL_PATHS['depth_anything']['weights'])
|
166 |
+
self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device))
|
167 |
+
self.depth_model.requires_grad_(False).to(self.device)
|
168 |
+
|
169 |
+
def _init_line_detector(self):
|
170 |
+
"""Initialize line detection model"""
|
171 |
+
if not self._line_detector_imported:
|
172 |
+
from controlnet_aux import AnylineDetector
|
173 |
+
self._line_detector_imported = True
|
174 |
+
|
175 |
+
self.anyline = AnylineDetector.from_pretrained(
|
176 |
+
MODEL_PATHS['anyline']['path'],
|
177 |
+
filename=MODEL_PATHS['anyline']['weights']
|
178 |
+
)
|
179 |
+
self.anyline.to(self.device)
|
180 |
+
|
181 |
+
def _init_sam(self):
|
182 |
+
"""Initialize SAM2 model"""
|
183 |
+
if not self._sam_imported:
|
184 |
+
from sam2.build_sam import build_sam2
|
185 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
186 |
+
self._sam_imported = True
|
187 |
+
|
188 |
+
sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'],
|
189 |
+
MODEL_PATHS['sam2']['weights'])
|
190 |
+
model_cfg = os.path.join(MODEL_PATHS['sam2']['path'],
|
191 |
+
MODEL_PATHS['sam2']['config'])
|
192 |
+
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
|
193 |
+
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
|
194 |
+
|
195 |
+
def _enable_turbo(self):
|
196 |
+
"""Enable turbo mode for faster inference"""
|
197 |
+
if not self._turbo_imported:
|
198 |
+
from optimum.quanto import freeze, qfloat8, quantize
|
199 |
+
self._turbo_imported = True
|
200 |
+
|
201 |
+
quantize(
|
202 |
+
self.transformer,
|
203 |
+
weights=qfloat8,
|
204 |
+
exclude=[
|
205 |
+
"*.norm", "*.norm1", "*.norm2", "*.norm2_context",
|
206 |
+
"proj_out", "x_embedder", "norm_out", "context_embedder",
|
207 |
+
],
|
208 |
+
)
|
209 |
+
freeze(self.transformer)
|
210 |
+
|
211 |
+
def generate_mask(self, image, input_points, input_labels):
|
212 |
+
"""
|
213 |
+
使用SAM2生成分割mask
|
214 |
+
|
215 |
+
Args:
|
216 |
+
image: PIL Image或numpy数组
|
217 |
+
input_points: numpy数组,形状为(N, 2),包含点的坐标
|
218 |
+
input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
PIL Image: 最高分数的mask
|
222 |
+
"""
|
223 |
+
try:
|
224 |
+
# 确保图像是numpy数组
|
225 |
+
if isinstance(image, Image.Image):
|
226 |
+
image_array = np.array(image)
|
227 |
+
else:
|
228 |
+
image_array = image
|
229 |
+
|
230 |
+
# 设置图像
|
231 |
+
self.sam2_predictor.set_image(image_array)
|
232 |
+
|
233 |
+
# 进行预测
|
234 |
+
with torch.inference_mode():
|
235 |
+
masks, scores, logits = self.sam2_predictor.predict(
|
236 |
+
point_coords=input_points,
|
237 |
+
point_labels=input_labels,
|
238 |
+
multimask_output=True,
|
239 |
+
)
|
240 |
+
|
241 |
+
# 返回得分最高的mask
|
242 |
+
best_mask_idx = scores.argmax()
|
243 |
+
mask = masks[best_mask_idx]
|
244 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
245 |
+
return mask_image
|
246 |
+
|
247 |
+
except Exception as e:
|
248 |
+
print(f"Mask generation failed: {str(e)}")
|
249 |
+
raise
|
250 |
+
|
251 |
+
def recover_2d_shape(self, image_hidden_state, grid_thw):
|
252 |
+
batch_size, num_tokens, hidden_dim = image_hidden_state.shape
|
253 |
+
_, h, w = grid_thw
|
254 |
+
h_out = h // 2
|
255 |
+
w_out = w // 2
|
256 |
+
# 重塑为 (batch_size, height, width, hidden_dim)
|
257 |
+
reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim)
|
258 |
+
return reshaped
|
259 |
+
|
260 |
+
def generate_attention_matrix(self, center_x, center_y, radius, image_shape):
|
261 |
+
height, width = image_shape
|
262 |
+
y, x = np.ogrid[:height, :width]
|
263 |
+
center_y, center_x = center_y * height, center_x * width
|
264 |
+
distances = np.sqrt((x - center_x)**2 + (y - center_y)**2)
|
265 |
+
attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1)
|
266 |
+
return attention
|
267 |
+
|
268 |
+
def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius):
|
269 |
+
qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0]))
|
270 |
+
attention_matrix = self.generate_attention_matrix(
|
271 |
+
center_x, center_y, radius,
|
272 |
+
(qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2))
|
273 |
+
)
|
274 |
+
attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1)
|
275 |
+
qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device)
|
276 |
+
return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3))
|
277 |
+
|
278 |
+
def compute_text_embeddings(self, prompt):
|
279 |
+
with torch.no_grad():
|
280 |
+
text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
281 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
282 |
+
prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False)
|
283 |
+
pooled_prompt_embeds = prompt_embeds.pooler_output
|
284 |
+
return pooled_prompt_embeds.to(self.dtype)
|
285 |
+
|
286 |
+
def compute_t5_text_embeddings(
|
287 |
+
self,
|
288 |
+
max_sequence_length=256,
|
289 |
+
prompt=None,
|
290 |
+
num_images_per_prompt=1,
|
291 |
+
device=None,
|
292 |
+
):
|
293 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
294 |
+
batch_size = len(prompt)
|
295 |
+
|
296 |
+
text_inputs = self.tokenizer_two(
|
297 |
+
prompt,
|
298 |
+
padding="max_length",
|
299 |
+
max_length=max_sequence_length,
|
300 |
+
truncation=True,
|
301 |
+
return_length=False,
|
302 |
+
return_overflowing_tokens=False,
|
303 |
+
return_tensors="pt",
|
304 |
+
)
|
305 |
+
text_input_ids = text_inputs.input_ids
|
306 |
+
prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0]
|
307 |
+
|
308 |
+
dtype = self.text_encoder_two.dtype
|
309 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
310 |
+
|
311 |
+
_, seq_len, _ = prompt_embeds.shape
|
312 |
+
|
313 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
314 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
315 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
316 |
+
|
317 |
+
return prompt_embeds
|
318 |
+
|
319 |
+
def process_image(self, image):
|
320 |
+
message = [
|
321 |
+
{
|
322 |
+
"role": "user",
|
323 |
+
"content": [
|
324 |
+
{"type": "image", "image": image},
|
325 |
+
{"type": "text", "text": "Describe this image."},
|
326 |
+
]
|
327 |
+
}
|
328 |
+
]
|
329 |
+
text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
330 |
+
|
331 |
+
with torch.no_grad():
|
332 |
+
inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
|
333 |
+
output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs)
|
334 |
+
image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
|
335 |
+
|
336 |
+
return image_hidden_state, image_grid_thw
|
337 |
+
|
338 |
+
def resize_image(self, img, max_pixels=1050000):
|
339 |
+
# 确保输入是 PIL Image
|
340 |
+
if not isinstance(img, Image.Image):
|
341 |
+
img = Image.fromarray(img)
|
342 |
+
|
343 |
+
width, height = img.size
|
344 |
+
num_pixels = width * height
|
345 |
+
|
346 |
+
if num_pixels > max_pixels:
|
347 |
+
scale = math.sqrt(max_pixels / num_pixels)
|
348 |
+
new_width = int(width * scale)
|
349 |
+
new_height = int(height * scale)
|
350 |
+
# 调整宽度和高度,使其能被8整除
|
351 |
+
new_width = new_width - (new_width % 8)
|
352 |
+
new_height = new_height - (new_height % 8)
|
353 |
+
img = img.resize((new_width, new_height), Image.LANCZOS)
|
354 |
+
else:
|
355 |
+
# 如果图片不需要缩小,仍然需要确保尺寸能被8整除
|
356 |
+
new_width = width - (width % 8)
|
357 |
+
new_height = height - (height % 8)
|
358 |
+
if new_width != width or new_height != height:
|
359 |
+
img = img.resize((new_width, new_height), Image.LANCZOS)
|
360 |
+
|
361 |
+
return img
|
362 |
+
|
363 |
+
def generate_depth_map(self, image):
|
364 |
+
"""Generate depth map using Depth Anything V2"""
|
365 |
+
# Convert PIL to numpy array
|
366 |
+
image_np = np.array(image)
|
367 |
+
|
368 |
+
# Convert RGB to BGR for cv2
|
369 |
+
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
370 |
+
|
371 |
+
# Generate depth map
|
372 |
+
with torch.no_grad():
|
373 |
+
depth = self.depth_model.infer_image(image_bgr)
|
374 |
+
|
375 |
+
# Normalize depth to 0-1 range
|
376 |
+
depth_norm = (depth - depth.min()) / (depth.max() - depth.min())
|
377 |
+
|
378 |
+
# Convert to RGB image
|
379 |
+
depth_rgb = (depth_norm * 255).astype(np.uint8)
|
380 |
+
depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB)
|
381 |
+
|
382 |
+
return Image.fromarray(depth_rgb)
|
383 |
+
|
384 |
+
|
385 |
+
def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28,
|
386 |
+
aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation",
|
387 |
+
denoise_strength=0.8, mask_image=None, imageCount=2,
|
388 |
+
line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2):
|
389 |
+
|
390 |
+
batch_size = imageCount
|
391 |
+
if aspect_ratio not in ASPECT_RATIOS:
|
392 |
+
raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
|
393 |
+
|
394 |
+
width, height = ASPECT_RATIOS[aspect_ratio]
|
395 |
+
|
396 |
+
pooled_prompt_embeds = self.compute_text_embeddings(prompt="")
|
397 |
+
t5_prompt_embeds = None
|
398 |
+
if prompt != "":
|
399 |
+
self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28)
|
400 |
+
t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device)
|
401 |
+
t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds)
|
402 |
+
else:
|
403 |
+
self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28)
|
404 |
+
|
405 |
+
qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a)
|
406 |
+
# 只有当所有注意力参数都被提供时,才应用注意力机制
|
407 |
+
if mode == "variation":
|
408 |
+
if center_x is not None and center_y is not None and radius is not None:
|
409 |
+
qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius)
|
410 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
411 |
+
|
412 |
+
if mode == "img2img" or mode == "inpaint":
|
413 |
+
if input_image_b:
|
414 |
+
qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
|
415 |
+
if center_x is not None and center_y is not None and radius is not None:
|
416 |
+
qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
|
417 |
+
qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
|
418 |
+
else:
|
419 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
420 |
+
qwen2_hidden_state_b = None
|
421 |
+
|
422 |
+
if mode == "controlnet" or mode == "controlnet-inpaint":
|
423 |
+
qwen2_hidden_state_b = None
|
424 |
+
if input_image_b:
|
425 |
+
qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
|
426 |
+
if center_x is not None and center_y is not None and radius is not None:
|
427 |
+
qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
|
428 |
+
qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
|
429 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
430 |
+
|
431 |
+
#############################
|
432 |
+
# IMAGE GENERATION
|
433 |
+
#############################
|
434 |
+
if mode == "variation":
|
435 |
+
# Initialize different pipelines
|
436 |
+
pipeline = FluxPipeline(
|
437 |
+
transformer=self.transformer,
|
438 |
+
scheduler=self.noise_scheduler,
|
439 |
+
vae=self.vae,
|
440 |
+
text_encoder=self.text_encoder,
|
441 |
+
tokenizer=self.tokenizer,
|
442 |
+
)
|
443 |
+
|
444 |
+
gen_images = pipeline(
|
445 |
+
prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
446 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
447 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
448 |
+
num_inference_steps=num_inference_steps,
|
449 |
+
guidance_scale=guidance_scale,
|
450 |
+
height=height,
|
451 |
+
width=width,
|
452 |
+
).images
|
453 |
+
|
454 |
+
|
455 |
+
#############################
|
456 |
+
# IMAGE-TO-IMAGE
|
457 |
+
#############################
|
458 |
+
elif mode == "img2img":
|
459 |
+
input_image_a = self.resize_image(input_image_a)
|
460 |
+
width, height = input_image_a.size
|
461 |
+
|
462 |
+
img2img_pipeline = FluxImg2ImgPipeline(
|
463 |
+
transformer=self.transformer,
|
464 |
+
scheduler=self.noise_scheduler,
|
465 |
+
vae=self.vae,
|
466 |
+
text_encoder=self.text_encoder,
|
467 |
+
tokenizer=self.tokenizer,
|
468 |
+
)
|
469 |
+
|
470 |
+
gen_images = img2img_pipeline(
|
471 |
+
image=input_image_a,
|
472 |
+
strength=denoise_strength,
|
473 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
474 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
475 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
476 |
+
num_inference_steps=num_inference_steps,
|
477 |
+
guidance_scale=guidance_scale,
|
478 |
+
height=height,
|
479 |
+
width=width,
|
480 |
+
).images
|
481 |
+
|
482 |
+
|
483 |
+
#############################
|
484 |
+
# INPAINTING
|
485 |
+
#############################
|
486 |
+
elif mode == "inpaint":
|
487 |
+
if mask_image is None:
|
488 |
+
raise ValueError("Mask image is required for inpainting mode")
|
489 |
+
|
490 |
+
input_image_a = self.resize_image(input_image_a)
|
491 |
+
mask_image = self.resize_image(mask_image)
|
492 |
+
width, height = input_image_a.size
|
493 |
+
|
494 |
+
inpaint_pipeline = FluxInpaintPipeline(
|
495 |
+
transformer=self.transformer,
|
496 |
+
scheduler=self.noise_scheduler,
|
497 |
+
vae=self.vae,
|
498 |
+
text_encoder=self.text_encoder,
|
499 |
+
tokenizer=self.tokenizer,
|
500 |
+
)
|
501 |
+
|
502 |
+
gen_images = inpaint_pipeline(
|
503 |
+
image=input_image_a,
|
504 |
+
mask_image=mask_image,
|
505 |
+
strength=denoise_strength,
|
506 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
507 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
508 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
509 |
+
num_inference_steps=num_inference_steps,
|
510 |
+
guidance_scale=guidance_scale,
|
511 |
+
height=height,
|
512 |
+
width=width,
|
513 |
+
).images
|
514 |
+
|
515 |
+
#############################
|
516 |
+
# CONTROLNET
|
517 |
+
#############################
|
518 |
+
elif mode == "controlnet":
|
519 |
+
input_image_a = self.resize_image(input_image_a)
|
520 |
+
width, height = input_image_a.size
|
521 |
+
|
522 |
+
controlnet_pipeline = FluxControlNetImg2ImgPipeline(
|
523 |
+
transformer=self.transformer,
|
524 |
+
scheduler=self.noise_scheduler,
|
525 |
+
vae=self.vae,
|
526 |
+
text_encoder=self.text_encoder,
|
527 |
+
tokenizer=self.tokenizer,
|
528 |
+
controlnet=self.controlnet,
|
529 |
+
)
|
530 |
+
|
531 |
+
# 准备控制图像和模式列表
|
532 |
+
control_images = []
|
533 |
+
control_modes = []
|
534 |
+
conditioning_scales = []
|
535 |
+
|
536 |
+
# 根据用户选择添加控制模式
|
537 |
+
if depth_mode:
|
538 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
539 |
+
control_images.append(control_image_depth)
|
540 |
+
control_modes.append(2) # depth mode
|
541 |
+
conditioning_scales.append(depth_strength)
|
542 |
+
|
543 |
+
if line_mode:
|
544 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
545 |
+
control_images.append(control_image_canny)
|
546 |
+
control_modes.append(0) # line mode
|
547 |
+
conditioning_scales.append(line_strength)
|
548 |
+
|
549 |
+
# 如果没有启用任何模式,默认使用line+depth模式
|
550 |
+
if not line_mode and not depth_mode:
|
551 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
552 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
553 |
+
control_images = [control_image_depth, control_image_canny]
|
554 |
+
control_modes = [2, 0]
|
555 |
+
conditioning_scales = [0.2, 0.4]
|
556 |
+
|
557 |
+
if qwen2_hidden_state_b is not None:
|
558 |
+
qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
|
559 |
+
qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
|
560 |
+
|
561 |
+
gen_images = controlnet_pipeline(
|
562 |
+
image=input_image_a,
|
563 |
+
strength=denoise_strength,
|
564 |
+
control_image=control_images,
|
565 |
+
control_mode=control_modes,
|
566 |
+
controlnet_conditioning_scale=conditioning_scales,
|
567 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
568 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
569 |
+
prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
570 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
571 |
+
num_inference_steps=num_inference_steps,
|
572 |
+
guidance_scale=guidance_scale,
|
573 |
+
height=height,
|
574 |
+
width=width,
|
575 |
+
).images
|
576 |
+
|
577 |
+
#############################
|
578 |
+
# CONTROLNET INPAINT
|
579 |
+
#############################
|
580 |
+
elif mode == "controlnet-inpaint":
|
581 |
+
input_image_a = self.resize_image(input_image_a)
|
582 |
+
mask_image = self.resize_image(mask_image)
|
583 |
+
width, height = input_image_a.size
|
584 |
+
|
585 |
+
controlnet_pipeline = FluxControlNetInpaintPipeline(
|
586 |
+
transformer=self.transformer,
|
587 |
+
scheduler=self.noise_scheduler,
|
588 |
+
vae=self.vae,
|
589 |
+
text_encoder=self.text_encoder,
|
590 |
+
tokenizer=self.tokenizer,
|
591 |
+
controlnet=self.controlnet,
|
592 |
+
)
|
593 |
+
|
594 |
+
# 准备控制图像和模式列表
|
595 |
+
control_images = []
|
596 |
+
control_modes = []
|
597 |
+
conditioning_scales = []
|
598 |
+
|
599 |
+
# 根据用户选择添加控制模式
|
600 |
+
if depth_mode:
|
601 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
602 |
+
control_images.append(control_image_depth)
|
603 |
+
control_modes.append(2) # depth mode
|
604 |
+
conditioning_scales.append(depth_strength)
|
605 |
+
|
606 |
+
if line_mode:
|
607 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
608 |
+
control_images.append(control_image_canny)
|
609 |
+
control_modes.append(0) # line mode
|
610 |
+
conditioning_scales.append(line_strength)
|
611 |
+
|
612 |
+
# 如果没有启用任何模式,默认使用line+depth模式
|
613 |
+
if not line_mode and not depth_mode:
|
614 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
615 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
616 |
+
control_images = [control_image_depth, control_image_canny]
|
617 |
+
control_modes = [2, 0]
|
618 |
+
conditioning_scales = [0.2, 0.4]
|
619 |
+
|
620 |
+
if qwen2_hidden_state_b is not None:
|
621 |
+
qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
|
622 |
+
qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
|
623 |
+
|
624 |
+
gen_images = controlnet_pipeline(
|
625 |
+
image=input_image_a,
|
626 |
+
mask_image=mask_image,
|
627 |
+
control_image=control_images,
|
628 |
+
control_mode=control_modes,
|
629 |
+
controlnet_conditioning_scale=conditioning_scales,
|
630 |
+
strength=denoise_strength,
|
631 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
632 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
633 |
+
prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
634 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
635 |
+
num_inference_steps=num_inference_steps,
|
636 |
+
guidance_scale=guidance_scale,
|
637 |
+
height=height,
|
638 |
+
width=width,
|
639 |
+
).images
|
640 |
+
|
641 |
+
else:
|
642 |
+
raise ValueError(f"Invalid mode: {mode}")
|
643 |
+
|
644 |
+
return gen_images
|
modelmod.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast, BitsAndBytesConfig
|
5 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
6 |
+
from flux.transformer_flux import FluxTransformer2DModel
|
7 |
+
|
8 |
+
from flux.pipeline_flux_chameleon import FluxPipeline
|
9 |
+
from flux.pipeline_flux_img2img import FluxImg2ImgPipeline
|
10 |
+
from flux.pipeline_flux_inpaint import FluxInpaintPipeline
|
11 |
+
from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel
|
12 |
+
from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline
|
13 |
+
from flux.controlnet_flux import FluxMultiControlNetModel
|
14 |
+
from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
15 |
+
|
16 |
+
from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
|
17 |
+
import os
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import math
|
21 |
+
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
22 |
+
|
23 |
+
|
24 |
+
def get_model_path(model_name):
|
25 |
+
"""Get the full path for a model based on the checkpoints directory."""
|
26 |
+
base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override
|
27 |
+
return os.path.join(base_dir, model_name)
|
28 |
+
|
29 |
+
# Model paths configuration
|
30 |
+
MODEL_PATHS = {
|
31 |
+
'flux': get_model_path('flux'),
|
32 |
+
'qwen2vl': get_model_path('qwen2-vl'),
|
33 |
+
'controlnet': get_model_path('controlnet'),
|
34 |
+
'depth_anything': {
|
35 |
+
'path': get_model_path('depth-anything-v2'),
|
36 |
+
'weights': 'depth_anything_v2_vitl.pth'
|
37 |
+
},
|
38 |
+
'anyline': {
|
39 |
+
'path': get_model_path('anyline'),
|
40 |
+
'weights': 'MTEED.pth'
|
41 |
+
},
|
42 |
+
'sam2': {
|
43 |
+
'path': get_model_path('segment-anything-2'),
|
44 |
+
'weights': 'sam2_hiera_large.pt',
|
45 |
+
'config': 'sam2_hiera_l.yaml'
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
ASPECT_RATIOS = {
|
51 |
+
"1:1": (1024, 1024),
|
52 |
+
"16:9": (1344, 768),
|
53 |
+
"9:16": (768, 1344),
|
54 |
+
"2.4:1": (1536, 640),
|
55 |
+
"3:4": (896, 1152),
|
56 |
+
"4:3": (1152, 896),
|
57 |
+
}
|
58 |
+
|
59 |
+
class Qwen2Connector(nn.Module):
|
60 |
+
def __init__(self, input_dim=3584, output_dim=4096):
|
61 |
+
super().__init__()
|
62 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.linear(x)
|
66 |
+
|
67 |
+
class FluxModel:
|
68 |
+
def __init__(self, is_turbo=False, device="cuda", required_features=None, is_quantization=True):
|
69 |
+
"""
|
70 |
+
Initialize FluxModel with specified features
|
71 |
+
Args:
|
72 |
+
is_turbo: Enable turbo mode for faster inference
|
73 |
+
device: Device to run the model on
|
74 |
+
required_features: List of required features ['controlnet', 'depth', 'line', 'sam']
|
75 |
+
"""
|
76 |
+
self.device = torch.device(device)
|
77 |
+
self.qkwargs = {"quantization_config": nf4_config} if is_quantization else {}
|
78 |
+
self.dtype = torch.bfloat16
|
79 |
+
if required_features is None:
|
80 |
+
required_features = []
|
81 |
+
|
82 |
+
self._line_detector_imported = False
|
83 |
+
self._depth_model_imported = False
|
84 |
+
self._sam_imported = False
|
85 |
+
self._turbo_imported = False
|
86 |
+
|
87 |
+
# Initialize base models (always required)
|
88 |
+
self._init_base_models()
|
89 |
+
|
90 |
+
# Initialize optional models based on requirements
|
91 |
+
if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']):
|
92 |
+
self._init_controlnet()
|
93 |
+
|
94 |
+
if 'depth' in required_features:
|
95 |
+
self._init_depth_model()
|
96 |
+
|
97 |
+
if 'line' in required_features:
|
98 |
+
self._init_line_detector()
|
99 |
+
|
100 |
+
if 'sam' in required_features:
|
101 |
+
self._init_sam()
|
102 |
+
|
103 |
+
if is_turbo:
|
104 |
+
self._enable_turbo()
|
105 |
+
|
106 |
+
def _init_base_models(self):
|
107 |
+
"""Initialize the core models that are always needed"""
|
108 |
+
# Qwen2VL and connector initialization
|
109 |
+
self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
|
110 |
+
MODEL_PATHS['qwen2vl'],
|
111 |
+
torch_dtype=self.dtype,
|
112 |
+
**self.qkwargs
|
113 |
+
)
|
114 |
+
self.qwen2vl.requires_grad_(False).to(self.device)
|
115 |
+
|
116 |
+
self.connector = Qwen2Connector(input_dim=3584, output_dim=4096)
|
117 |
+
connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt")
|
118 |
+
if os.path.exists(connector_path):
|
119 |
+
connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True)
|
120 |
+
connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()}
|
121 |
+
self.connector.load_state_dict(connector_state_dict)
|
122 |
+
self.connector.to(self.dtype).to(self.device)
|
123 |
+
|
124 |
+
# Text encoders initialization
|
125 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer")
|
126 |
+
self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder")
|
127 |
+
self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2", **self.qkwargs)
|
128 |
+
self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2")
|
129 |
+
|
130 |
+
self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device)
|
131 |
+
#self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device)
|
132 |
+
self.text_encoder_two.requires_grad_(False).to(self.device)
|
133 |
+
|
134 |
+
# T5 context embedder
|
135 |
+
self.t5_context_embedder = nn.Linear(4096, 3072)
|
136 |
+
t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt")
|
137 |
+
t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True)
|
138 |
+
self.t5_context_embedder.load_state_dict(t5_embedder_state_dict)
|
139 |
+
self.t5_context_embedder.to(self.dtype).to(self.device)
|
140 |
+
|
141 |
+
# Basic components
|
142 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1)
|
143 |
+
self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae")
|
144 |
+
self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer", **self.qkwargs)
|
145 |
+
|
146 |
+
self.vae.requires_grad_(False).to(self.dtype).to(self.device)
|
147 |
+
#self.transformer.requires_grad_(False).to(self.dtype).to(self.device)
|
148 |
+
self.transformer.requires_grad_(False).to(self.device)
|
149 |
+
|
150 |
+
def _init_controlnet(self):
|
151 |
+
"""Initialize ControlNet model"""
|
152 |
+
self.controlnet_union = FluxControlNetModel.from_pretrained(
|
153 |
+
MODEL_PATHS['controlnet'],
|
154 |
+
torch_dtype=torch.bfloat16
|
155 |
+
)
|
156 |
+
self.controlnet_union.requires_grad_(False).to(self.device)
|
157 |
+
self.controlnet = FluxMultiControlNetModel([self.controlnet_union])
|
158 |
+
|
159 |
+
def _init_depth_model(self):
|
160 |
+
"""Initialize Depth Anything V2 model"""
|
161 |
+
if not self._depth_model_imported:
|
162 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
163 |
+
self._depth_model_imported = True
|
164 |
+
|
165 |
+
self.depth_model = DepthAnythingV2(
|
166 |
+
encoder='vitl',
|
167 |
+
features=256,
|
168 |
+
out_channels=[256, 512, 1024, 1024]
|
169 |
+
)
|
170 |
+
depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'],
|
171 |
+
MODEL_PATHS['depth_anything']['weights'])
|
172 |
+
self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device))
|
173 |
+
self.depth_model.requires_grad_(False).to(self.device)
|
174 |
+
|
175 |
+
def _init_line_detector(self):
|
176 |
+
"""Initialize line detection model"""
|
177 |
+
if not self._line_detector_imported:
|
178 |
+
from controlnet_aux import AnylineDetector
|
179 |
+
self._line_detector_imported = True
|
180 |
+
|
181 |
+
self.anyline = AnylineDetector.from_pretrained(
|
182 |
+
MODEL_PATHS['anyline']['path'],
|
183 |
+
filename=MODEL_PATHS['anyline']['weights']
|
184 |
+
)
|
185 |
+
self.anyline.to(self.device)
|
186 |
+
|
187 |
+
def _init_sam(self):
|
188 |
+
"""Initialize SAM2 model"""
|
189 |
+
if not self._sam_imported:
|
190 |
+
from sam2.build_sam import build_sam2
|
191 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
192 |
+
self._sam_imported = True
|
193 |
+
|
194 |
+
sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'],
|
195 |
+
MODEL_PATHS['sam2']['weights'])
|
196 |
+
model_cfg = os.path.join(MODEL_PATHS['sam2']['path'],
|
197 |
+
MODEL_PATHS['sam2']['config'])
|
198 |
+
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
|
199 |
+
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
|
200 |
+
|
201 |
+
def _enable_turbo(self):
|
202 |
+
"""Enable turbo mode for faster inference"""
|
203 |
+
if not self._turbo_imported:
|
204 |
+
from optimum.quanto import freeze, qfloat8, quantize
|
205 |
+
self._turbo_imported = True
|
206 |
+
|
207 |
+
quantize(
|
208 |
+
self.transformer,
|
209 |
+
weights=qfloat8,
|
210 |
+
exclude=[
|
211 |
+
"*.norm", "*.norm1", "*.norm2", "*.norm2_context",
|
212 |
+
"proj_out", "x_embedder", "norm_out", "context_embedder",
|
213 |
+
],
|
214 |
+
)
|
215 |
+
freeze(self.transformer)
|
216 |
+
|
217 |
+
def generate_mask(self, image, input_points, input_labels):
|
218 |
+
"""
|
219 |
+
使用SAM2生成分割mask
|
220 |
+
|
221 |
+
Args:
|
222 |
+
image: PIL Image或numpy数组
|
223 |
+
input_points: numpy数组,形状为(N, 2),包含点的坐标
|
224 |
+
input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
PIL Image: 最高分数的mask
|
228 |
+
"""
|
229 |
+
try:
|
230 |
+
# 确保图像是numpy数组
|
231 |
+
if isinstance(image, Image.Image):
|
232 |
+
image_array = np.array(image)
|
233 |
+
else:
|
234 |
+
image_array = image
|
235 |
+
|
236 |
+
# 设置图像
|
237 |
+
self.sam2_predictor.set_image(image_array)
|
238 |
+
|
239 |
+
# 进行预测
|
240 |
+
with torch.inference_mode():
|
241 |
+
masks, scores, logits = self.sam2_predictor.predict(
|
242 |
+
point_coords=input_points,
|
243 |
+
point_labels=input_labels,
|
244 |
+
multimask_output=True,
|
245 |
+
)
|
246 |
+
|
247 |
+
# 返回得分最高的mask
|
248 |
+
best_mask_idx = scores.argmax()
|
249 |
+
mask = masks[best_mask_idx]
|
250 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
251 |
+
return mask_image
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
print(f"Mask generation failed: {str(e)}")
|
255 |
+
raise
|
256 |
+
|
257 |
+
def recover_2d_shape(self, image_hidden_state, grid_thw):
|
258 |
+
batch_size, num_tokens, hidden_dim = image_hidden_state.shape
|
259 |
+
_, h, w = grid_thw
|
260 |
+
h_out = h // 2
|
261 |
+
w_out = w // 2
|
262 |
+
# 重塑为 (batch_size, height, width, hidden_dim)
|
263 |
+
reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim)
|
264 |
+
return reshaped
|
265 |
+
|
266 |
+
def generate_attention_matrix(self, center_x, center_y, radius, image_shape):
|
267 |
+
height, width = image_shape
|
268 |
+
y, x = np.ogrid[:height, :width]
|
269 |
+
center_y, center_x = center_y * height, center_x * width
|
270 |
+
distances = np.sqrt((x - center_x)**2 + (y - center_y)**2)
|
271 |
+
attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1)
|
272 |
+
return attention
|
273 |
+
|
274 |
+
def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius):
|
275 |
+
qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0]))
|
276 |
+
attention_matrix = self.generate_attention_matrix(
|
277 |
+
center_x, center_y, radius,
|
278 |
+
(qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2))
|
279 |
+
)
|
280 |
+
attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1)
|
281 |
+
qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device)
|
282 |
+
return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3))
|
283 |
+
|
284 |
+
def compute_text_embeddings(self, prompt):
|
285 |
+
with torch.no_grad():
|
286 |
+
text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
287 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
288 |
+
prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False)
|
289 |
+
pooled_prompt_embeds = prompt_embeds.pooler_output
|
290 |
+
return pooled_prompt_embeds.to(self.dtype)
|
291 |
+
|
292 |
+
def compute_t5_text_embeddings(
|
293 |
+
self,
|
294 |
+
max_sequence_length=256,
|
295 |
+
prompt=None,
|
296 |
+
num_images_per_prompt=1,
|
297 |
+
device=None,
|
298 |
+
):
|
299 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
300 |
+
batch_size = len(prompt)
|
301 |
+
|
302 |
+
text_inputs = self.tokenizer_two(
|
303 |
+
prompt,
|
304 |
+
padding="max_length",
|
305 |
+
max_length=max_sequence_length,
|
306 |
+
truncation=True,
|
307 |
+
return_length=False,
|
308 |
+
return_overflowing_tokens=False,
|
309 |
+
return_tensors="pt",
|
310 |
+
)
|
311 |
+
text_input_ids = text_inputs.input_ids
|
312 |
+
prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0]
|
313 |
+
|
314 |
+
dtype = self.text_encoder_two.dtype
|
315 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
316 |
+
|
317 |
+
_, seq_len, _ = prompt_embeds.shape
|
318 |
+
|
319 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
320 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
321 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
322 |
+
|
323 |
+
return prompt_embeds
|
324 |
+
|
325 |
+
def process_image(self, image):
|
326 |
+
message = [
|
327 |
+
{
|
328 |
+
"role": "user",
|
329 |
+
"content": [
|
330 |
+
{"type": "image", "image": image},
|
331 |
+
{"type": "text", "text": "Describe this image."},
|
332 |
+
]
|
333 |
+
}
|
334 |
+
]
|
335 |
+
text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
336 |
+
|
337 |
+
with torch.no_grad():
|
338 |
+
inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
|
339 |
+
output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs)
|
340 |
+
image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
|
341 |
+
|
342 |
+
return image_hidden_state, image_grid_thw
|
343 |
+
|
344 |
+
def resize_image(self, img, max_pixels=1050000):
|
345 |
+
# 确保输入是 PIL Image
|
346 |
+
if not isinstance(img, Image.Image):
|
347 |
+
img = Image.fromarray(img)
|
348 |
+
|
349 |
+
width, height = img.size
|
350 |
+
num_pixels = width * height
|
351 |
+
|
352 |
+
if num_pixels > max_pixels:
|
353 |
+
scale = math.sqrt(max_pixels / num_pixels)
|
354 |
+
new_width = int(width * scale)
|
355 |
+
new_height = int(height * scale)
|
356 |
+
# 调整宽度和高度,使其能被8整除
|
357 |
+
new_width = new_width - (new_width % 8)
|
358 |
+
new_height = new_height - (new_height % 8)
|
359 |
+
img = img.resize((new_width, new_height), Image.LANCZOS)
|
360 |
+
else:
|
361 |
+
# 如果图片不需要缩小,仍然需要确保尺寸能被8整除
|
362 |
+
new_width = width - (width % 8)
|
363 |
+
new_height = height - (height % 8)
|
364 |
+
if new_width != width or new_height != height:
|
365 |
+
img = img.resize((new_width, new_height), Image.LANCZOS)
|
366 |
+
|
367 |
+
return img
|
368 |
+
|
369 |
+
def generate_depth_map(self, image):
|
370 |
+
"""Generate depth map using Depth Anything V2"""
|
371 |
+
# Convert PIL to numpy array
|
372 |
+
image_np = np.array(image)
|
373 |
+
|
374 |
+
# Convert RGB to BGR for cv2
|
375 |
+
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
376 |
+
|
377 |
+
# Generate depth map
|
378 |
+
with torch.no_grad():
|
379 |
+
depth = self.depth_model.infer_image(image_bgr)
|
380 |
+
|
381 |
+
# Normalize depth to 0-1 range
|
382 |
+
depth_norm = (depth - depth.min()) / (depth.max() - depth.min())
|
383 |
+
|
384 |
+
# Convert to RGB image
|
385 |
+
depth_rgb = (depth_norm * 255).astype(np.uint8)
|
386 |
+
depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB)
|
387 |
+
|
388 |
+
return Image.fromarray(depth_rgb)
|
389 |
+
|
390 |
+
|
391 |
+
def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28,
|
392 |
+
aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation",
|
393 |
+
denoise_strength=0.8, mask_image=None, imageCount=2,
|
394 |
+
line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2):
|
395 |
+
|
396 |
+
batch_size = imageCount
|
397 |
+
if aspect_ratio not in ASPECT_RATIOS:
|
398 |
+
raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
|
399 |
+
|
400 |
+
width, height = ASPECT_RATIOS[aspect_ratio]
|
401 |
+
|
402 |
+
pooled_prompt_embeds = self.compute_text_embeddings(prompt="")
|
403 |
+
t5_prompt_embeds = None
|
404 |
+
if prompt != "":
|
405 |
+
self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28)
|
406 |
+
t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device).to(self.dtype)
|
407 |
+
t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds)
|
408 |
+
else:
|
409 |
+
self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28)
|
410 |
+
|
411 |
+
qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a)
|
412 |
+
# 只有当所有注意力参数都被提供时,才应用注意力机制
|
413 |
+
if mode == "variation":
|
414 |
+
if center_x is not None and center_y is not None and radius is not None:
|
415 |
+
qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius)
|
416 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
417 |
+
|
418 |
+
if mode == "img2img" or mode == "inpaint":
|
419 |
+
if input_image_b:
|
420 |
+
qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
|
421 |
+
if center_x is not None and center_y is not None and radius is not None:
|
422 |
+
qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
|
423 |
+
qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
|
424 |
+
else:
|
425 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
426 |
+
qwen2_hidden_state_b = None
|
427 |
+
|
428 |
+
if mode == "controlnet" or mode == "controlnet-inpaint":
|
429 |
+
qwen2_hidden_state_b = None
|
430 |
+
if input_image_b:
|
431 |
+
qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b)
|
432 |
+
if center_x is not None and center_y is not None and radius is not None:
|
433 |
+
qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius)
|
434 |
+
qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b)
|
435 |
+
qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a)
|
436 |
+
|
437 |
+
#############################
|
438 |
+
# IMAGE GENERATION
|
439 |
+
#############################
|
440 |
+
if mode == "variation":
|
441 |
+
# Initialize different pipelines
|
442 |
+
pipeline = FluxPipeline(
|
443 |
+
transformer=self.transformer,
|
444 |
+
scheduler=self.noise_scheduler,
|
445 |
+
vae=self.vae,
|
446 |
+
text_encoder=self.text_encoder,
|
447 |
+
tokenizer=self.tokenizer,
|
448 |
+
)
|
449 |
+
|
450 |
+
gen_images = pipeline(
|
451 |
+
prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
452 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
453 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
454 |
+
num_inference_steps=num_inference_steps,
|
455 |
+
guidance_scale=guidance_scale,
|
456 |
+
height=height,
|
457 |
+
width=width,
|
458 |
+
).images
|
459 |
+
|
460 |
+
|
461 |
+
#############################
|
462 |
+
# IMAGE-TO-IMAGE
|
463 |
+
#############################
|
464 |
+
elif mode == "img2img":
|
465 |
+
input_image_a = self.resize_image(input_image_a)
|
466 |
+
width, height = input_image_a.size
|
467 |
+
|
468 |
+
img2img_pipeline = FluxImg2ImgPipeline(
|
469 |
+
transformer=self.transformer,
|
470 |
+
scheduler=self.noise_scheduler,
|
471 |
+
vae=self.vae,
|
472 |
+
text_encoder=self.text_encoder,
|
473 |
+
tokenizer=self.tokenizer,
|
474 |
+
)
|
475 |
+
|
476 |
+
gen_images = img2img_pipeline(
|
477 |
+
image=input_image_a,
|
478 |
+
strength=denoise_strength,
|
479 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
480 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
481 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
482 |
+
num_inference_steps=num_inference_steps,
|
483 |
+
guidance_scale=guidance_scale,
|
484 |
+
height=height,
|
485 |
+
width=width,
|
486 |
+
).images
|
487 |
+
|
488 |
+
|
489 |
+
#############################
|
490 |
+
# INPAINTING
|
491 |
+
#############################
|
492 |
+
elif mode == "inpaint":
|
493 |
+
if mask_image is None:
|
494 |
+
raise ValueError("Mask image is required for inpainting mode")
|
495 |
+
|
496 |
+
input_image_a = self.resize_image(input_image_a)
|
497 |
+
mask_image = self.resize_image(mask_image)
|
498 |
+
width, height = input_image_a.size
|
499 |
+
|
500 |
+
inpaint_pipeline = FluxInpaintPipeline(
|
501 |
+
transformer=self.transformer,
|
502 |
+
scheduler=self.noise_scheduler,
|
503 |
+
vae=self.vae,
|
504 |
+
text_encoder=self.text_encoder,
|
505 |
+
tokenizer=self.tokenizer,
|
506 |
+
)
|
507 |
+
|
508 |
+
gen_images = inpaint_pipeline(
|
509 |
+
image=input_image_a,
|
510 |
+
mask_image=mask_image,
|
511 |
+
strength=denoise_strength,
|
512 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
513 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
514 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
515 |
+
num_inference_steps=num_inference_steps,
|
516 |
+
guidance_scale=guidance_scale,
|
517 |
+
height=height,
|
518 |
+
width=width,
|
519 |
+
).images
|
520 |
+
|
521 |
+
#############################
|
522 |
+
# CONTROLNET
|
523 |
+
#############################
|
524 |
+
elif mode == "controlnet":
|
525 |
+
input_image_a = self.resize_image(input_image_a)
|
526 |
+
width, height = input_image_a.size
|
527 |
+
|
528 |
+
controlnet_pipeline = FluxControlNetImg2ImgPipeline(
|
529 |
+
transformer=self.transformer,
|
530 |
+
scheduler=self.noise_scheduler,
|
531 |
+
vae=self.vae,
|
532 |
+
text_encoder=self.text_encoder,
|
533 |
+
tokenizer=self.tokenizer,
|
534 |
+
controlnet=self.controlnet,
|
535 |
+
)
|
536 |
+
|
537 |
+
# 准备控制图像和模式列表
|
538 |
+
control_images = []
|
539 |
+
control_modes = []
|
540 |
+
conditioning_scales = []
|
541 |
+
|
542 |
+
# 根据用户选择添加控制模式
|
543 |
+
if depth_mode:
|
544 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
545 |
+
control_images.append(control_image_depth)
|
546 |
+
control_modes.append(2) # depth mode
|
547 |
+
conditioning_scales.append(depth_strength)
|
548 |
+
|
549 |
+
if line_mode:
|
550 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
551 |
+
control_images.append(control_image_canny)
|
552 |
+
control_modes.append(0) # line mode
|
553 |
+
conditioning_scales.append(line_strength)
|
554 |
+
|
555 |
+
# 如果没有启用任何模式,默认使用line+depth模式
|
556 |
+
if not line_mode and not depth_mode:
|
557 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
558 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
559 |
+
control_images = [control_image_depth, control_image_canny]
|
560 |
+
control_modes = [2, 0]
|
561 |
+
conditioning_scales = [0.2, 0.4]
|
562 |
+
|
563 |
+
if qwen2_hidden_state_b is not None:
|
564 |
+
qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
|
565 |
+
qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
|
566 |
+
|
567 |
+
gen_images = controlnet_pipeline(
|
568 |
+
image=input_image_a,
|
569 |
+
strength=denoise_strength,
|
570 |
+
control_image=control_images,
|
571 |
+
control_mode=control_modes,
|
572 |
+
controlnet_conditioning_scale=conditioning_scales,
|
573 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
574 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
575 |
+
prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
576 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
577 |
+
num_inference_steps=num_inference_steps,
|
578 |
+
guidance_scale=guidance_scale,
|
579 |
+
height=height,
|
580 |
+
width=width,
|
581 |
+
).images
|
582 |
+
|
583 |
+
#############################
|
584 |
+
# CONTROLNET INPAINT
|
585 |
+
#############################
|
586 |
+
elif mode == "controlnet-inpaint":
|
587 |
+
input_image_a = self.resize_image(input_image_a)
|
588 |
+
mask_image = self.resize_image(mask_image)
|
589 |
+
width, height = input_image_a.size
|
590 |
+
|
591 |
+
controlnet_pipeline = FluxControlNetInpaintPipeline(
|
592 |
+
transformer=self.transformer,
|
593 |
+
scheduler=self.noise_scheduler,
|
594 |
+
vae=self.vae,
|
595 |
+
text_encoder=self.text_encoder,
|
596 |
+
tokenizer=self.tokenizer,
|
597 |
+
controlnet=self.controlnet,
|
598 |
+
)
|
599 |
+
|
600 |
+
# 准备控制图像和模式列表
|
601 |
+
control_images = []
|
602 |
+
control_modes = []
|
603 |
+
conditioning_scales = []
|
604 |
+
|
605 |
+
# 根据用户选择添加控制模式
|
606 |
+
if depth_mode:
|
607 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
608 |
+
control_images.append(control_image_depth)
|
609 |
+
control_modes.append(2) # depth mode
|
610 |
+
conditioning_scales.append(depth_strength)
|
611 |
+
|
612 |
+
if line_mode:
|
613 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
614 |
+
control_images.append(control_image_canny)
|
615 |
+
control_modes.append(0) # line mode
|
616 |
+
conditioning_scales.append(line_strength)
|
617 |
+
|
618 |
+
# 如果没有启用任何模式,默认使用line+depth模式
|
619 |
+
if not line_mode and not depth_mode:
|
620 |
+
control_image_depth = self.generate_depth_map(input_image_a)
|
621 |
+
control_image_canny = self.anyline(input_image_a, detect_resolution=1280)
|
622 |
+
control_images = [control_image_depth, control_image_canny]
|
623 |
+
control_modes = [2, 0]
|
624 |
+
conditioning_scales = [0.2, 0.4]
|
625 |
+
|
626 |
+
if qwen2_hidden_state_b is not None:
|
627 |
+
qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :]
|
628 |
+
qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :]
|
629 |
+
|
630 |
+
gen_images = controlnet_pipeline(
|
631 |
+
image=input_image_a,
|
632 |
+
mask_image=mask_image,
|
633 |
+
control_image=control_images,
|
634 |
+
control_mode=control_modes,
|
635 |
+
controlnet_conditioning_scale=conditioning_scales,
|
636 |
+
strength=denoise_strength,
|
637 |
+
prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
638 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None,
|
639 |
+
prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1),
|
640 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
641 |
+
num_inference_steps=num_inference_steps,
|
642 |
+
guidance_scale=guidance_scale,
|
643 |
+
height=height,
|
644 |
+
width=width,
|
645 |
+
).images
|
646 |
+
|
647 |
+
else:
|
648 |
+
raise ValueError(f"Invalid mode: {mode}")
|
649 |
+
|
650 |
+
return gen_images
|
qwen2_vl/configuration_qwen2_vl.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Qwen2VL model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
28 |
+
model_type = "qwen2_vl"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
depth=32,
|
33 |
+
embed_dim=1280,
|
34 |
+
hidden_size=3584,
|
35 |
+
hidden_act="quick_gelu",
|
36 |
+
mlp_ratio=4,
|
37 |
+
num_heads=16,
|
38 |
+
in_channels=3,
|
39 |
+
patch_size=14,
|
40 |
+
spatial_merge_size=2,
|
41 |
+
temporal_patch_size=2,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
self.depth = depth
|
47 |
+
self.embed_dim = embed_dim
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.hidden_act = hidden_act
|
50 |
+
self.mlp_ratio = mlp_ratio
|
51 |
+
self.num_heads = num_heads
|
52 |
+
self.in_channels = in_channels
|
53 |
+
self.patch_size = patch_size
|
54 |
+
self.spatial_merge_size = spatial_merge_size
|
55 |
+
self.temporal_patch_size = temporal_patch_size
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
59 |
+
cls._set_token_in_kwargs(kwargs)
|
60 |
+
|
61 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
62 |
+
|
63 |
+
if config_dict.get("model_type") == "qwen2_vl":
|
64 |
+
config_dict = config_dict["vision_config"]
|
65 |
+
|
66 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
67 |
+
logger.warning(
|
68 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
69 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
70 |
+
)
|
71 |
+
|
72 |
+
return cls.from_dict(config_dict, **kwargs)
|
73 |
+
|
74 |
+
|
75 |
+
class Qwen2VLConfig(PretrainedConfig):
|
76 |
+
r"""
|
77 |
+
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
78 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
79 |
+
with the defaults will yield a similar configuration to that of
|
80 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
81 |
+
|
82 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
83 |
+
documentation from [`PretrainedConfig`] for more information.
|
84 |
+
|
85 |
+
|
86 |
+
Args:
|
87 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
88 |
+
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
89 |
+
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
90 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
91 |
+
Dimension of the hidden representations.
|
92 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
93 |
+
Dimension of the MLP representations.
|
94 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
95 |
+
Number of hidden layers in the Transformer encoder.
|
96 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
97 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
98 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
99 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
100 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
101 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
102 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
103 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
104 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
105 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
106 |
+
The non-linear activation function (function or string) in the decoder.
|
107 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
108 |
+
The maximum sequence length that this model might ever be used with.
|
109 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
110 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
111 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
112 |
+
The epsilon used by the rms normalization layers.
|
113 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
114 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
115 |
+
relevant if `config.is_decoder=True`.
|
116 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
117 |
+
Whether the model's input and output word embeddings should be tied.
|
118 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
119 |
+
The base period of the RoPE embeddings.
|
120 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
121 |
+
Whether to use sliding window attention.
|
122 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
123 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
124 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
125 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
126 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
127 |
+
The dropout ratio for the attention probabilities.
|
128 |
+
vision_config (`Dict`, *optional*):
|
129 |
+
The config for the visual encoder initialization.
|
130 |
+
rope_scaling (`Dict`, *optional*):
|
131 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
132 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
133 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
134 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
135 |
+
these scaling strategies behave:
|
136 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
137 |
+
experimental feature, subject to breaking API changes in future versions.
|
138 |
+
|
139 |
+
```python
|
140 |
+
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
141 |
+
|
142 |
+
>>> # Initializing a Qwen2VL style configuration
|
143 |
+
>>> configuration = Qwen2VLConfig()
|
144 |
+
|
145 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
146 |
+
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
147 |
+
|
148 |
+
>>> # Accessing the model configuration
|
149 |
+
>>> configuration = model.config
|
150 |
+
```"""
|
151 |
+
|
152 |
+
model_type = "qwen2_vl"
|
153 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
vocab_size=152064,
|
158 |
+
hidden_size=8192,
|
159 |
+
intermediate_size=29568,
|
160 |
+
num_hidden_layers=80,
|
161 |
+
num_attention_heads=64,
|
162 |
+
num_key_value_heads=8,
|
163 |
+
hidden_act="silu",
|
164 |
+
max_position_embeddings=32768,
|
165 |
+
initializer_range=0.02,
|
166 |
+
rms_norm_eps=1e-05,
|
167 |
+
use_cache=True,
|
168 |
+
tie_word_embeddings=False,
|
169 |
+
rope_theta=1000000.0,
|
170 |
+
use_sliding_window=False,
|
171 |
+
sliding_window=4096,
|
172 |
+
max_window_layers=80,
|
173 |
+
attention_dropout=0.0,
|
174 |
+
vision_config=None,
|
175 |
+
rope_scaling=None,
|
176 |
+
**kwargs,
|
177 |
+
):
|
178 |
+
if isinstance(vision_config, dict):
|
179 |
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
180 |
+
elif vision_config is None:
|
181 |
+
self.vision_config = Qwen2VLVisionConfig()
|
182 |
+
|
183 |
+
self.vocab_size = vocab_size
|
184 |
+
self.max_position_embeddings = max_position_embeddings
|
185 |
+
self.hidden_size = hidden_size
|
186 |
+
self.intermediate_size = intermediate_size
|
187 |
+
self.num_hidden_layers = num_hidden_layers
|
188 |
+
self.num_attention_heads = num_attention_heads
|
189 |
+
self.use_sliding_window = use_sliding_window
|
190 |
+
self.sliding_window = sliding_window
|
191 |
+
self.max_window_layers = max_window_layers
|
192 |
+
|
193 |
+
# for backward compatibility
|
194 |
+
if num_key_value_heads is None:
|
195 |
+
num_key_value_heads = num_attention_heads
|
196 |
+
|
197 |
+
self.num_key_value_heads = num_key_value_heads
|
198 |
+
self.hidden_act = hidden_act
|
199 |
+
self.initializer_range = initializer_range
|
200 |
+
self.rms_norm_eps = rms_norm_eps
|
201 |
+
self.use_cache = use_cache
|
202 |
+
self.rope_theta = rope_theta
|
203 |
+
self.attention_dropout = attention_dropout
|
204 |
+
self.rope_scaling = rope_scaling
|
205 |
+
|
206 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
qwen2_vl/image_processing_qwen2_vl.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""Image processor class for Qwen2-VL."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from typing import Dict, List, Optional, Union
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
28 |
+
from transformers.image_transforms import (
|
29 |
+
convert_to_rgb,
|
30 |
+
resize,
|
31 |
+
to_channel_dimension_format,
|
32 |
+
)
|
33 |
+
from transformers.image_utils import (
|
34 |
+
OPENAI_CLIP_MEAN,
|
35 |
+
OPENAI_CLIP_STD,
|
36 |
+
ChannelDimension,
|
37 |
+
ImageInput,
|
38 |
+
PILImageResampling,
|
39 |
+
VideoInput,
|
40 |
+
get_image_size,
|
41 |
+
infer_channel_dimension_format,
|
42 |
+
is_scaled_image,
|
43 |
+
is_valid_image,
|
44 |
+
make_list_of_images,
|
45 |
+
to_numpy_array,
|
46 |
+
valid_images,
|
47 |
+
validate_preprocess_arguments,
|
48 |
+
)
|
49 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
if is_vision_available():
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
|
59 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
60 |
+
"""
|
61 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
65 |
+
The input image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
list: A list of images.
|
69 |
+
"""
|
70 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
71 |
+
return [img for img_list in images for img in img_list]
|
72 |
+
|
73 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
74 |
+
return images
|
75 |
+
|
76 |
+
elif is_valid_image(images):
|
77 |
+
return [images]
|
78 |
+
|
79 |
+
raise ValueError(f"Could not make batched images from {images}")
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
83 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
84 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
85 |
+
return videos
|
86 |
+
|
87 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
88 |
+
if isinstance(videos[0], Image.Image):
|
89 |
+
return [videos]
|
90 |
+
elif len(videos[0].shape) == 4:
|
91 |
+
return [list(video) for video in videos]
|
92 |
+
|
93 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
94 |
+
return [list(videos)]
|
95 |
+
|
96 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
97 |
+
|
98 |
+
|
99 |
+
def smart_resize(
|
100 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
101 |
+
):
|
102 |
+
"""Rescales the image so that the following conditions are met:
|
103 |
+
|
104 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
105 |
+
|
106 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
107 |
+
|
108 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if height < factor or width < factor:
|
112 |
+
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
113 |
+
elif max(height, width) / min(height, width) > 200:
|
114 |
+
raise ValueError(
|
115 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
116 |
+
)
|
117 |
+
h_bar = round(height / factor) * factor
|
118 |
+
w_bar = round(width / factor) * factor
|
119 |
+
if h_bar * w_bar > max_pixels:
|
120 |
+
beta = math.sqrt((height * width) / max_pixels)
|
121 |
+
h_bar = math.floor(height / beta / factor) * factor
|
122 |
+
w_bar = math.floor(width / beta / factor) * factor
|
123 |
+
elif h_bar * w_bar < min_pixels:
|
124 |
+
beta = math.sqrt(min_pixels / (height * width))
|
125 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
126 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
127 |
+
return h_bar, w_bar
|
128 |
+
|
129 |
+
|
130 |
+
class Qwen2VLImageProcessor(BaseImageProcessor):
|
131 |
+
r"""
|
132 |
+
Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
136 |
+
Whether to resize the image's (height, width) dimensions.
|
137 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
138 |
+
Resampling filter to use when resizing the image.
|
139 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
140 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
141 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
142 |
+
Scale factor to use if rescaling the image.
|
143 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
144 |
+
Whether to normalize the image.
|
145 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
146 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
147 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
148 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
149 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
150 |
+
Whether to convert the image to RGB.
|
151 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
152 |
+
The min pixels of the image to resize the image.
|
153 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
154 |
+
The max pixels of the image to resize the image.
|
155 |
+
patch_size (`int`, *optional*, defaults to 14):
|
156 |
+
The spacial patch size of the vision encoder.
|
157 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
158 |
+
The temporal patch size of the vision encoder.
|
159 |
+
merge_size (`int`, *optional*, defaults to 2):
|
160 |
+
The merge size of the vision encoder to llm encoder.
|
161 |
+
"""
|
162 |
+
|
163 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
do_resize: bool = True,
|
168 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
169 |
+
do_rescale: bool = True,
|
170 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
171 |
+
do_normalize: bool = True,
|
172 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
173 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
174 |
+
do_convert_rgb: bool = True,
|
175 |
+
min_pixels: int = 56 * 56,
|
176 |
+
max_pixels: int = 28 * 28 * 1280,
|
177 |
+
patch_size: int = 14,
|
178 |
+
temporal_patch_size: int = 2,
|
179 |
+
merge_size: int = 2,
|
180 |
+
**kwargs,
|
181 |
+
) -> None:
|
182 |
+
super().__init__(**kwargs)
|
183 |
+
self.do_resize = do_resize
|
184 |
+
self.resample = resample
|
185 |
+
self.do_rescale = do_rescale
|
186 |
+
self.rescale_factor = rescale_factor
|
187 |
+
self.do_normalize = do_normalize
|
188 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
189 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
190 |
+
self.min_pixels = min_pixels
|
191 |
+
self.max_pixels = max_pixels
|
192 |
+
self.patch_size = patch_size
|
193 |
+
self.temporal_patch_size = temporal_patch_size
|
194 |
+
self.merge_size = merge_size
|
195 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
196 |
+
self.do_convert_rgb = do_convert_rgb
|
197 |
+
|
198 |
+
def _preprocess(
|
199 |
+
self,
|
200 |
+
images: Union[ImageInput, VideoInput],
|
201 |
+
do_resize: bool = None,
|
202 |
+
resample: PILImageResampling = None,
|
203 |
+
do_rescale: bool = None,
|
204 |
+
rescale_factor: float = None,
|
205 |
+
do_normalize: bool = None,
|
206 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
207 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
208 |
+
do_convert_rgb: bool = None,
|
209 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
210 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
211 |
+
):
|
212 |
+
"""
|
213 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
images (`ImageInput`):
|
217 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
218 |
+
vision_info (`List[Dict]`, *optional*):
|
219 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
220 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
221 |
+
Whether to resize the image.
|
222 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
223 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
224 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
225 |
+
Whether to rescale the image.
|
226 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
227 |
+
Scale factor to use if rescaling the image.
|
228 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
229 |
+
Whether to normalize the image.
|
230 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
231 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
232 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
233 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
234 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
235 |
+
Whether to convert the image to RGB.
|
236 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
237 |
+
The channel dimension format for the output image. Can be one of:
|
238 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
239 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
240 |
+
- Unset: Use the channel dimension format of the input image.
|
241 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
242 |
+
The channel dimension format for the input image. Can be one of:
|
243 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
244 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
245 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
246 |
+
"""
|
247 |
+
images = make_list_of_images(images)
|
248 |
+
|
249 |
+
if do_convert_rgb:
|
250 |
+
images = [convert_to_rgb(image) for image in images]
|
251 |
+
|
252 |
+
# All transformations expect numpy arrays.
|
253 |
+
images = [to_numpy_array(image) for image in images]
|
254 |
+
|
255 |
+
if is_scaled_image(images[0]) and do_rescale:
|
256 |
+
logger.warning_once(
|
257 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
258 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
259 |
+
)
|
260 |
+
if input_data_format is None:
|
261 |
+
# We assume that all images have the same channel dimension format.
|
262 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
263 |
+
|
264 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
265 |
+
resized_height, resized_width = height, width
|
266 |
+
processed_images = []
|
267 |
+
for image in images:
|
268 |
+
if do_resize:
|
269 |
+
resized_height, resized_width = smart_resize(
|
270 |
+
height,
|
271 |
+
width,
|
272 |
+
factor=self.patch_size * self.merge_size,
|
273 |
+
min_pixels=self.min_pixels,
|
274 |
+
max_pixels=self.max_pixels,
|
275 |
+
)
|
276 |
+
image = resize(
|
277 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
278 |
+
)
|
279 |
+
|
280 |
+
if do_rescale:
|
281 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
282 |
+
|
283 |
+
if do_normalize:
|
284 |
+
image = self.normalize(
|
285 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
286 |
+
)
|
287 |
+
|
288 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
289 |
+
processed_images.append(image)
|
290 |
+
|
291 |
+
patches = np.array(processed_images)
|
292 |
+
if data_format == ChannelDimension.LAST:
|
293 |
+
patches = patches.transpose(0, 3, 1, 2)
|
294 |
+
if patches.shape[0] == 1:
|
295 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
296 |
+
channel = patches.shape[1]
|
297 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
298 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
299 |
+
patches = patches.reshape(
|
300 |
+
grid_t,
|
301 |
+
self.temporal_patch_size,
|
302 |
+
channel,
|
303 |
+
grid_h // self.merge_size,
|
304 |
+
self.merge_size,
|
305 |
+
self.patch_size,
|
306 |
+
grid_w // self.merge_size,
|
307 |
+
self.merge_size,
|
308 |
+
self.patch_size,
|
309 |
+
)
|
310 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
311 |
+
flatten_patches = patches.reshape(
|
312 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
313 |
+
)
|
314 |
+
|
315 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
316 |
+
|
317 |
+
def preprocess(
|
318 |
+
self,
|
319 |
+
images: ImageInput,
|
320 |
+
videos: VideoInput = None,
|
321 |
+
do_resize: bool = None,
|
322 |
+
size: Dict[str, int] = None,
|
323 |
+
resample: PILImageResampling = None,
|
324 |
+
do_rescale: bool = None,
|
325 |
+
rescale_factor: float = None,
|
326 |
+
do_normalize: bool = None,
|
327 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
328 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
329 |
+
do_convert_rgb: bool = None,
|
330 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
331 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
332 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
Args:
|
336 |
+
images (`ImageInput`):
|
337 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
338 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
339 |
+
videos (`VideoInput`):
|
340 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
341 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
342 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
343 |
+
Whether to resize the image.
|
344 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
345 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
346 |
+
the longest edge resized to keep the input aspect ratio.
|
347 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
348 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
349 |
+
has an effect if `do_resize` is set to `True`.
|
350 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
351 |
+
Whether to rescale the image.
|
352 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
353 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
354 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
355 |
+
Whether to normalize the image.
|
356 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
357 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
358 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
359 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
360 |
+
`True`.
|
361 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
362 |
+
Whether to convert the image to RGB.
|
363 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
364 |
+
The type of tensors to return. Can be one of:
|
365 |
+
- Unset: Return a list of `np.ndarray`.
|
366 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
367 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
368 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
369 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
370 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
371 |
+
The channel dimension format for the output image. Can be one of:
|
372 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
373 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
374 |
+
- Unset: Use the channel dimension format of the input image.
|
375 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
376 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
377 |
+
from the input image. Can be one of:
|
378 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
379 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
380 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
381 |
+
|
382 |
+
"""
|
383 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
384 |
+
size = size if size is not None else self.size
|
385 |
+
resample = resample if resample is not None else self.resample
|
386 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
387 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
388 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
389 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
390 |
+
image_std = image_std if image_std is not None else self.image_std
|
391 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
392 |
+
|
393 |
+
if images is not None:
|
394 |
+
images = make_batched_images(images)
|
395 |
+
if videos is not None:
|
396 |
+
videos = make_batched_videos(videos)
|
397 |
+
|
398 |
+
if images is not None and not valid_images(images):
|
399 |
+
raise ValueError(
|
400 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
401 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
402 |
+
)
|
403 |
+
|
404 |
+
validate_preprocess_arguments(
|
405 |
+
rescale_factor=rescale_factor,
|
406 |
+
do_normalize=do_normalize,
|
407 |
+
image_mean=image_mean,
|
408 |
+
image_std=image_std,
|
409 |
+
do_resize=do_resize,
|
410 |
+
size=size,
|
411 |
+
resample=resample,
|
412 |
+
)
|
413 |
+
|
414 |
+
if images is not None:
|
415 |
+
pixel_values, vision_grid_thws = [], []
|
416 |
+
for image in images:
|
417 |
+
patches, image_grid_thw = self._preprocess(
|
418 |
+
image,
|
419 |
+
do_resize=do_resize,
|
420 |
+
resample=resample,
|
421 |
+
do_rescale=do_rescale,
|
422 |
+
rescale_factor=rescale_factor,
|
423 |
+
do_normalize=do_normalize,
|
424 |
+
image_mean=image_mean,
|
425 |
+
image_std=image_std,
|
426 |
+
data_format=data_format,
|
427 |
+
do_convert_rgb=do_convert_rgb,
|
428 |
+
input_data_format=input_data_format,
|
429 |
+
)
|
430 |
+
pixel_values.extend(patches)
|
431 |
+
vision_grid_thws.append(image_grid_thw)
|
432 |
+
pixel_values = np.array(pixel_values)
|
433 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
434 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
435 |
+
|
436 |
+
if videos is not None:
|
437 |
+
pixel_values, vision_grid_thws = [], []
|
438 |
+
for images in videos:
|
439 |
+
patches, video_grid_thw = self._preprocess(
|
440 |
+
images,
|
441 |
+
do_resize=do_resize,
|
442 |
+
resample=resample,
|
443 |
+
do_rescale=do_rescale,
|
444 |
+
rescale_factor=rescale_factor,
|
445 |
+
do_normalize=do_normalize,
|
446 |
+
image_mean=image_mean,
|
447 |
+
image_std=image_std,
|
448 |
+
data_format=data_format,
|
449 |
+
do_convert_rgb=do_convert_rgb,
|
450 |
+
input_data_format=input_data_format,
|
451 |
+
)
|
452 |
+
pixel_values.extend(patches)
|
453 |
+
vision_grid_thws.append(video_grid_thw)
|
454 |
+
pixel_values = np.array(pixel_values)
|
455 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
456 |
+
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
457 |
+
|
458 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
qwen2_vl/modeling_qwen2_vl.py
ADDED
@@ -0,0 +1,1952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""PyTorch Qwen2-VL model."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
31 |
+
|
32 |
+
from transformers.activations import ACT2FN
|
33 |
+
from transformers.cache_utils import Cache, StaticCache
|
34 |
+
from transformers.modeling_attn_mask_utils import (
|
35 |
+
AttentionMaskConverter,
|
36 |
+
)
|
37 |
+
from transformers.modeling_outputs import (
|
38 |
+
BaseModelOutputWithPast,
|
39 |
+
ModelOutput,
|
40 |
+
)
|
41 |
+
from transformers.modeling_utils import PreTrainedModel
|
42 |
+
from transformers.utils import (
|
43 |
+
add_start_docstrings,
|
44 |
+
add_start_docstrings_to_model_forward,
|
45 |
+
is_flash_attn_2_available,
|
46 |
+
is_flash_attn_greater_or_equal_2_10,
|
47 |
+
logging,
|
48 |
+
replace_return_docstrings,
|
49 |
+
)
|
50 |
+
from qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
|
51 |
+
|
52 |
+
import traceback
|
53 |
+
|
54 |
+
|
55 |
+
if is_flash_attn_2_available():
|
56 |
+
from flash_attn import flash_attn_varlen_func
|
57 |
+
|
58 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
59 |
+
else:
|
60 |
+
flash_attn_varlen_func = None
|
61 |
+
|
62 |
+
|
63 |
+
logger = logging.get_logger(__name__)
|
64 |
+
|
65 |
+
_CONFIG_FOR_DOC = "Qwen2VLConfig"
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
70 |
+
"""
|
71 |
+
Base class for Qwen2VL causal language model (or autoregressive) outputs.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
75 |
+
Language modeling loss (for next-token prediction).
|
76 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
77 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
78 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
79 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
80 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
81 |
+
|
82 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
83 |
+
`past_key_values` input) to speed up sequential decoding.
|
84 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
85 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
86 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
87 |
+
|
88 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
89 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
90 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
91 |
+
sequence_length)`.
|
92 |
+
|
93 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
94 |
+
heads.
|
95 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
96 |
+
The rope index difference between sequence length and multimodal rope.
|
97 |
+
"""
|
98 |
+
|
99 |
+
loss: Optional[torch.FloatTensor] = None
|
100 |
+
logits: torch.FloatTensor = None
|
101 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
102 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
103 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
104 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
105 |
+
|
106 |
+
|
107 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding
|
108 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
109 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
self.dim = dim
|
113 |
+
self.max_position_embeddings = max_position_embeddings
|
114 |
+
self.base = base
|
115 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
116 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
117 |
+
|
118 |
+
# Build here to make `torch.jit.trace` work.
|
119 |
+
self._set_cos_sin_cache(
|
120 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
121 |
+
)
|
122 |
+
|
123 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
124 |
+
self.max_seq_len_cached = seq_len
|
125 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
126 |
+
|
127 |
+
freqs = torch.outer(t, self.inv_freq)
|
128 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
129 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
130 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
131 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
132 |
+
|
133 |
+
def forward(self, x, seq_len=None):
|
134 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
135 |
+
if seq_len > self.max_seq_len_cached:
|
136 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
137 |
+
|
138 |
+
return (
|
139 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
140 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
145 |
+
def rotate_half(x):
|
146 |
+
"""Rotates half the hidden dims of the input."""
|
147 |
+
x1 = x[..., : x.shape[-1] // 2]
|
148 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
149 |
+
return torch.cat((-x2, x1), dim=-1)
|
150 |
+
|
151 |
+
|
152 |
+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
|
153 |
+
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
154 |
+
|
155 |
+
Explanation:
|
156 |
+
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
157 |
+
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
158 |
+
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
159 |
+
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
160 |
+
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
161 |
+
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
162 |
+
difference with modern LLMs.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
q (`torch.Tensor`): The query tensor.
|
166 |
+
k (`torch.Tensor`): The key tensor.
|
167 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
168 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
169 |
+
position_ids (`torch.Tensor`):
|
170 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
171 |
+
used to pass offsetted position ids when working with a KV-cache.
|
172 |
+
mrope_section(`List(int)`):
|
173 |
+
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
174 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
175 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
176 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
177 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
178 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
179 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
180 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
181 |
+
Returns:
|
182 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
183 |
+
"""
|
184 |
+
cos = cos[position_ids]
|
185 |
+
sin = sin[position_ids]
|
186 |
+
mrope_section = mrope_section * 2
|
187 |
+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
188 |
+
unsqueeze_dim
|
189 |
+
)
|
190 |
+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
191 |
+
unsqueeze_dim
|
192 |
+
)
|
193 |
+
|
194 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
195 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
196 |
+
return q_embed, k_embed
|
197 |
+
|
198 |
+
|
199 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
200 |
+
orig_dtype = tensor.dtype
|
201 |
+
tensor = tensor.float()
|
202 |
+
cos = freqs.cos()
|
203 |
+
sin = freqs.sin()
|
204 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
205 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
206 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
207 |
+
output = output.to(orig_dtype)
|
208 |
+
return output
|
209 |
+
|
210 |
+
|
211 |
+
class VisionRotaryEmbedding(nn.Module):
|
212 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
213 |
+
super().__init__()
|
214 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
215 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
216 |
+
|
217 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
218 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
219 |
+
freqs = torch.outer(seq, self.inv_freq)
|
220 |
+
return freqs
|
221 |
+
|
222 |
+
|
223 |
+
class PatchEmbed(nn.Module):
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
patch_size: int = 14,
|
227 |
+
temporal_patch_size: int = 2,
|
228 |
+
in_channels: int = 3,
|
229 |
+
embed_dim: int = 1152,
|
230 |
+
) -> None:
|
231 |
+
super().__init__()
|
232 |
+
self.patch_size = patch_size
|
233 |
+
self.temporal_patch_size = temporal_patch_size
|
234 |
+
self.in_channels = in_channels
|
235 |
+
self.embed_dim = embed_dim
|
236 |
+
|
237 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
238 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
239 |
+
|
240 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
241 |
+
target_dtype = self.proj.weight.dtype
|
242 |
+
hidden_states = hidden_states.view(
|
243 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
244 |
+
)
|
245 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
246 |
+
return hidden_states
|
247 |
+
|
248 |
+
|
249 |
+
class PatchMerger(nn.Module):
|
250 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
251 |
+
super().__init__()
|
252 |
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
253 |
+
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
254 |
+
self.mlp = nn.Sequential(
|
255 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
256 |
+
nn.GELU(),
|
257 |
+
nn.Linear(self.hidden_size, dim),
|
258 |
+
)
|
259 |
+
|
260 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
261 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
class VisionMlp(nn.Module):
|
266 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
267 |
+
super().__init__()
|
268 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
269 |
+
self.act = ACT2FN[hidden_act]
|
270 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
271 |
+
|
272 |
+
def forward(self, x) -> torch.Tensor:
|
273 |
+
return self.fc2(self.act(self.fc1(x)))
|
274 |
+
|
275 |
+
|
276 |
+
class VisionAttention(nn.Module):
|
277 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
278 |
+
super().__init__()
|
279 |
+
self.num_heads = num_heads
|
280 |
+
self.head_dim = dim // num_heads
|
281 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
282 |
+
self.proj = nn.Linear(dim, dim)
|
283 |
+
|
284 |
+
def forward(
|
285 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
286 |
+
) -> torch.Tensor:
|
287 |
+
seq_length = hidden_states.shape[0]
|
288 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
289 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
290 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
291 |
+
|
292 |
+
attention_mask = torch.full(
|
293 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
294 |
+
)
|
295 |
+
for i in range(1, len(cu_seqlens)):
|
296 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
297 |
+
|
298 |
+
q = q.transpose(0, 1)
|
299 |
+
k = k.transpose(0, 1)
|
300 |
+
v = v.transpose(0, 1)
|
301 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
302 |
+
attn_weights = attn_weights + attention_mask
|
303 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
304 |
+
attn_output = torch.matmul(attn_weights, v)
|
305 |
+
attn_output = attn_output.transpose(0, 1)
|
306 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
307 |
+
attn_output = self.proj(attn_output)
|
308 |
+
return attn_output
|
309 |
+
|
310 |
+
|
311 |
+
class VisionFlashAttention2(nn.Module):
|
312 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
313 |
+
super().__init__()
|
314 |
+
self.num_heads = num_heads
|
315 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
316 |
+
self.proj = nn.Linear(dim, dim)
|
317 |
+
|
318 |
+
def forward(
|
319 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
320 |
+
) -> torch.Tensor:
|
321 |
+
seq_length = hidden_states.shape[0]
|
322 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
323 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
324 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
325 |
+
|
326 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
327 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
328 |
+
seq_length, -1
|
329 |
+
)
|
330 |
+
attn_output = self.proj(attn_output)
|
331 |
+
return attn_output
|
332 |
+
|
333 |
+
|
334 |
+
class VisionSdpaAttention(nn.Module):
|
335 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
336 |
+
super().__init__()
|
337 |
+
self.num_heads = num_heads
|
338 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
339 |
+
self.proj = nn.Linear(dim, dim)
|
340 |
+
|
341 |
+
def forward(
|
342 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
343 |
+
) -> torch.Tensor:
|
344 |
+
seq_length = hidden_states.shape[0]
|
345 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
346 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
347 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
348 |
+
|
349 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
350 |
+
for i in range(1, len(cu_seqlens)):
|
351 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
352 |
+
q = q.transpose(0, 1)
|
353 |
+
k = k.transpose(0, 1)
|
354 |
+
v = v.transpose(0, 1)
|
355 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
356 |
+
attn_output = attn_output.transpose(0, 1)
|
357 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
358 |
+
attn_output = self.proj(attn_output)
|
359 |
+
return attn_output
|
360 |
+
|
361 |
+
|
362 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
363 |
+
"eager": VisionAttention,
|
364 |
+
"flash_attention_2": VisionFlashAttention2,
|
365 |
+
"sdpa": VisionSdpaAttention,
|
366 |
+
}
|
367 |
+
|
368 |
+
|
369 |
+
class Qwen2VLVisionBlock(nn.Module):
|
370 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
371 |
+
super().__init__()
|
372 |
+
self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
|
373 |
+
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
|
374 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
375 |
+
|
376 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
377 |
+
config.embed_dim, num_heads=config.num_heads
|
378 |
+
)
|
379 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
380 |
+
|
381 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
382 |
+
hidden_states = hidden_states + self.attn(
|
383 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
384 |
+
)
|
385 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
386 |
+
return hidden_states
|
387 |
+
|
388 |
+
|
389 |
+
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
390 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
391 |
+
attention_mask: torch.Tensor,
|
392 |
+
sequence_length: int,
|
393 |
+
target_length: int,
|
394 |
+
dtype: torch.dtype,
|
395 |
+
device: torch.device,
|
396 |
+
min_dtype: float,
|
397 |
+
cache_position: torch.Tensor,
|
398 |
+
batch_size: int,
|
399 |
+
):
|
400 |
+
"""
|
401 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
402 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
attention_mask (`torch.Tensor`):
|
406 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
407 |
+
sequence_length (`int`):
|
408 |
+
The sequence length being processed.
|
409 |
+
target_length (`int`):
|
410 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
411 |
+
dtype (`torch.dtype`):
|
412 |
+
The dtype to use for the 4D attention mask.
|
413 |
+
device (`torch.device`):
|
414 |
+
The device to plcae the 4D attention mask on.
|
415 |
+
min_dtype (`float`):
|
416 |
+
The minimum value representable with the dtype `dtype`.
|
417 |
+
cache_position (`torch.Tensor`):
|
418 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
419 |
+
batch_size (`torch.Tensor`):
|
420 |
+
Batch size.
|
421 |
+
"""
|
422 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
423 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
424 |
+
causal_mask = attention_mask
|
425 |
+
else:
|
426 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
427 |
+
if sequence_length != 1:
|
428 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
429 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
430 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
431 |
+
if attention_mask is not None:
|
432 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
433 |
+
mask_length = attention_mask.shape[-1]
|
434 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
435 |
+
padding_mask = padding_mask == 0
|
436 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
437 |
+
padding_mask, min_dtype
|
438 |
+
)
|
439 |
+
|
440 |
+
return causal_mask
|
441 |
+
|
442 |
+
|
443 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
|
444 |
+
class Qwen2RMSNorm(nn.Module):
|
445 |
+
def __init__(self, hidden_size, eps=1e-6):
|
446 |
+
"""
|
447 |
+
Qwen2RMSNorm is equivalent to T5LayerNorm
|
448 |
+
"""
|
449 |
+
super().__init__()
|
450 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
451 |
+
self.variance_epsilon = eps
|
452 |
+
|
453 |
+
def forward(self, hidden_states):
|
454 |
+
input_dtype = hidden_states.dtype
|
455 |
+
hidden_states = hidden_states.to(torch.float32)
|
456 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
457 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
458 |
+
return self.weight * hidden_states.to(input_dtype)
|
459 |
+
|
460 |
+
def extra_repr(self):
|
461 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
462 |
+
|
463 |
+
|
464 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
465 |
+
class Qwen2MLP(nn.Module):
|
466 |
+
def __init__(self, config):
|
467 |
+
super().__init__()
|
468 |
+
self.hidden_size = config.hidden_size
|
469 |
+
self.intermediate_size = config.intermediate_size
|
470 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
471 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
472 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
473 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
474 |
+
|
475 |
+
def forward(self, hidden_state):
|
476 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
477 |
+
|
478 |
+
|
479 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
480 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
481 |
+
"""
|
482 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
483 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
484 |
+
"""
|
485 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
486 |
+
if n_rep == 1:
|
487 |
+
return hidden_states
|
488 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
489 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
490 |
+
|
491 |
+
|
492 |
+
class Qwen2VLAttention(nn.Module):
|
493 |
+
"""
|
494 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
495 |
+
and "Generating Long Sequences with Sparse Transformers".
|
496 |
+
"""
|
497 |
+
|
498 |
+
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
|
499 |
+
super().__init__()
|
500 |
+
self.config = config
|
501 |
+
self.layer_idx = layer_idx
|
502 |
+
if layer_idx is None:
|
503 |
+
logger.warning_once(
|
504 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
505 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
506 |
+
"when creating this class."
|
507 |
+
)
|
508 |
+
|
509 |
+
self.hidden_size = config.hidden_size
|
510 |
+
self.num_heads = config.num_attention_heads
|
511 |
+
self.head_dim = self.hidden_size // self.num_heads
|
512 |
+
self.num_key_value_heads = config.num_key_value_heads
|
513 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
514 |
+
self.max_position_embeddings = config.max_position_embeddings
|
515 |
+
self.rope_theta = config.rope_theta
|
516 |
+
self.is_causal = True
|
517 |
+
self.attention_dropout = config.attention_dropout
|
518 |
+
self.rope_scaling = config.rope_scaling
|
519 |
+
|
520 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
521 |
+
raise ValueError(
|
522 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
523 |
+
f" and `num_heads`: {self.num_heads})."
|
524 |
+
)
|
525 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
526 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
527 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
528 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
529 |
+
|
530 |
+
self.rotary_emb = Qwen2RotaryEmbedding(
|
531 |
+
self.head_dim,
|
532 |
+
max_position_embeddings=self.max_position_embeddings,
|
533 |
+
base=self.rope_theta,
|
534 |
+
)
|
535 |
+
|
536 |
+
def forward(
|
537 |
+
self,
|
538 |
+
hidden_states: torch.Tensor,
|
539 |
+
attention_mask: Optional[torch.Tensor] = None,
|
540 |
+
position_ids: Optional[torch.LongTensor] = None,
|
541 |
+
past_key_value: Optional[Cache] = None,
|
542 |
+
output_attentions: bool = False,
|
543 |
+
use_cache: bool = False,
|
544 |
+
cache_position: Optional[torch.LongTensor] = None,
|
545 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
546 |
+
bsz, q_len, _ = hidden_states.size()
|
547 |
+
|
548 |
+
query_states = self.q_proj(hidden_states)
|
549 |
+
key_states = self.k_proj(hidden_states)
|
550 |
+
value_states = self.v_proj(hidden_states)
|
551 |
+
|
552 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
553 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
554 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
555 |
+
|
556 |
+
kv_seq_len = key_states.shape[-2]
|
557 |
+
if past_key_value is not None:
|
558 |
+
if self.layer_idx is None:
|
559 |
+
raise ValueError(
|
560 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
561 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
562 |
+
"with a layer index."
|
563 |
+
)
|
564 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
565 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
566 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
567 |
+
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
568 |
+
)
|
569 |
+
|
570 |
+
if past_key_value is not None:
|
571 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
572 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
573 |
+
|
574 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
575 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
576 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
577 |
+
|
578 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
579 |
+
|
580 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
581 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
582 |
+
attn_weights = attn_weights + causal_mask
|
583 |
+
|
584 |
+
# upcast attention to fp32
|
585 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
586 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
587 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
588 |
+
|
589 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
590 |
+
raise ValueError(
|
591 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
592 |
+
f" {attn_output.size()}"
|
593 |
+
)
|
594 |
+
|
595 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
596 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
597 |
+
|
598 |
+
attn_output = self.o_proj(attn_output)
|
599 |
+
|
600 |
+
if not output_attentions:
|
601 |
+
attn_weights = None
|
602 |
+
|
603 |
+
return attn_output, attn_weights, past_key_value
|
604 |
+
|
605 |
+
|
606 |
+
class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
607 |
+
"""
|
608 |
+
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
609 |
+
as the weights of the module stays untouched. The only required change would be on the forward pass
|
610 |
+
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
611 |
+
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
612 |
+
config.max_window_layers layers.
|
613 |
+
"""
|
614 |
+
|
615 |
+
def __init__(self, *args, **kwargs):
|
616 |
+
super().__init__(*args, **kwargs)
|
617 |
+
|
618 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
619 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
620 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
621 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
622 |
+
|
623 |
+
def forward(
|
624 |
+
self,
|
625 |
+
hidden_states: torch.Tensor,
|
626 |
+
attention_mask: Optional[torch.Tensor] = None,
|
627 |
+
position_ids: Optional[torch.LongTensor] = None,
|
628 |
+
past_key_value: Optional[Cache] = None,
|
629 |
+
output_attentions: bool = False,
|
630 |
+
use_cache: bool = False,
|
631 |
+
cache_position: Optional[torch.LongTensor] = None,
|
632 |
+
):
|
633 |
+
bsz, q_len, _ = hidden_states.size()
|
634 |
+
|
635 |
+
query_states = self.q_proj(hidden_states)
|
636 |
+
key_states = self.k_proj(hidden_states)
|
637 |
+
value_states = self.v_proj(hidden_states)
|
638 |
+
|
639 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
640 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
641 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
642 |
+
|
643 |
+
kv_seq_len = key_states.shape[-2]
|
644 |
+
if past_key_value is not None:
|
645 |
+
if self.layer_idx is None:
|
646 |
+
raise ValueError(
|
647 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
648 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
649 |
+
"with a layer index."
|
650 |
+
)
|
651 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
652 |
+
|
653 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
654 |
+
rotary_seq_len = (
|
655 |
+
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
656 |
+
)
|
657 |
+
|
658 |
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
659 |
+
|
660 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
661 |
+
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
662 |
+
)
|
663 |
+
|
664 |
+
if past_key_value is not None:
|
665 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
666 |
+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
667 |
+
if (
|
668 |
+
getattr(self.config, "sliding_window", None) is not None
|
669 |
+
and kv_seq_len > self.config.sliding_window
|
670 |
+
and cache_has_contents
|
671 |
+
):
|
672 |
+
slicing_tokens = 1 - self.config.sliding_window
|
673 |
+
|
674 |
+
past_key = past_key_value[self.layer_idx][0]
|
675 |
+
past_value = past_key_value[self.layer_idx][1]
|
676 |
+
|
677 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
678 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
679 |
+
|
680 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
681 |
+
raise ValueError(
|
682 |
+
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
683 |
+
f" {past_key.shape}"
|
684 |
+
)
|
685 |
+
|
686 |
+
if attention_mask is not None:
|
687 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
688 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
689 |
+
|
690 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
691 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
692 |
+
|
693 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
694 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
695 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
696 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
697 |
+
|
698 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
699 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
700 |
+
# cast them back in float16 just to be sure everything works as expected.
|
701 |
+
input_dtype = query_states.dtype
|
702 |
+
if input_dtype == torch.float32:
|
703 |
+
if torch.is_autocast_enabled():
|
704 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
705 |
+
# Handle the case where the model is quantized
|
706 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
707 |
+
target_dtype = self.config._pre_quantization_dtype
|
708 |
+
else:
|
709 |
+
target_dtype = self.q_proj.weight.dtype
|
710 |
+
|
711 |
+
logger.warning_once(
|
712 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
713 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
714 |
+
f" {target_dtype}."
|
715 |
+
)
|
716 |
+
|
717 |
+
query_states = query_states.to(target_dtype)
|
718 |
+
key_states = key_states.to(target_dtype)
|
719 |
+
value_states = value_states.to(target_dtype)
|
720 |
+
|
721 |
+
# Reashape to the expected shape for Flash Attention
|
722 |
+
query_states = query_states.transpose(1, 2)
|
723 |
+
key_states = key_states.transpose(1, 2)
|
724 |
+
value_states = value_states.transpose(1, 2)
|
725 |
+
|
726 |
+
if (
|
727 |
+
self.config.use_sliding_window
|
728 |
+
and getattr(self.config, "sliding_window", None) is not None
|
729 |
+
and self.layer_idx >= self.config.max_window_layers
|
730 |
+
):
|
731 |
+
sliding_window = self.config.sliding_window
|
732 |
+
else:
|
733 |
+
sliding_window = None
|
734 |
+
|
735 |
+
attn_output = _flash_attention_forward(
|
736 |
+
query_states,
|
737 |
+
key_states,
|
738 |
+
value_states,
|
739 |
+
attention_mask,
|
740 |
+
q_len,
|
741 |
+
dropout=dropout_rate,
|
742 |
+
sliding_window=sliding_window,
|
743 |
+
is_causal=self.is_causal,
|
744 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
745 |
+
)
|
746 |
+
|
747 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
748 |
+
attn_output = self.o_proj(attn_output)
|
749 |
+
|
750 |
+
if not output_attentions:
|
751 |
+
attn_weights = None
|
752 |
+
|
753 |
+
return attn_output, attn_weights, past_key_value
|
754 |
+
|
755 |
+
|
756 |
+
class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
757 |
+
"""
|
758 |
+
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
759 |
+
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
760 |
+
SDPA API.
|
761 |
+
"""
|
762 |
+
|
763 |
+
# Adapted from Qwen2Attention.forward
|
764 |
+
def forward(
|
765 |
+
self,
|
766 |
+
hidden_states: torch.Tensor,
|
767 |
+
attention_mask: Optional[torch.Tensor] = None,
|
768 |
+
position_ids: Optional[torch.LongTensor] = None,
|
769 |
+
past_key_value: Optional[Cache] = None,
|
770 |
+
output_attentions: bool = False,
|
771 |
+
use_cache: bool = False,
|
772 |
+
cache_position: Optional[torch.LongTensor] = None,
|
773 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
774 |
+
if output_attentions:
|
775 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
776 |
+
logger.warning_once(
|
777 |
+
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
778 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
779 |
+
)
|
780 |
+
return super().forward(
|
781 |
+
hidden_states=hidden_states,
|
782 |
+
attention_mask=attention_mask,
|
783 |
+
position_ids=position_ids,
|
784 |
+
past_key_value=past_key_value,
|
785 |
+
output_attentions=output_attentions,
|
786 |
+
use_cache=use_cache,
|
787 |
+
)
|
788 |
+
|
789 |
+
bsz, q_len, _ = hidden_states.size()
|
790 |
+
|
791 |
+
query_states = self.q_proj(hidden_states)
|
792 |
+
key_states = self.k_proj(hidden_states)
|
793 |
+
value_states = self.v_proj(hidden_states)
|
794 |
+
|
795 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
796 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
797 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
798 |
+
|
799 |
+
kv_seq_len = key_states.shape[-2]
|
800 |
+
if past_key_value is not None:
|
801 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
802 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
803 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
804 |
+
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
805 |
+
)
|
806 |
+
|
807 |
+
if past_key_value is not None:
|
808 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
809 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
810 |
+
|
811 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
812 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
813 |
+
|
814 |
+
causal_mask = attention_mask
|
815 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
816 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
817 |
+
|
818 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
819 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
820 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
821 |
+
query_states = query_states.contiguous()
|
822 |
+
key_states = key_states.contiguous()
|
823 |
+
value_states = value_states.contiguous()
|
824 |
+
|
825 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
826 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
827 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
828 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
829 |
+
|
830 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
831 |
+
query_states,
|
832 |
+
key_states,
|
833 |
+
value_states,
|
834 |
+
attn_mask=causal_mask,
|
835 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
836 |
+
is_causal=is_causal,
|
837 |
+
)
|
838 |
+
|
839 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
840 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
841 |
+
|
842 |
+
attn_output = self.o_proj(attn_output)
|
843 |
+
|
844 |
+
return attn_output, None, past_key_value
|
845 |
+
|
846 |
+
|
847 |
+
QWEN2_VL_ATTENTION_CLASSES = {
|
848 |
+
"eager": Qwen2VLAttention,
|
849 |
+
"flash_attention_2": Qwen2VLFlashAttention2,
|
850 |
+
"sdpa": Qwen2VLSdpaAttention,
|
851 |
+
}
|
852 |
+
|
853 |
+
|
854 |
+
class Qwen2VLDecoderLayer(nn.Module):
|
855 |
+
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
856 |
+
super().__init__()
|
857 |
+
self.hidden_size = config.hidden_size
|
858 |
+
|
859 |
+
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
|
860 |
+
logger.warning_once(
|
861 |
+
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
862 |
+
"unexpected results may be encountered."
|
863 |
+
)
|
864 |
+
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
865 |
+
|
866 |
+
self.mlp = Qwen2MLP(config)
|
867 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
868 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
869 |
+
|
870 |
+
def forward(
|
871 |
+
self,
|
872 |
+
hidden_states: torch.Tensor,
|
873 |
+
attention_mask: Optional[torch.Tensor] = None,
|
874 |
+
position_ids: Optional[torch.LongTensor] = None,
|
875 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
876 |
+
output_attentions: Optional[bool] = False,
|
877 |
+
use_cache: Optional[bool] = False,
|
878 |
+
cache_position: Optional[torch.LongTensor] = None,
|
879 |
+
**kwargs,
|
880 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
881 |
+
"""
|
882 |
+
Args:
|
883 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
884 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
885 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
886 |
+
output_attentions (`bool`, *optional*):
|
887 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
888 |
+
returned tensors for more detail.
|
889 |
+
use_cache (`bool`, *optional*):
|
890 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
891 |
+
(see `past_key_values`).
|
892 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
893 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
894 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
895 |
+
kwargs (`dict`, *optional*):
|
896 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
897 |
+
into the model
|
898 |
+
"""
|
899 |
+
|
900 |
+
residual = hidden_states
|
901 |
+
|
902 |
+
hidden_states = self.input_layernorm(hidden_states)
|
903 |
+
|
904 |
+
# Self Attention
|
905 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
906 |
+
hidden_states=hidden_states,
|
907 |
+
attention_mask=attention_mask,
|
908 |
+
position_ids=position_ids,
|
909 |
+
past_key_value=past_key_value,
|
910 |
+
output_attentions=output_attentions,
|
911 |
+
use_cache=use_cache,
|
912 |
+
cache_position=cache_position,
|
913 |
+
)
|
914 |
+
hidden_states = residual + hidden_states
|
915 |
+
|
916 |
+
# Fully Connected
|
917 |
+
residual = hidden_states
|
918 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
919 |
+
hidden_states = self.mlp(hidden_states)
|
920 |
+
hidden_states = residual + hidden_states
|
921 |
+
|
922 |
+
outputs = (hidden_states,)
|
923 |
+
|
924 |
+
if output_attentions:
|
925 |
+
outputs += (self_attn_weights,)
|
926 |
+
|
927 |
+
if use_cache:
|
928 |
+
outputs += (present_key_value,)
|
929 |
+
|
930 |
+
return outputs
|
931 |
+
|
932 |
+
|
933 |
+
QWEN2VL_START_DOCSTRING = r"""
|
934 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
935 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
936 |
+
etc.)
|
937 |
+
|
938 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
939 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
940 |
+
and behavior.
|
941 |
+
|
942 |
+
Parameters:
|
943 |
+
config ([`Qwen2VLConfig`]):
|
944 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
945 |
+
load the weights associated with the model, only the configuration. Check out the
|
946 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
947 |
+
"""
|
948 |
+
|
949 |
+
|
950 |
+
@add_start_docstrings(
|
951 |
+
"The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
|
952 |
+
QWEN2VL_START_DOCSTRING,
|
953 |
+
)
|
954 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
955 |
+
config_class = Qwen2VLConfig
|
956 |
+
base_model_prefix = "model"
|
957 |
+
supports_gradient_checkpointing = True
|
958 |
+
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
|
959 |
+
_skip_keys_device_placement = "past_key_values"
|
960 |
+
_supports_flash_attn_2 = True
|
961 |
+
_supports_sdpa = True
|
962 |
+
_supports_cache_class = True
|
963 |
+
_supports_static_cache = True
|
964 |
+
|
965 |
+
def _init_weights(self, module):
|
966 |
+
std = self.config.initializer_range
|
967 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
968 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
969 |
+
if module.bias is not None:
|
970 |
+
module.bias.data.zero_()
|
971 |
+
elif isinstance(module, nn.Embedding):
|
972 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
973 |
+
if module.padding_idx is not None:
|
974 |
+
module.weight.data[module.padding_idx].zero_()
|
975 |
+
|
976 |
+
|
977 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
978 |
+
config_class = Qwen2VLVisionConfig
|
979 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
980 |
+
|
981 |
+
def __init__(self, config) -> None:
|
982 |
+
super().__init__(config)
|
983 |
+
self.spatial_merge_size = config.spatial_merge_size
|
984 |
+
|
985 |
+
self.patch_embed = PatchEmbed(
|
986 |
+
patch_size=config.patch_size,
|
987 |
+
temporal_patch_size=config.temporal_patch_size,
|
988 |
+
in_channels=config.in_channels,
|
989 |
+
embed_dim=config.embed_dim,
|
990 |
+
)
|
991 |
+
|
992 |
+
head_dim = config.embed_dim // config.num_heads
|
993 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
994 |
+
|
995 |
+
self.blocks = nn.ModuleList(
|
996 |
+
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
997 |
+
)
|
998 |
+
self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
|
999 |
+
|
1000 |
+
def get_dtype(self) -> torch.dtype:
|
1001 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
1002 |
+
|
1003 |
+
def get_device(self) -> torch.device:
|
1004 |
+
return self.blocks[0].mlp.fc2.weight.device
|
1005 |
+
|
1006 |
+
def rot_pos_emb(self, grid_thw):
|
1007 |
+
pos_ids = []
|
1008 |
+
for t, h, w in grid_thw:
|
1009 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
1010 |
+
hpos_ids = hpos_ids.reshape(
|
1011 |
+
h // self.spatial_merge_size,
|
1012 |
+
self.spatial_merge_size,
|
1013 |
+
w // self.spatial_merge_size,
|
1014 |
+
self.spatial_merge_size,
|
1015 |
+
)
|
1016 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
1017 |
+
hpos_ids = hpos_ids.flatten()
|
1018 |
+
|
1019 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
1020 |
+
wpos_ids = wpos_ids.reshape(
|
1021 |
+
h // self.spatial_merge_size,
|
1022 |
+
self.spatial_merge_size,
|
1023 |
+
w // self.spatial_merge_size,
|
1024 |
+
self.spatial_merge_size,
|
1025 |
+
)
|
1026 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
1027 |
+
wpos_ids = wpos_ids.flatten()
|
1028 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
1029 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
1030 |
+
max_grid_size = grid_thw[:, 1:].max()
|
1031 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
1032 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
1033 |
+
return rotary_pos_emb
|
1034 |
+
|
1035 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
1036 |
+
hidden_states = self.patch_embed(hidden_states)
|
1037 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
1038 |
+
|
1039 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
1040 |
+
dim=0, dtype=torch.int32
|
1041 |
+
)
|
1042 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
1043 |
+
|
1044 |
+
for blk in self.blocks:
|
1045 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
1046 |
+
|
1047 |
+
return self.merger(hidden_states)
|
1048 |
+
|
1049 |
+
|
1050 |
+
@add_start_docstrings(
|
1051 |
+
"The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
|
1052 |
+
QWEN2VL_START_DOCSTRING,
|
1053 |
+
)
|
1054 |
+
class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
1055 |
+
def __init__(self, config: Qwen2VLConfig):
|
1056 |
+
super().__init__(config)
|
1057 |
+
self.padding_idx = config.pad_token_id
|
1058 |
+
self.vocab_size = config.vocab_size
|
1059 |
+
|
1060 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1061 |
+
self.layers = nn.ModuleList(
|
1062 |
+
[Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
1063 |
+
)
|
1064 |
+
self._attn_implementation = config._attn_implementation
|
1065 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1066 |
+
|
1067 |
+
self.gradient_checkpointing = False
|
1068 |
+
# Initialize weights and apply final processing
|
1069 |
+
self.post_init()
|
1070 |
+
|
1071 |
+
def get_input_embeddings(self):
|
1072 |
+
return self.embed_tokens
|
1073 |
+
|
1074 |
+
def set_input_embeddings(self, value):
|
1075 |
+
self.embed_tokens = value
|
1076 |
+
|
1077 |
+
def forward(
|
1078 |
+
self,
|
1079 |
+
input_ids: torch.LongTensor = None,
|
1080 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1081 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1082 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1083 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1084 |
+
use_cache: Optional[bool] = None,
|
1085 |
+
output_attentions: Optional[bool] = None,
|
1086 |
+
output_hidden_states: Optional[bool] = None,
|
1087 |
+
return_dict: Optional[bool] = None,
|
1088 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1089 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1090 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1091 |
+
output_hidden_states = (
|
1092 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1093 |
+
)
|
1094 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1095 |
+
|
1096 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1097 |
+
|
1098 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
1099 |
+
raise ValueError(
|
1100 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
if self.gradient_checkpointing and self.training:
|
1104 |
+
if use_cache:
|
1105 |
+
logger.warning_once(
|
1106 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1107 |
+
)
|
1108 |
+
use_cache = False
|
1109 |
+
|
1110 |
+
if inputs_embeds is None:
|
1111 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1112 |
+
|
1113 |
+
if cache_position is None:
|
1114 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1115 |
+
cache_position = torch.arange(
|
1116 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
1117 |
+
)
|
1118 |
+
if position_ids is None:
|
1119 |
+
# the hard coded `3` is for temporal, height and width.
|
1120 |
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
1121 |
+
|
1122 |
+
causal_mask = self._update_causal_mask(
|
1123 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
hidden_states = inputs_embeds
|
1127 |
+
|
1128 |
+
# decoder layers
|
1129 |
+
all_hidden_states = () if output_hidden_states else None
|
1130 |
+
all_self_attns = () if output_attentions else None
|
1131 |
+
next_decoder_cache = None
|
1132 |
+
|
1133 |
+
for decoder_layer in self.layers:
|
1134 |
+
if output_hidden_states:
|
1135 |
+
all_hidden_states += (hidden_states,)
|
1136 |
+
|
1137 |
+
if self.gradient_checkpointing and self.training:
|
1138 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1139 |
+
decoder_layer.__call__,
|
1140 |
+
hidden_states,
|
1141 |
+
causal_mask,
|
1142 |
+
position_ids,
|
1143 |
+
past_key_values,
|
1144 |
+
output_attentions,
|
1145 |
+
use_cache,
|
1146 |
+
cache_position,
|
1147 |
+
)
|
1148 |
+
else:
|
1149 |
+
layer_outputs = decoder_layer(
|
1150 |
+
hidden_states,
|
1151 |
+
attention_mask=causal_mask,
|
1152 |
+
position_ids=position_ids,
|
1153 |
+
past_key_value=past_key_values,
|
1154 |
+
output_attentions=output_attentions,
|
1155 |
+
use_cache=use_cache,
|
1156 |
+
cache_position=cache_position,
|
1157 |
+
)
|
1158 |
+
|
1159 |
+
hidden_states = layer_outputs[0]
|
1160 |
+
|
1161 |
+
if use_cache:
|
1162 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1163 |
+
|
1164 |
+
if output_attentions:
|
1165 |
+
all_self_attns += (layer_outputs[1],)
|
1166 |
+
|
1167 |
+
hidden_states = self.norm(hidden_states)
|
1168 |
+
|
1169 |
+
# add hidden states from the last decoder layer
|
1170 |
+
if output_hidden_states:
|
1171 |
+
all_hidden_states += (hidden_states,)
|
1172 |
+
|
1173 |
+
next_cache = next_decoder_cache if use_cache else None
|
1174 |
+
|
1175 |
+
if not return_dict:
|
1176 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1177 |
+
return BaseModelOutputWithPast(
|
1178 |
+
last_hidden_state=hidden_states,
|
1179 |
+
past_key_values=next_cache,
|
1180 |
+
hidden_states=all_hidden_states,
|
1181 |
+
attentions=all_self_attns,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
1185 |
+
def _update_causal_mask(
|
1186 |
+
self,
|
1187 |
+
attention_mask: torch.Tensor,
|
1188 |
+
input_tensor: torch.Tensor,
|
1189 |
+
cache_position: torch.Tensor,
|
1190 |
+
past_key_values: Cache,
|
1191 |
+
output_attentions: bool,
|
1192 |
+
):
|
1193 |
+
if self.config._attn_implementation == "flash_attention_2":
|
1194 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
1195 |
+
return attention_mask
|
1196 |
+
return None
|
1197 |
+
|
1198 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
1199 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
1200 |
+
# to infer the attention mask.
|
1201 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1202 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
1203 |
+
|
1204 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
1205 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
1206 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
1207 |
+
attention_mask,
|
1208 |
+
inputs_embeds=input_tensor,
|
1209 |
+
past_key_values_length=past_seen_tokens,
|
1210 |
+
is_training=self.training,
|
1211 |
+
):
|
1212 |
+
return None
|
1213 |
+
|
1214 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
1215 |
+
min_dtype = torch.finfo(dtype).min
|
1216 |
+
sequence_length = input_tensor.shape[1]
|
1217 |
+
if using_static_cache:
|
1218 |
+
target_length = past_key_values.get_max_length()
|
1219 |
+
else:
|
1220 |
+
target_length = (
|
1221 |
+
attention_mask.shape[-1]
|
1222 |
+
if isinstance(attention_mask, torch.Tensor)
|
1223 |
+
else past_seen_tokens + sequence_length + 1
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
1227 |
+
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1228 |
+
attention_mask,
|
1229 |
+
sequence_length=sequence_length,
|
1230 |
+
target_length=target_length,
|
1231 |
+
dtype=dtype,
|
1232 |
+
device=device,
|
1233 |
+
min_dtype=min_dtype,
|
1234 |
+
cache_position=cache_position,
|
1235 |
+
batch_size=input_tensor.shape[0],
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
if (
|
1239 |
+
self.config._attn_implementation == "sdpa"
|
1240 |
+
and attention_mask is not None
|
1241 |
+
and attention_mask.device.type == "cuda"
|
1242 |
+
and not output_attentions
|
1243 |
+
):
|
1244 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
1245 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
1246 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
1247 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
1248 |
+
|
1249 |
+
return causal_mask
|
1250 |
+
|
1251 |
+
|
1252 |
+
QWEN2_VL_INPUTS_DOCSTRING = r"""
|
1253 |
+
Args:
|
1254 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1255 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
1256 |
+
it.
|
1257 |
+
|
1258 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1259 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1260 |
+
|
1261 |
+
[What are input IDs?](../glossary#input-ids)
|
1262 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1263 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1264 |
+
|
1265 |
+
- 1 for tokens that are **not masked**,
|
1266 |
+
- 0 for tokens that are **masked**.
|
1267 |
+
|
1268 |
+
[What are attention masks?](../glossary#attention-mask)
|
1269 |
+
|
1270 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1271 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1272 |
+
|
1273 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
1274 |
+
`past_key_values`).
|
1275 |
+
|
1276 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
1277 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
1278 |
+
information on the default strategy.
|
1279 |
+
|
1280 |
+
- 1 indicates the head is **not masked**,
|
1281 |
+
- 0 indicates the head is **masked**.
|
1282 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1283 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1284 |
+
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
1285 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
1286 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
1287 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
1288 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
1289 |
+
|
1290 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
1291 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
1292 |
+
|
1293 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1294 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1295 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1296 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1297 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1298 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1299 |
+
model's internal embedding lookup matrix.
|
1300 |
+
use_cache (`bool`, *optional*):
|
1301 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1302 |
+
`past_key_values`).
|
1303 |
+
output_attentions (`bool`, *optional*):
|
1304 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1305 |
+
tensors for more detail.
|
1306 |
+
output_hidden_states (`bool`, *optional*):
|
1307 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1308 |
+
more detail.
|
1309 |
+
return_dict (`bool`, *optional*):
|
1310 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1311 |
+
pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)):
|
1312 |
+
The tensors corresponding to the input images. Pixel values can be obtained using
|
1313 |
+
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
1314 |
+
[`Qwen2VLImageProcessor`] for processing images.
|
1315 |
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
1316 |
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
1317 |
+
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
1318 |
+
[`Qwen2VLImageProcessor`] for processing videos.
|
1319 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
1320 |
+
The temporal, height and width of feature shape of each image in LLM.
|
1321 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
1322 |
+
The temporal, height and width of feature shape of each video in LLM.
|
1323 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
1324 |
+
The rope index difference between sequence length and multimodal rope.
|
1325 |
+
"""
|
1326 |
+
|
1327 |
+
|
1328 |
+
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
|
1329 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1330 |
+
|
1331 |
+
def __init__(self, config):
|
1332 |
+
super().__init__(config)
|
1333 |
+
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
|
1334 |
+
config.vision_config, attn_implementation=config._attn_implementation
|
1335 |
+
)
|
1336 |
+
self.model = Qwen2VLModel(config)
|
1337 |
+
self.vocab_size = config.vocab_size
|
1338 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1339 |
+
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
1340 |
+
|
1341 |
+
# Initialize weights and apply final processing
|
1342 |
+
self.post_init()
|
1343 |
+
|
1344 |
+
def get_input_embeddings(self):
|
1345 |
+
return self.model.embed_tokens
|
1346 |
+
|
1347 |
+
def set_input_embeddings(self, value):
|
1348 |
+
self.model.embed_tokens = value
|
1349 |
+
|
1350 |
+
def get_output_embeddings(self):
|
1351 |
+
return self.lm_head
|
1352 |
+
|
1353 |
+
def set_output_embeddings(self, new_embeddings):
|
1354 |
+
self.lm_head = new_embeddings
|
1355 |
+
|
1356 |
+
def set_decoder(self, decoder):
|
1357 |
+
self.model = decoder
|
1358 |
+
|
1359 |
+
def get_decoder(self):
|
1360 |
+
return self.model
|
1361 |
+
|
1362 |
+
def get_rope_index(
|
1363 |
+
self,
|
1364 |
+
input_ids: torch.LongTensor,
|
1365 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
1366 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
1367 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1368 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1369 |
+
"""
|
1370 |
+
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
1371 |
+
|
1372 |
+
Explanation:
|
1373 |
+
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
1374 |
+
|
1375 |
+
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
1376 |
+
Examples:
|
1377 |
+
input_ids: [T T T T T], here T is for text.
|
1378 |
+
temporal position_ids: [0, 1, 2, 3, 4]
|
1379 |
+
height position_ids: [0, 1, 2, 3, 4]
|
1380 |
+
width position_ids: [0, 1, 2, 3, 4]
|
1381 |
+
|
1382 |
+
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
1383 |
+
and 1D rotary position embeddin for text part.
|
1384 |
+
Examples:
|
1385 |
+
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
1386 |
+
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
1387 |
+
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
1388 |
+
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
1389 |
+
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
1390 |
+
text temporal position_ids: [3, 4, 5, 6, 7]
|
1391 |
+
text height position_ids: [3, 4, 5, 6, 7]
|
1392 |
+
text width position_ids: [3, 4, 5, 6, 7]
|
1393 |
+
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
1394 |
+
|
1395 |
+
Args:
|
1396 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1397 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
1398 |
+
it.
|
1399 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
1400 |
+
The temporal, height and width of feature shape of each image in LLM.
|
1401 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
1402 |
+
The temporal, height and width of feature shape of each video in LLM.
|
1403 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1404 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1405 |
+
|
1406 |
+
- 1 for tokens that are **not masked**,
|
1407 |
+
- 0 for tokens that are **masked**.
|
1408 |
+
|
1409 |
+
Returns:
|
1410 |
+
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
1411 |
+
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
1412 |
+
"""
|
1413 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
1414 |
+
image_token_id = self.config.image_token_id
|
1415 |
+
video_token_id = self.config.video_token_id
|
1416 |
+
vision_start_token_id = self.config.vision_start_token_id
|
1417 |
+
mrope_position_deltas = []
|
1418 |
+
if image_grid_thw is not None or video_grid_thw is not None:
|
1419 |
+
total_input_ids = input_ids
|
1420 |
+
position_ids = torch.ones(
|
1421 |
+
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
1422 |
+
)
|
1423 |
+
image_index, video_index = 0, 0
|
1424 |
+
for i, input_ids in enumerate(total_input_ids):
|
1425 |
+
if attention_mask is not None:
|
1426 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
1427 |
+
image_nums, video_nums = 0, 0
|
1428 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
1429 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
1430 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
1431 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
1432 |
+
input_tokens = input_ids.tolist()
|
1433 |
+
llm_pos_ids_list: list = []
|
1434 |
+
st = 0
|
1435 |
+
remain_images, remain_videos = image_nums, video_nums
|
1436 |
+
for _ in range(image_nums + video_nums):
|
1437 |
+
if image_token_id in input_tokens and remain_images > 0:
|
1438 |
+
ed_image = input_tokens.index(image_token_id, st)
|
1439 |
+
else:
|
1440 |
+
ed_image = len(input_tokens) + 1
|
1441 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
1442 |
+
ed_video = input_tokens.index(video_token_id, st)
|
1443 |
+
else:
|
1444 |
+
ed_video = len(input_tokens) + 1
|
1445 |
+
if ed_image < ed_video:
|
1446 |
+
t, h, w = (
|
1447 |
+
image_grid_thw[image_index][0],
|
1448 |
+
image_grid_thw[image_index][1],
|
1449 |
+
image_grid_thw[image_index][2],
|
1450 |
+
)
|
1451 |
+
image_index += 1
|
1452 |
+
remain_images -= 1
|
1453 |
+
ed = ed_image
|
1454 |
+
else:
|
1455 |
+
t, h, w = (
|
1456 |
+
video_grid_thw[video_index][0],
|
1457 |
+
video_grid_thw[video_index][1],
|
1458 |
+
video_grid_thw[video_index][2],
|
1459 |
+
)
|
1460 |
+
video_index += 1
|
1461 |
+
remain_videos -= 1
|
1462 |
+
ed = ed_video
|
1463 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1464 |
+
t.item(),
|
1465 |
+
h.item() // spatial_merge_size,
|
1466 |
+
w.item() // spatial_merge_size,
|
1467 |
+
)
|
1468 |
+
text_len = ed - st
|
1469 |
+
|
1470 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
1471 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
1472 |
+
|
1473 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
1474 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
1475 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
1476 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
1477 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
1478 |
+
|
1479 |
+
if st < len(input_tokens):
|
1480 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
1481 |
+
text_len = len(input_tokens) - st
|
1482 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
1483 |
+
|
1484 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1485 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
1486 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
1487 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
1488 |
+
return position_ids, mrope_position_deltas
|
1489 |
+
else:
|
1490 |
+
if attention_mask is not None:
|
1491 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1492 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1493 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
1494 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
1495 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
1496 |
+
else:
|
1497 |
+
position_ids = (
|
1498 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
1499 |
+
.view(1, 1, -1)
|
1500 |
+
.expand(3, input_ids.shape[0], -1)
|
1501 |
+
)
|
1502 |
+
mrope_position_deltas = torch.zeros(
|
1503 |
+
[input_ids.shape[0], 1],
|
1504 |
+
device=input_ids.device,
|
1505 |
+
dtype=input_ids.dtype,
|
1506 |
+
)
|
1507 |
+
|
1508 |
+
return position_ids, mrope_position_deltas
|
1509 |
+
|
1510 |
+
def _update_model_kwargs_for_generation(
|
1511 |
+
self,
|
1512 |
+
outputs: ModelOutput,
|
1513 |
+
model_kwargs: Dict[str, Any],
|
1514 |
+
is_encoder_decoder: bool = False,
|
1515 |
+
num_new_tokens: int = 1,
|
1516 |
+
) -> Dict[str, Any]:
|
1517 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
1518 |
+
outputs=outputs,
|
1519 |
+
model_kwargs=model_kwargs,
|
1520 |
+
is_encoder_decoder=is_encoder_decoder,
|
1521 |
+
num_new_tokens=num_new_tokens,
|
1522 |
+
)
|
1523 |
+
|
1524 |
+
if getattr(outputs, "rope_deltas", None) is not None:
|
1525 |
+
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
1526 |
+
|
1527 |
+
return model_kwargs
|
1528 |
+
|
1529 |
+
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
1530 |
+
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1531 |
+
def forward(
|
1532 |
+
self,
|
1533 |
+
input_ids: torch.LongTensor = None,
|
1534 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1535 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1536 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1537 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1538 |
+
labels: Optional[torch.LongTensor] = None,
|
1539 |
+
use_cache: Optional[bool] = None,
|
1540 |
+
output_attentions: Optional[bool] = None,
|
1541 |
+
output_hidden_states: Optional[bool] = None,
|
1542 |
+
return_dict: Optional[bool] = None,
|
1543 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1544 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
1545 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
1546 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
1547 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
1548 |
+
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
1549 |
+
r"""
|
1550 |
+
Args:
|
1551 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1552 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1553 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1554 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1555 |
+
|
1556 |
+
Returns:
|
1557 |
+
|
1558 |
+
Example:
|
1559 |
+
|
1560 |
+
```python
|
1561 |
+
>>> from PIL import Image
|
1562 |
+
>>> import requests
|
1563 |
+
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
1564 |
+
|
1565 |
+
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
1566 |
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
1567 |
+
|
1568 |
+
>>> messages = [
|
1569 |
+
{
|
1570 |
+
"role": "user",
|
1571 |
+
"content": [
|
1572 |
+
{"type": "image"},
|
1573 |
+
{"type": "text", "text": "What is shown in this image?"},
|
1574 |
+
],
|
1575 |
+
},
|
1576 |
+
]
|
1577 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
1578 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1579 |
+
|
1580 |
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
1581 |
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
1582 |
+
|
1583 |
+
>>> # Generate
|
1584 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1585 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1586 |
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
1587 |
+
```"""
|
1588 |
+
|
1589 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1590 |
+
output_hidden_states = (
|
1591 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1592 |
+
)
|
1593 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1594 |
+
|
1595 |
+
if inputs_embeds is None:
|
1596 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
1597 |
+
if pixel_values is not None:
|
1598 |
+
pixel_values = pixel_values.type(self.visual.get_dtype())
|
1599 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
|
1600 |
+
image_mask = input_ids == self.config.image_token_id
|
1601 |
+
if self.training:
|
1602 |
+
inputs_embeds = inputs_embeds.clone()
|
1603 |
+
inputs_embeds[image_mask] = image_embeds
|
1604 |
+
if pixel_values_videos is not None:
|
1605 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
1606 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
|
1607 |
+
video_mask = input_ids == self.config.video_token_id
|
1608 |
+
inputs_embeds[video_mask] = video_embeds
|
1609 |
+
if attention_mask is not None:
|
1610 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
1611 |
+
|
1612 |
+
outputs = self.model(
|
1613 |
+
input_ids=None,
|
1614 |
+
position_ids=position_ids,
|
1615 |
+
attention_mask=attention_mask,
|
1616 |
+
past_key_values=past_key_values,
|
1617 |
+
inputs_embeds=inputs_embeds,
|
1618 |
+
use_cache=use_cache,
|
1619 |
+
output_attentions=output_attentions,
|
1620 |
+
output_hidden_states=output_hidden_states,
|
1621 |
+
return_dict=return_dict,
|
1622 |
+
)
|
1623 |
+
|
1624 |
+
hidden_states = outputs[0]
|
1625 |
+
logits = self.lm_head(hidden_states)
|
1626 |
+
logits = logits.float()
|
1627 |
+
|
1628 |
+
loss = None
|
1629 |
+
if labels is not None:
|
1630 |
+
# Shift so that tokens < n predict n
|
1631 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1632 |
+
shift_labels = labels[..., 1:].contiguous()
|
1633 |
+
# Flatten the tokens
|
1634 |
+
loss_fct = CrossEntropyLoss()
|
1635 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1636 |
+
shift_labels = shift_labels.view(-1)
|
1637 |
+
# Enable model parallelism
|
1638 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1639 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1640 |
+
|
1641 |
+
if not return_dict:
|
1642 |
+
output = (logits,) + outputs[1:]
|
1643 |
+
return (loss,) + output if loss is not None else output
|
1644 |
+
|
1645 |
+
return Qwen2VLCausalLMOutputWithPast(
|
1646 |
+
loss=loss,
|
1647 |
+
logits=logits,
|
1648 |
+
past_key_values=outputs.past_key_values,
|
1649 |
+
hidden_states=outputs.hidden_states,
|
1650 |
+
attentions=outputs.attentions,
|
1651 |
+
rope_deltas=rope_deltas,
|
1652 |
+
)
|
1653 |
+
|
1654 |
+
def prepare_inputs_for_generation(
|
1655 |
+
self,
|
1656 |
+
input_ids,
|
1657 |
+
past_key_values=None,
|
1658 |
+
attention_mask=None,
|
1659 |
+
inputs_embeds=None,
|
1660 |
+
cache_position=None,
|
1661 |
+
position_ids=None,
|
1662 |
+
use_cache=True,
|
1663 |
+
pixel_values=None,
|
1664 |
+
pixel_values_videos=None,
|
1665 |
+
image_grid_thw=None,
|
1666 |
+
video_grid_thw=None,
|
1667 |
+
**kwargs,
|
1668 |
+
):
|
1669 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1670 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1671 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
1672 |
+
if past_key_values is not None:
|
1673 |
+
if inputs_embeds is not None: # Exception 1
|
1674 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
1675 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
1676 |
+
input_ids = input_ids[:, cache_position]
|
1677 |
+
|
1678 |
+
rope_deltas = kwargs.get("rope_deltas", None)
|
1679 |
+
if attention_mask is not None and position_ids is None:
|
1680 |
+
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
1681 |
+
position_ids, rope_deltas = self.get_rope_index(
|
1682 |
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
1683 |
+
)
|
1684 |
+
else:
|
1685 |
+
batch_size, seq_length = input_ids.shape
|
1686 |
+
delta = (
|
1687 |
+
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
|
1688 |
+
)
|
1689 |
+
position_ids = torch.arange(seq_length, device=input_ids.device)
|
1690 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
1691 |
+
position_ids = position_ids.add(delta)
|
1692 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
1693 |
+
|
1694 |
+
if cache_position[0] != 0:
|
1695 |
+
pixel_values = None
|
1696 |
+
pixel_values_videos = None
|
1697 |
+
|
1698 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1699 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
1700 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1701 |
+
else:
|
1702 |
+
model_inputs = {"input_ids": input_ids}
|
1703 |
+
|
1704 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
1705 |
+
if inputs_embeds is not None:
|
1706 |
+
batch_size, sequence_length = inputs_embeds.shape
|
1707 |
+
device = inputs_embeds.device
|
1708 |
+
else:
|
1709 |
+
batch_size, sequence_length = input_ids.shape
|
1710 |
+
device = input_ids.device
|
1711 |
+
|
1712 |
+
dtype = self.lm_head.weight.dtype
|
1713 |
+
min_dtype = torch.finfo(dtype).min
|
1714 |
+
|
1715 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1716 |
+
attention_mask,
|
1717 |
+
sequence_length=sequence_length,
|
1718 |
+
target_length=past_key_values.get_max_length(),
|
1719 |
+
dtype=dtype,
|
1720 |
+
device=device,
|
1721 |
+
min_dtype=min_dtype,
|
1722 |
+
cache_position=cache_position,
|
1723 |
+
batch_size=batch_size,
|
1724 |
+
)
|
1725 |
+
|
1726 |
+
model_inputs.update(
|
1727 |
+
{
|
1728 |
+
"position_ids": position_ids,
|
1729 |
+
"past_key_values": past_key_values,
|
1730 |
+
"use_cache": use_cache,
|
1731 |
+
"attention_mask": attention_mask,
|
1732 |
+
"pixel_values": pixel_values,
|
1733 |
+
"pixel_values_videos": pixel_values_videos,
|
1734 |
+
"image_grid_thw": image_grid_thw,
|
1735 |
+
"video_grid_thw": video_grid_thw,
|
1736 |
+
"rope_deltas": rope_deltas,
|
1737 |
+
}
|
1738 |
+
)
|
1739 |
+
return model_inputs
|
1740 |
+
|
1741 |
+
|
1742 |
+
class Qwen2VLSimplifiedModel(Qwen2VLPreTrainedModel):
|
1743 |
+
def __init__(self, config):
|
1744 |
+
super().__init__(config)
|
1745 |
+
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
|
1746 |
+
config.vision_config, attn_implementation=config._attn_implementation
|
1747 |
+
)
|
1748 |
+
self.model = Qwen2VLModel(config)
|
1749 |
+
self.hidden_size = config.hidden_size
|
1750 |
+
|
1751 |
+
# 初始化权重
|
1752 |
+
self.post_init()
|
1753 |
+
|
1754 |
+
def get_input_embeddings(self):
|
1755 |
+
return self.model.embed_tokens
|
1756 |
+
|
1757 |
+
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
|
1758 |
+
# 移除生成相关的更新逻辑
|
1759 |
+
return model_kwargs
|
1760 |
+
|
1761 |
+
def forward(
|
1762 |
+
self,
|
1763 |
+
input_ids: torch.LongTensor = None,
|
1764 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1765 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1766 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1767 |
+
output_attentions: Optional[bool] = None,
|
1768 |
+
output_hidden_states: Optional[bool] = None,
|
1769 |
+
return_dict: Optional[bool] = None,
|
1770 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1771 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
1772 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
1773 |
+
video_grid_thw: Optional[torch.LongTensor] = None
|
1774 |
+
):
|
1775 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1776 |
+
output_hidden_states = (
|
1777 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1778 |
+
)
|
1779 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1780 |
+
|
1781 |
+
if inputs_embeds is None:
|
1782 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
1783 |
+
if pixel_values is not None:
|
1784 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
|
1785 |
+
image_mask = input_ids == self.config.image_token_id
|
1786 |
+
inputs_embeds[image_mask] = image_embeds
|
1787 |
+
if attention_mask is not None:
|
1788 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
1789 |
+
|
1790 |
+
if position_ids is None:
|
1791 |
+
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
1792 |
+
|
1793 |
+
outputs = self.model(
|
1794 |
+
input_ids=None,
|
1795 |
+
position_ids=position_ids,
|
1796 |
+
attention_mask=attention_mask,
|
1797 |
+
inputs_embeds=inputs_embeds,
|
1798 |
+
output_attentions=output_attentions,
|
1799 |
+
output_hidden_states=output_hidden_states,
|
1800 |
+
return_dict=return_dict,
|
1801 |
+
)
|
1802 |
+
hidden_states = outputs[0]
|
1803 |
+
|
1804 |
+
return hidden_states, image_mask, image_grid_thw
|
1805 |
+
|
1806 |
+
def get_rope_index(
|
1807 |
+
self,
|
1808 |
+
input_ids: torch.LongTensor,
|
1809 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
1810 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
1811 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1812 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1813 |
+
"""
|
1814 |
+
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
1815 |
+
|
1816 |
+
Explanation:
|
1817 |
+
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
1818 |
+
|
1819 |
+
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
1820 |
+
Examples:
|
1821 |
+
input_ids: [T T T T T], here T is for text.
|
1822 |
+
temporal position_ids: [0, 1, 2, 3, 4]
|
1823 |
+
height position_ids: [0, 1, 2, 3, 4]
|
1824 |
+
width position_ids: [0, 1, 2, 3, 4]
|
1825 |
+
|
1826 |
+
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
1827 |
+
and 1D rotary position embeddin for text part.
|
1828 |
+
Examples:
|
1829 |
+
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
1830 |
+
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
1831 |
+
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
1832 |
+
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
1833 |
+
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
1834 |
+
text temporal position_ids: [3, 4, 5, 6, 7]
|
1835 |
+
text height position_ids: [3, 4, 5, 6, 7]
|
1836 |
+
text width position_ids: [3, 4, 5, 6, 7]
|
1837 |
+
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
1838 |
+
|
1839 |
+
Args:
|
1840 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1841 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
1842 |
+
it.
|
1843 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
1844 |
+
The temporal, height and width of feature shape of each image in LLM.
|
1845 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
1846 |
+
The temporal, height and width of feature shape of each video in LLM.
|
1847 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1848 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1849 |
+
|
1850 |
+
- 1 for tokens that are **not masked**,
|
1851 |
+
- 0 for tokens that are **masked**.
|
1852 |
+
|
1853 |
+
Returns:
|
1854 |
+
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
1855 |
+
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
1856 |
+
"""
|
1857 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
1858 |
+
image_token_id = self.config.image_token_id
|
1859 |
+
video_token_id = self.config.video_token_id
|
1860 |
+
vision_start_token_id = self.config.vision_start_token_id
|
1861 |
+
mrope_position_deltas = []
|
1862 |
+
if image_grid_thw is not None or video_grid_thw is not None:
|
1863 |
+
total_input_ids = input_ids
|
1864 |
+
position_ids = torch.ones(
|
1865 |
+
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
1866 |
+
)
|
1867 |
+
image_index, video_index = 0, 0
|
1868 |
+
for i, input_ids in enumerate(total_input_ids):
|
1869 |
+
if attention_mask is not None:
|
1870 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
1871 |
+
image_nums, video_nums = 0, 0
|
1872 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
1873 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
1874 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
1875 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
1876 |
+
input_tokens = input_ids.tolist()
|
1877 |
+
llm_pos_ids_list: list = []
|
1878 |
+
st = 0
|
1879 |
+
remain_images, remain_videos = image_nums, video_nums
|
1880 |
+
for _ in range(image_nums + video_nums):
|
1881 |
+
if image_token_id in input_tokens and remain_images > 0:
|
1882 |
+
ed_image = input_tokens.index(image_token_id, st)
|
1883 |
+
else:
|
1884 |
+
ed_image = len(input_tokens) + 1
|
1885 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
1886 |
+
ed_video = input_tokens.index(video_token_id, st)
|
1887 |
+
else:
|
1888 |
+
ed_video = len(input_tokens) + 1
|
1889 |
+
if ed_image < ed_video:
|
1890 |
+
t, h, w = (
|
1891 |
+
image_grid_thw[image_index][0],
|
1892 |
+
image_grid_thw[image_index][1],
|
1893 |
+
image_grid_thw[image_index][2],
|
1894 |
+
)
|
1895 |
+
image_index += 1
|
1896 |
+
remain_images -= 1
|
1897 |
+
ed = ed_image
|
1898 |
+
else:
|
1899 |
+
t, h, w = (
|
1900 |
+
video_grid_thw[video_index][0],
|
1901 |
+
video_grid_thw[video_index][1],
|
1902 |
+
video_grid_thw[video_index][2],
|
1903 |
+
)
|
1904 |
+
video_index += 1
|
1905 |
+
remain_videos -= 1
|
1906 |
+
ed = ed_video
|
1907 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1908 |
+
t.item(),
|
1909 |
+
h.item() // spatial_merge_size,
|
1910 |
+
w.item() // spatial_merge_size,
|
1911 |
+
)
|
1912 |
+
text_len = ed - st
|
1913 |
+
|
1914 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
1915 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
1916 |
+
|
1917 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
1918 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
1919 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
1920 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
1921 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
1922 |
+
|
1923 |
+
if st < len(input_tokens):
|
1924 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
1925 |
+
text_len = len(input_tokens) - st
|
1926 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
1927 |
+
|
1928 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1929 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
1930 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
1931 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
1932 |
+
return position_ids, mrope_position_deltas
|
1933 |
+
else:
|
1934 |
+
if attention_mask is not None:
|
1935 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1936 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1937 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
1938 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
1939 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
1940 |
+
else:
|
1941 |
+
position_ids = (
|
1942 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
1943 |
+
.view(1, 1, -1)
|
1944 |
+
.expand(3, input_ids.shape[0], -1)
|
1945 |
+
)
|
1946 |
+
mrope_position_deltas = torch.zeros(
|
1947 |
+
[input_ids.shape[0], 1],
|
1948 |
+
device=input_ids.device,
|
1949 |
+
dtype=input_ids.dtype,
|
1950 |
+
)
|
1951 |
+
|
1952 |
+
return position_ids, mrope_position_deltas
|
qwen2_vl/processing_qwen2_vl.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""
|
21 |
+
Processor class for Qwen2-VL.
|
22 |
+
"""
|
23 |
+
|
24 |
+
from typing import List, Optional, Union
|
25 |
+
|
26 |
+
from transformers.feature_extraction_utils import BatchFeature
|
27 |
+
from transformers.image_utils import ImageInput, VideoInput
|
28 |
+
from transformers.processing_utils import ProcessorMixin
|
29 |
+
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
30 |
+
from transformers.utils import TensorType, logging
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class Qwen2VLProcessor(ProcessorMixin):
|
37 |
+
r"""
|
38 |
+
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
|
39 |
+
|
40 |
+
[`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
41 |
+
[`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
45 |
+
The image processor is a required input.
|
46 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
47 |
+
The tokenizer is a required input.
|
48 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
49 |
+
in a chat into a tokenizable string.
|
50 |
+
"""
|
51 |
+
|
52 |
+
attributes = ["image_processor", "tokenizer"]
|
53 |
+
valid_kwargs = ["chat_template"]
|
54 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
55 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
56 |
+
|
57 |
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
58 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
59 |
+
|
60 |
+
def __call__(
|
61 |
+
self,
|
62 |
+
images: ImageInput = None,
|
63 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
64 |
+
videos: VideoInput = None,
|
65 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
66 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
67 |
+
max_length: int = None,
|
68 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
69 |
+
) -> BatchFeature:
|
70 |
+
"""
|
71 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
72 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
73 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
74 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
78 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
79 |
+
tensor. Both channels-first and channels-last formats are supported.
|
80 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
81 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
82 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
83 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
84 |
+
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
85 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
86 |
+
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
87 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
88 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
89 |
+
index) among:
|
90 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
91 |
+
sequence if provided).
|
92 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
93 |
+
acceptable input length for the model if that argument is not provided.
|
94 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
95 |
+
lengths).
|
96 |
+
max_length (`int`, *optional*):
|
97 |
+
Maximum length of the returned list and optionally padding length (see above).
|
98 |
+
truncation (`bool`, *optional*):
|
99 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
100 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
101 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
102 |
+
|
103 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
104 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
105 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
106 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
110 |
+
|
111 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
112 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
113 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
114 |
+
`None`).
|
115 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
116 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
117 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
118 |
+
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
119 |
+
"""
|
120 |
+
if images is not None:
|
121 |
+
image_inputs = self.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
122 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
123 |
+
else:
|
124 |
+
image_inputs = {}
|
125 |
+
image_grid_thw = None
|
126 |
+
|
127 |
+
if videos is not None:
|
128 |
+
videos_inputs = self.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
129 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
130 |
+
else:
|
131 |
+
videos_inputs = {}
|
132 |
+
video_grid_thw = None
|
133 |
+
|
134 |
+
if not isinstance(text, list):
|
135 |
+
text = [text]
|
136 |
+
|
137 |
+
if image_grid_thw is not None:
|
138 |
+
merge_length = self.image_processor.merge_size**2
|
139 |
+
index = 0
|
140 |
+
for i in range(len(text)):
|
141 |
+
while "<|image_pad|>" in text[i]:
|
142 |
+
text[i] = text[i].replace(
|
143 |
+
"<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
144 |
+
)
|
145 |
+
index += 1
|
146 |
+
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
147 |
+
|
148 |
+
if video_grid_thw is not None:
|
149 |
+
merge_length = self.image_processor.merge_size**2
|
150 |
+
index = 0
|
151 |
+
for i in range(len(text)):
|
152 |
+
while "<|video_pad|>" in text[i]:
|
153 |
+
text[i] = text[i].replace(
|
154 |
+
"<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
|
155 |
+
)
|
156 |
+
index += 1
|
157 |
+
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
158 |
+
|
159 |
+
text_inputs = self.tokenizer(
|
160 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
161 |
+
)
|
162 |
+
|
163 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
164 |
+
|
165 |
+
def batch_decode(self, *args, **kwargs):
|
166 |
+
"""
|
167 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
168 |
+
refer to the docstring of this method for more information.
|
169 |
+
"""
|
170 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
171 |
+
|
172 |
+
def decode(self, *args, **kwargs):
|
173 |
+
"""
|
174 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
175 |
+
the docstring of this method for more information.
|
176 |
+
"""
|
177 |
+
return self.tokenizer.decode(*args, **kwargs)
|
178 |
+
|
179 |
+
@property
|
180 |
+
def model_input_names(self):
|
181 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
182 |
+
image_processor_input_names = self.image_processor.model_input_names
|
183 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
huggingface_hub
|
3 |
+
gradio_imageslider
|
4 |
+
requests
|
5 |
+
numpy<2
|
6 |
+
torch
|
7 |
+
torchvision
|
8 |
+
transformers
|
9 |
+
diffusers>=0.30.0
|
10 |
+
accelerate
|
11 |
+
Pillow
|
12 |
+
opencv-python-headless>=4.8.0
|
13 |
+
protobuf
|
14 |
+
sentencepiece
|
15 |
+
git+https://github.com/huggingface/controlnet_aux
|
16 |
+
mediapipe
|
17 |
+
sam2
|
18 |
+
optimum
|
19 |
+
optimum-quanto
|
20 |
+
matplotlib
|
21 |
+
xformers
|
22 |
+
bitsandbytes
|
technical-report.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d9cdff3966330d9a59053cd1c2163421d803e9802533061d3e6bca23a62054f8
|
3 |
+
size 66144700
|