jimmycarter commited on
Commit
f5b866b
1 Parent(s): de30f79

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/comparisons_full/comparison_0.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/comparisons_full/comparison_1.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/comparisons_full/comparison_10.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/comparisons_full/comparison_11.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/comparisons_full/comparison_12.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/comparisons_full/comparison_2.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/comparisons_full/comparison_3.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/comparisons_full/comparison_4.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/comparisons_full/comparison_5.jpg filter=lfs diff=lfs merge=lfs -text
45
+ assets/comparisons_full/comparison_6.jpg filter=lfs diff=lfs merge=lfs -text
46
+ assets/comparisons_full/comparison_7.jpg filter=lfs diff=lfs merge=lfs -text
47
+ assets/comparisons_full/comparison_8.jpg filter=lfs diff=lfs merge=lfs -text
48
+ assets/comparisons_full/comparison_9.jpg filter=lfs diff=lfs merge=lfs -text
49
+ assets/comparisons/lady.jpg filter=lfs diff=lfs merge=lfs -text
50
+ assets/comparisons/lime.jpg filter=lfs diff=lfs merge=lfs -text
51
+ assets/comparisons/teal_woman.jpg filter=lfs diff=lfs merge=lfs -text
52
+ assets/comparisons/witch.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,263 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LibreFLUX: A free, de-distilled FLUX model
2
+
3
+ LibreFLUX is an Apache 2.0 version of [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) that provides a full T5 context length, uses attention masking, has classifier free guidance restored, and has had most of the FLUX aesthetic finetuning/DPO fully removed. That means it's a lot uglier than base flux, but it has the potential to be more easily finetuned to any new distribution. It keeps in mind the core tenets of open source software, that it should be difficult to use, slower and clunkier than a proprietary solution, and have an aesthetic trapped somewhere inside the early 2000s.
4
+
5
+ ![De-distillation t-shirt](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/splash.jpg)
6
+
7
+ > The image features a man standing confidently, wearing a simple t-shirt with a humorous and quirky message printed across the front. The t-shirt reads: "I de-distilled FLUX into a slow, ugly model and all I got was this stupid t-shirt." The man’s expression suggests a mix of pride and irony, as if he's aware of the complexity behind the statement, yet amused by the underwhelming reward. The background is neutral, keeping the focus on the man and his t-shirt, which pokes fun at the frustrating and often anticlimactic nature of technical processes or complex problem-solving, distilled into a comically understated punchline.
8
+
9
+ ## Table of Contents
10
+
11
+ - [LibreFLUX: A free, de-distilled FLUX model](#libreflux-a-free-de-distilled-flux-model)
12
+ - [Usage](#usage)
13
+ - [Non-technical Report on Schnell De-distillation](#non-technical-report-on-schnell-de-distillation)
14
+ - [Why](#why)
15
+ - [Restoring the Original Training Objective](#restoring-the-original-training-objective)
16
+ - [FLUX and Attention Masking](#flux-and-attention-masking)
17
+ - [Make De-distillation Go Fast and Fit in Small GPUs](#make-de-distillation-go-fast-and-fit-in-small-gpus)
18
+ - [Selecting Better Layers to Train with LoKr](#selecting-better-layers-to-train-with-lokr)
19
+ - [Beta Timestep Scheduling and Timestep Stratification](#beta-timestep-scheduling-and-timestep-stratification)
20
+ - [Datasets](#datasets)
21
+ - [Training](#training)
22
+ - [Post-hoc "EMA"](#post-hoc-ema)
23
+ - [Results](#results)
24
+ - [Closing Thoughts](#closing-thoughts)
25
+ - [Contacting Me and Grants](#contacting-me-and-grants)
26
+ - [Citation](#citation)
27
+
28
+ # Usage
29
+
30
+ To use the model, just call the custom pipeline using [diffusers](https://github.com/huggingface/diffusers).
31
+
32
+ ```py
33
+ from diffusers import DiffusionPipeline
34
+ pipeline = DiffusionPipeline.from_pretrained(
35
+ "jimmycarter/LibreFLUX",
36
+ custom_pipeline="jimmycarter/LibreFLUX",
37
+ use_safetensors=True,
38
+ )
39
+
40
+ # High VRAM
41
+ prompt = "Photograph of a chalk board on which is written: 'I thought what I'd do was, I'd pretend I was one of those deaf-mutes.'"
42
+ negative_prompt = "blurry"
43
+ images = pipeline(
44
+ prompt=prompt,
45
+ negative_prompt=negative_prompt,
46
+ )
47
+ images[0].save('chalkboard.png')
48
+
49
+ # If you have <=24 GB VRAM, try:
50
+ # ! pip install optimum-quanto
51
+ # Then
52
+ from optimum.quanto import freeze, quantize, qint8
53
+ quantize(
54
+ pipe.transformer,
55
+ weights=qint8,
56
+ exclude=[
57
+ "*.norm", "*.norm1", "*.norm2", "*.norm2_context",
58
+ "proj_out", "x_embedder", "norm_out", "context_embedder",
59
+ ],
60
+ )
61
+ freeze(pipe.transformer)
62
+ pipe.enable_model_cpu_offload()
63
+ images = pipeline(
64
+ prompt=prompt,
65
+ negative_prompt=negative_prompt,
66
+ device=None,
67
+ )
68
+ images[0].save('chalkboard.png')
69
+ ```
70
+
71
+ # Non-technical Report on Schnell De-distillation
72
+
73
+ Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
74
+
75
+ ![Science.](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/science.png)
76
+
77
+ **Everyone is ~~an artist~~ a machine learning researcher.**
78
+
79
+ ## Why
80
+
81
+ FLUX is a good text-to-image model, but the only versions of it that are out are distilled. FLUX.1-dev is distilled so that you don't need to use CFG (classifier free guidance), so instead of making one sample for conditional (your prompt) and unconditional (negative prompt), you only have to make the sample for conditional. This means that FLUX.1-dev is twice as fast as the model without distillation.
82
+
83
+ FLUX.1-schnell (German for "fast") is further distilled so that you only need 4 steps of conditional generation to get an image. Importantly, FLUX.1-schnell has an Apache-2.0 license, so you can use it freely without having to obtain a commercial license from Black Forest Labs. Out of the box, schnell is pretty bad when you use CFG unless you skip the first couple of steps.
84
+
85
+ The FLUX distilled models are created for their base, non-distilled models by [training on output from the teacher model (non-distilled) to student model (distilled) along with some tricks like an adversarial network](https://arxiv.org/abs/2403.12015).
86
+
87
+ For de-distilled models, image generation takes a little less than twice as long because you need to compute a sample for both conditional and unconditional images at each step. The benefit is you can use them commercially for free, training is a little easier, and they may be more creative.
88
+
89
+ ## Restoring the original training objective
90
+
91
+ This part is actually really easy. You just train it on the normal flow-matching objective with MSE loss and the model starts learning how to do it again. That being said, I don't think either LibreFLUX or [OpenFLUX.1](https://huggingface.co/ostris/OpenFLUX.1) managed to fully de-distill the model. The evidence I see for that is that both models will either get strange shadows that overwhelm the image or blurriness when using CFG scale values greater than 4.0. Neither of us trained very long in comparison to the training for the original model (assumed to be around 0.5-2.0m H100 hours), so it's not particularly surprising.
92
+
93
+ ## FLUX and attention masking
94
+
95
+ FLUX models use a text model called T5-XXL to get most of its conditioning for the text-to-image task. Importantly, they pad the text out to either 256 (schnell) or 512 (dev) tokens. 512 tokens is the maximum trained length for the model. By padding, I mean they repeat the last token until the sequence is this length.
96
+
97
+ This results in the model using these padding tokens to [store information](https://arxiv.org/abs/2309.16588). When you [visualize the attention maps of the tokens in the padding segment of the text encoder](https://github.com/kaibioinfo/FluxAttentionMap/blob/main/attentionmap.ipynb), you can see that about 10-40 tokens shortly after the last token of the text and about 10-40 tokens at the end of the padding contain information which the model uses to make images. Because these are normally used to store information, it means that any prompt long enough to not have some of these padding tokens will end up with degraded performance.
98
+
99
+ It's easy to prevent this by masking out these padding token during attention. BFL and their engineers know this, but they probably decided against it because it works as is and most fast implementations of attention only work with causal (LLM) types of padding and so would let them train faster.
100
+
101
+ I already [implemented attention masking](https://github.com/bghira/SimpleTuner/blob/main/helpers/models/flux/transformer.py#L404-L406) and I would like to be able to use all 512 tokens without degradation, so I did my finetune with it on. Small scale finetunes with it on tend to damage the model, but since I need to train so much out of distillation schnell to make it work anyway I figured it probably didn't matter to add it.
102
+
103
+ Note that FLUX.1-schnell was only trained on 256 tokens, so my finetune allows users to use the whole 512 token sequence length.
104
+
105
+ ## Make de-distillation go fast and fit in small GPUs
106
+
107
+ I avoided doing any full-rank (normal, all parameters) finetuning at all, since FLUX is big. I trained initially with the model in int8 precision using [quanto](https://github.com/huggingface/optimum-quanto). I started with a 600 million parameter [LoKr](https://arxiv.org/abs/2309.14859), since LoKr tends to approximate full-rank finetuning better than LoRA. The loss was really slow to go down when I began, so after poking around the code to initialize the matrix to apply to the LoKr I settled on this function, which injects noise at a fraction of the magnitudes of the layers they apply to.
108
+
109
+ ```py
110
+ def approximate_normal_tensor(inp, target, scale=1.0):
111
+ tensor = torch.randn_like(target)
112
+ desired_norm = inp.norm()
113
+ desired_mean = inp.mean()
114
+ desired_std = inp.std()
115
+
116
+ current_norm = tensor.norm()
117
+ tensor = tensor * (desired_norm / current_norm)
118
+ current_std = tensor.std()
119
+ tensor = tensor * (desired_std / current_std)
120
+ tensor = tensor - tensor.mean() + desired_mean
121
+ tensor.mul_(scale)
122
+
123
+ target.copy_(tensor)
124
+
125
+
126
+ def init_lokr_network_with_perturbed_normal(lycoris, scale=1e-3):
127
+ with torch.no_grad():
128
+ for lora in lycoris.loras:
129
+ lora.lokr_w1.fill_(1.0)
130
+ approximate_normal_tensor(lora.org_weight, lora.lokr_w2, scale=scale)
131
+ ```
132
+
133
+ This isn't normal PEFT (parameter efficient fine-tuning) anymore, because this will perturb all the weights of the model slightly in the beginning. It doesn't seem to cause any performance degradation in the model after testing and it made the loss fall for my LoKr twice as fast, so I used it with `scale=1e-3`. The LoKr weights I trained in bfloat16, with the `adamw_bf16` optimizer that I ~~plagiarized~~ wrote with the magic of open source software.
134
+
135
+ ## Selecting better layers to train with LoKr
136
+
137
+ FLUX is a pretty standard transformer model aside from some peculiarities. One of these peculiarities is in their "norm" layers, which contain non-linearities so they don't act like norms except for a single normalization that is applied in the layer without any weights (LayerNorm with `elementwise_affine=False`). When you fine-tune and look at what changes these layers are one of the big ones that seems to change.
138
+
139
+ The other thing about transformers is that [all the heavy lifting is most often done at the start and end layers of the network](https://arxiv.org/abs/2403.17887), so you may as well fine-tune those more than other layers. When I looked at the cosine similarity of the hidden states between each block in diffusion transformers, it more or less reflected what was observed with LLMs. So I made a pull-request to the LyCORIS repository (that maintains a LoKr implementation) that lets you more easily pick individual layers and set different factors on them, then focused my LoKr on these layers.
140
+
141
+ ## Beta timestep scheduling and timestep stratification
142
+
143
+ One problem with diffusion models is that they are [multi-task](https://arxiv.org/abs/2211.01324) (different timesteps are considered different tasks) and the tasks all tend to be associated with differently shaped and sized gradients and different magnitudes of loss. This is very much not a big deal when you have a huge batch size, so the timesteps of the model all get more or less sampled evenly and the gradients are smoothed out and have less variance. I also knew that the schnell model had more problems with image distortions caused by sampling at the high-noise timesteps, so I did two things:
144
+
145
+ 1. Implemented a Beta schedule that approximates the original sigmoid sampling, to let me shift the timesteps sampled to the high noise steps similar but less extreme than some of the alternative sampling methods in the SD3 research paper.
146
+ 2. Implement multi-rank stratified sampling so that during each step the model trained timesteps were selected per batch based on regions, which normalizes the gradients significantly like using a higher batch size would.
147
+
148
+ ```py
149
+ alpha = 2.0
150
+ beta = 2.0
151
+ num_processes = self.accelerator.num_processes
152
+ process_index = self.accelerator.process_index
153
+ total_bsz = num_processes * bsz
154
+ start_idx = process_index * bsz
155
+ end_idx = (process_index + 1) * bsz
156
+ indices = torch.arange(start_idx, end_idx, dtype=torch.float64)
157
+ u = torch.rand(bsz)
158
+ p = (indices + u) / total_bsz
159
+ sigmas = torch.from_numpy(
160
+ sp_beta.ppf(p.numpy(), a=alpha, b=beta)
161
+ ).to(device=self.accelerator.device)
162
+ ```
163
+
164
+ ## Datasets
165
+
166
+ No one talks about what datasets they train anymore, but I used open ones from the web captioned with VLMs and 2-3 captions per image. There was at least one short and one long caption for every image. The datasets were diverse and most of them did not have aesthetic selection, which helped direct the model away from the traditional hyper-optimized image generation of text-to-image models. Many people think that looks worse, but I like that it can make a diverse pile of images. The model was trained on about 0.5 million high resolution images in both random square crops and random aspect ratio crops.
167
+
168
+ ## Training
169
+
170
+ I started training for over a month on a 5x 3090s and about 500,000 images. I used a 600m LoKr for this. The model looked okay after. Then, I [unexpectedly gained access to 7x H100s for compute resources](https://rundiffusion.com), so I merged my PEFT model in and began training on a new LoKr with 3.2b parameters.
171
+
172
+ ## Post-hoc "EMA"
173
+
174
+ I've been too lazy to implement real [post-hoc EMA like from EDM2](https://github.com/lucidrains/ema-pytorch/blob/main/ema_pytorch/post_hoc_ema.py), but to approximate it I saved all the checkpoints from the H100 runs and then LERPed them iteratively with different alpha values. I evaluated those checkpoints at different CFG scales to see if any of them were superior to the last checkpoint.
175
+
176
+ ```py
177
+ first_checkpoint_file = checkpoint_files[0]
178
+ ema_state_dict = load_file(first_checkpoint_file)
179
+ for checkpoint_file in checkpoint_files[1:]:
180
+ new_state_dict = load_file(checkpoint_file)
181
+ for k in ema_state_dict.keys():
182
+ ema_state_dict[k] = torch.lerp(
183
+ ema_state_dict[k],
184
+ new_state_dict[k],
185
+ alpha,
186
+ )
187
+
188
+ output_file = os.path.join(output_folder, f"alpha_linear_{alpha}.safetensors")
189
+ save_file(ema_state_dict, output_file)
190
+ ```
191
+
192
+ After looking at all models in alphas `[0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 0.995, 0.999]`, I ended up settling on alpha 0.9 using the power of my eyeballs. If I am being frank, many of the EMA models looked remarkably similar and had the same kind of "rolling around various minima" qualities that training does in general.
193
+
194
+ ## Results
195
+
196
+ I will go over the results briefly, but I'll start with the images.
197
+
198
+ **Figure 1.** Some side-by-side images of LibreFLUX and [OpenFLUX.1](https://huggingface.co/ostris/OpenFLUX.1). They were made using diffusers, with 512-token maximum length text embeddings for LibreFLUX and 256-token maximum length for OpenFLUX.1. LibreFLUX had attention masking on while OpenFLUX did not. The models were sampled with 35 steps at various resolutions. The negative prompt for both was simply "blurry". All inference was done with the transformer quantized to int8 by quanto.
199
+
200
+ ![Polar bear](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/bear.jpg)
201
+
202
+ > A cinematic style shot of a polar bear standing confidently in the center of a vibrant nightclub. The bear is holding a large sign that reads 'Open Source! Apache 2.0' in one arm and giving a thumbs up with the other arm. Around him, the club is alive with energy as colorful lasers and disco lights illuminate the scene. People are dancing all around him, wearing glowsticks and candy bracelets, adding to the fun and electric atmosphere. The polar bear's white fur contrasts against the dark, neon-lit background, and the entire scene has a surreal, festive vibe, blending technology activism with a lively party environment.
203
+
204
+ ![Artistic picture of woman](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/lady.jpg)
205
+
206
+ > widescreen, vintage style from 1970s, Extreme realism in a complex, highly detailed composition featuring a woman with extremely long flowing rainbow-colored hair. The glowing background, with its vibrant colors, exaggerated details, intricate textures, and dynamic lighting, creates a whimsical, dreamy atmosphere in photorealistic quality. Threads of light that float and weave through the air, adding movement and intrigue. Patterns on the ground or in the background that glow subtly, adding a layer of complexity.Rainbows that appear faintly in the background, adding a touch of color and wonder.Butterfly wings that shimmer in the light, adding life and movement to the scene.Beams of light that radiate softly through the scene, adding focus and direction. The woman looks away from the camera, with a soft, wistful expression, her hair framing her face.
207
+
208
+ ![Western movie poster](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/lime.jpg)
209
+
210
+ > a highly detailed and atmospheric, painted western movie poster with the title text "Once Upon a Lime in the West" in a dark red western-style font and the tagline text "There were three men ... and one very sour twist", with movie credits at the bottom, featuring small white text detailing actor and director names and production company logos, inspired by classic western movie posters from the 1960s, an oversized lime is the central element in the middle ground of a rugged, sun-scorched desert landscape typical of a western, the vast expanse of dry, cracked earth stretches toward the horizon, framed by towering red rock formations, the absurdity of the lime is juxtaposed with the intense gravitas of the stoic, iconic gunfighters, as if the lime were as formidable an adversary as any seasoned gunslinger, in the foreground, the silhouettes of two iconic gunfighters stand poised, facing the lime and away from the viewer, the lime looms in the distance like a final showdown in the classic western tradition, in the foreground, the gunfighters stand with long duster coats flowing in the wind, and wide-brimmed hats tilted to cast shadows over their faces, their stances are tense, as if ready for the inevitable draw, and the weapons they carry glint, the background consists of the distant town, where the sun is casting a golden glow, old wooden buildings line the sides, with horses tied to posts and a weathered saloon sign swinging gently in the wind, in this poster, the lime plays the role of the silent villain, an almost mythical object that the gunfighters are preparing to confront, the tension of the scene is palpable, the gunfighters in the foreground have faces marked by dust and sweat, their eyes narrowed against the bright sunlight, their expressions are serious and resolute, as if they have come a long way for this final duel, the absurdity of the lime is in stark contrast with their stoic demeanor, a wide, panoramic shot captures the entire scene, with the gunfighters in the foreground, the lime in the mid-ground, and the town on the horizon, the framing emphasizes the scale of the desert and the dramatic standoff taking place, while subtly highlighting the oversized lime, the camera is positioned low, angled upward from the dusty ground toward the gunfighters, with the distant lime looming ahead, this angle lends the figures an imposing presence, while still giving the lime an absurd grandeur in the distance, the perspective draws the viewer’s eye across the desert, from the silhouettes of the gunfighters to the bizarre focal point of the lime, amplifying the tension, the lighting is harsh and unforgiving, typical of a desert setting, with the evening sun casting deep shadows across the ground, dust clouds drift subtly across the ground, creating a hazy effect, while the sky above is a vast expanse of pale blue, fading into golden hues near the horizon where the sun begins to set, the poster is shot as if using classic anamorphic lenses to capture the wide, epic scale of the desert, the color palette is warm and saturated, evoking the look of a classic spaghetti western, the lime looms unnaturally in the distance, as if conjured from the land itself, casting an absurdly grand shadow across the rugged landscape, the texture and detail evoke hand-painted, weathered posters from the golden age of westerns, with slightly frayed edges and faint creases mimicking the wear of vintage classics
211
+
212
+ ![Witch action figure](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/witch.jpg)
213
+
214
+ > A boxed action figure of a beautiful elf girl witch wearing a skimpy black leotard, black thigh highs, black armlets, and a short black cloak. Her hair is pink and shoulder-length. Her eyes are green. She is a slim and attractive elf with small breasts. The accessories include an apple, magic wand, potion bottle, black cat, jack o lantern, and a book. The box is orange and black with a logo near the bottom of it that says "BAD WITCH". The box is on a shelf on the toy aisle.
215
+
216
+ ![Photograph of woman in teal room with dog](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/teal_woman.jpg)
217
+
218
+ > A cute blonde woman in bikini and her doge are sitting on a couch cuddling and the expressive, stylish living room scene with a playful twist. The room is painted in a soothing turquoise color scheme, stylish living room scene bathed in a cool, textured turquoise blanket and adorned with several matching turquoise throw pillows. The room's color scheme is predominantly turquoise, relaxed demeanor. The couch is covered in a soft, reflecting light and adding to the vibrant blue hue., dark room with a sleek, spherical gold decorations, This photograph captures a scene that is whimsically styled in a vibrant, reflective cyan sunglasses. The dog's expression is cheerful, metallic fabric sofa. The dog, soothing atmosphere.
219
+
220
+ ![Selfie of a man and woman](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/selfie.jpg)
221
+
222
+ > Selfie of a woman in front of the eiffel tower, a man is standing next to her and giving a thumbs up
223
+
224
+ ![Image of just text](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/scars.jpg)
225
+
226
+ > An image contains three motivational phrases, all in capitalized stylized text on a colorful background: 1. At the top: "PAIN HEALS" 2. In the middle, bold and slightly larger: "CHICKS DIG SCARS" 3. At the bottom: "GLORY LASTS FOREVER"
227
+
228
+ ![Digital art with lots of details specified of McDonald's on the moon](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/moon.jpg)
229
+
230
+ > An illustration featuring a McDonald's on the moon. An anthropomorphic cat in a pink top and blue jeans is ordering McDonald's, while a zebra cashier stands behind the counter. The moon's surface is visible outside the windows, with craters and a distant view of Earth. The interior of the McDonald's is similar to those on Earth but adapted to the lunar environment, with vibrant colors and futuristic design elements. The overall scene is whimsical and imaginative, blending everyday life with a fantastical setting.
231
+
232
+ LibreFLUX and OpenFLUX have their strengths and weaknesses. OpenFLUX was de-distilled using the outputs of FLUX.1-schnell, which might explain why it's worse at text but also has the FLUX hyperaesthetics. Text-to-image models [don't have any good metrics](https://arxiv.org/abs/2306.04675) so past a point of "soupiness" and single digit FID you just need to look at the model and see if it fits what you think nice pictures are.
233
+
234
+ Both models appear to be terrible at making drawings. Because people are probably curious to see the non-cherry picks, [I've included CFG sweep comparisons of both LibreFLUX and OpenFLUX.1 here](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons_full/). I'm not going to say this is the best model ever, but it might be a springboard for people wanting to finetune better models from.
235
+
236
+ ## Closing thoughts
237
+
238
+ If I had to do it again, I'd probably raise the learning rate more on the H100 run. There was a [bug in SimpleTuner](https://github.com/bghira/SimpleTuner/issues/1064) that caused me to not use the [initialization trick](#make-de-distillation-go-fast-and-fit-in-small-gpus) when on the H100s, then [timestep stratification](#beta-timestep-scheduling-and-timestep-stratification) ended up quieting down the gradient magnitudes even more and caused the model to learn very slowly at `1e-5`. I realized this when looking at the results of EMA on the final FLUX.1-dev. The H100s really came out of nowhere as I just got an IP address to shell into late one night around 10PM and ended up staying up all night to get everything running, so in the future I'm sure I would be more prepared.
239
+
240
+ For de-distillation of schnell I think you probably need a lot more than 1500 H100-equivalent hours. I am very tired of training FLUX and am looking forward to a better model with less parameters. The model learns new concepts slowly when given piles of well labeled data. Given the history of LLMs, we now have models like LLaMA 3.1 8B that trade blows with GPT3.5 175B and I am hopeful that the future holds [smaller, faster models that look better](https://openreview.net/pdf?id=jQP5o1VAVc).
241
+
242
+ As far as what I think of the FLUX "open source", many models being trained and released today are attempts at raising VC cash and I have noticed a mountain of them being promoted on Twitter. Since [a16z poached the entire SD3 dev team from Stability.ai](https://siliconcanals.com/black-forest-labs-secures-28m/) the field feels more toxic than ever, but I am hopeful for individuals and research labs to selflessly lead the path forward for open weights. I made zero dollars on this and have made zero dollars on ML to date, but I try to make contributions where I can.
243
+
244
+ ![The state of open source](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/opensource.png)
245
+
246
+ I would like to thank [RunDiffusion](https://rundiffusion.com) for the H100 access.
247
+
248
+ ## Contacting me and grants
249
+
250
+ You can contact me by opening an issue on the discuss page of this model. If you want to speak privately about grants because you want me to continue training this or give me a means to conduct reproducible research, leave an email address too.
251
+
252
+ ## Citation
253
+
254
+ ```
255
+ @misc{libreflux,
256
+ author = {James Carter},
257
+ title = {LibreFLUX: A free, de-distilled FLUX model},
258
+ year = {2024},
259
+ publisher = {Huggingface},
260
+ journal = {Huggingface repository},
261
+ howpublished = {\url{https://huggingface.co/datasets/jimmycarter/libreflux}},
262
+ }
263
+ ```
assets/comparisons/bear.jpg ADDED
assets/comparisons/lady.jpg ADDED

Git LFS Details

  • SHA256: 49a3159ecf7344b79cb1dac15d4d58e223cfa3ab5a5d0c3c20408075a975d48c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
assets/comparisons/lime.jpg ADDED

Git LFS Details

  • SHA256: c7118c7f8a5bdb986c7c730d8ff58c591d75220ca8b0292e4515d3399825ea41
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
assets/comparisons/moon.jpg ADDED
assets/comparisons/scars.jpg ADDED
assets/comparisons/selfie.jpg ADDED
assets/comparisons/teal_woman.jpg ADDED

Git LFS Details

  • SHA256: 65b63c23197cb7c04ecee0e093a6f622445421696cff4b581b6e16cc293d5b59
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
assets/comparisons/witch.jpg ADDED

Git LFS Details

  • SHA256: 8b1ae272c5773f2c20d277c9211c6058ccef6c6f0bc43eea2bbf5654b84fbba0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/comparisons_full/comparison_0.jpg ADDED

Git LFS Details

  • SHA256: 85bb40e0d7a45c5746a5fcde617d2030c9bab539e9efb68764f527a1ccbfd1fc
  • Pointer size: 132 Bytes
  • Size of remote file: 7.65 MB
assets/comparisons_full/comparison_1.jpg ADDED

Git LFS Details

  • SHA256: 269eab0f051d10787d59c4fb5dc1ef4afa724630fcb07b3e3fc3a00b5f6593bc
  • Pointer size: 132 Bytes
  • Size of remote file: 8.34 MB
assets/comparisons_full/comparison_10.jpg ADDED

Git LFS Details

  • SHA256: 2ad909ec1b1ccfe7f546ac42b33cfd7e77036160e39d919759446711a0618af6
  • Pointer size: 132 Bytes
  • Size of remote file: 8.81 MB
assets/comparisons_full/comparison_11.jpg ADDED

Git LFS Details

  • SHA256: 1d9e2fd962c736c474420a367df9985ff0833f682fc8e7bda4ae2e2d51a1d4a1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.64 MB
assets/comparisons_full/comparison_12.jpg ADDED

Git LFS Details

  • SHA256: 7e0ce943b90254f8fc396466865e01cc8b2da4abe9e80d2b098938f1ac98e415
  • Pointer size: 132 Bytes
  • Size of remote file: 4.26 MB
assets/comparisons_full/comparison_2.jpg ADDED

Git LFS Details

  • SHA256: a5dd339b56006c2c6fc5634455d5b769c66e48591c81f32cc96de55d41effe5d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.01 MB
assets/comparisons_full/comparison_3.jpg ADDED

Git LFS Details

  • SHA256: ff63da27e5e5bca9aa6d94432751ed9914535be00e9756d820424b2342903332
  • Pointer size: 132 Bytes
  • Size of remote file: 6.18 MB
assets/comparisons_full/comparison_4.jpg ADDED

Git LFS Details

  • SHA256: 351f76d69d3d5ffe336820a4bced043a0f330ec7b1ff28cb62f3481f7b766001
  • Pointer size: 132 Bytes
  • Size of remote file: 6.16 MB
assets/comparisons_full/comparison_5.jpg ADDED

Git LFS Details

  • SHA256: f26a2c03bc9086c9194e598e02b64e80a31bd9e0802727133b8bcaf8750e675a
  • Pointer size: 132 Bytes
  • Size of remote file: 6.56 MB
assets/comparisons_full/comparison_6.jpg ADDED

Git LFS Details

  • SHA256: 747a7611d787a6ecaaf31ac032ab346f63f7cf5bc59b8799adc56cfdfa26fbe6
  • Pointer size: 132 Bytes
  • Size of remote file: 6.82 MB
assets/comparisons_full/comparison_7.jpg ADDED

Git LFS Details

  • SHA256: f8676c6464eae55f1185c8e6c729a70935c9995c140df803c3d3334d4c642086
  • Pointer size: 132 Bytes
  • Size of remote file: 7.92 MB
assets/comparisons_full/comparison_8.jpg ADDED

Git LFS Details

  • SHA256: 74ec133d10782a8f490fb74120d4f740be8dd08e1d80e610b242d1a91935ab0e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.46 MB
assets/comparisons_full/comparison_9.jpg ADDED

Git LFS Details

  • SHA256: 43946250e55f904b6d73489aedf1a0d3fbd8c405f442dfcff5ebce410bec1e23
  • Pointer size: 132 Bytes
  • Size of remote file: 5 MB
assets/comparisons_full/prompts.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts = [
2
+ ( # 0
3
+ "A wide format poster featuring George Washington atop a glorious bald eagle with its wings spread flying through the sky, background is the American flag and fireworks. Huge shiny Red white and blue, bold, gradient letters at the bottom spelling out \"WTF IS A KILOMETER\" in flaming text. 4k, masterpiece.",
4
+ (1536, 1024),
5
+ ),
6
+ ( # 1
7
+ "widescreen, vintage style from 1970s, Extreme realism in a complex, highly detailed composition featuring a woman with extremely long flowing rainbow-colored hair. The glowing background, with its vibrant colors, exaggerated details, intricate textures, and dynamic lighting, creates a whimsical, dreamy atmosphere in photorealistic quality. Threads of light that float and weave through the air, adding movement and intrigue. Patterns on the ground or in the background that glow subtly, adding a layer of complexity.Rainbows that appear faintly in the background, adding a touch of color and wonder.Butterfly wings that shimmer in the light, adding life and movement to the scene.Beams of light that radiate softly through the scene, adding focus and direction. The woman looks away from the camera, with a soft, wistful expression, her hair framing her face. ",
8
+ (1536, 1024),
9
+ ),
10
+ ( # 2
11
+ 'A cinematic style shot of a polar bear standing confidently in the center of a vibrant nightclub. The bear is holding a large sign that reads \'Open Source! Apache 2.0\' in one arm and giving a thumbs up with the other arm. Around him, the club is alive with energy as colorful lasers and disco lights illuminate the scene. People are dancing all around him, wearing glowsticks and candy bracelets, adding to the fun and electric atmosphere. The polar bear\'s white fur contrasts against the dark, neon-lit background, and the entire scene has a surreal, festive vibe, blending technology activism with a lively party environment.',
12
+ (1536, 1024),
13
+ ),
14
+ ( # 3
15
+ 'A boxed action figure of a beautiful elf girl witch wearing a skimpy black leotard, black thigh highs, black armlets, and a short black cloak. Her hair is pink and shoulder-length. Her eyes are green. She is a slim and attractive elf with small breasts. The accessories include an apple, magic wand, potion bottle, black cat, jack o lantern, and a book. The box is orange and black with a logo near the bottom of it that says "BAD WITCH". The box is on a shelf on the toy aisle.',
16
+ (1024, 1536),
17
+ ),
18
+ ( # 4
19
+ "A cute blonde woman in bikini and her doge are sitting on a couch cuddling and the expressive, stylish living room scene with a playful twist. The room is painted in a soothing turquoise color scheme, stylish living room scene bathed in a cool, textured turquoise blanket and adorned with several matching turquoise throw pillows. The room's color scheme is predominantly turquoise, relaxed demeanor. The couch is covered in a soft, reflecting light and adding to the vibrant blue hue., dark room with a sleek, spherical gold decorations, This photograph captures a scene that is whimsically styled in a vibrant, reflective cyan sunglasses. The dog's expression is cheerful, metallic fabric sofa. The dog, soothing atmosphere.",
20
+ (1536, 1024),
21
+ ),
22
+ ( # 5
23
+ "Bioluminescent, A hyperrealistic depiction of a surreal scene: a piano keyboard morphs into a spiral staircase, ascending into a swirling vortex of golden, autumnal hues. A figure with a porcelain mask, reminiscent of commedia dell'arte, emerges from beneath the keys, their hand extended towards a lone female figure in a flowing gown at the apex of the staircase. Emphasize the juxtaposition of the organic and geometric, the tangible and ethereal, with a chiaroscuro lighting style. Capture the melancholic beauty and enigmatic narrative inherent in the scene.",
24
+ (1024, 1536),
25
+ ),
26
+ ( # 6
27
+ "highly detailed cinematic movie poster with the text \"PACIFIST RIM\" in a bold, vibrant sci-fi-style font at the top and a tagline reading \"Saving the world, one bouquet at a time\" below it, with the movie credits at the bottom, in the foreground, a gigantic tailless humanoid bipedal mecha-robot and an equally massive kaiju with blue-green iridescent scales and bioluminescent accents stand face-to-face, the enormous mecha on the left is clad in battle-worn yet gleaming metallic armor plates, holding out a large bouquet of exotic, colorful flowers to the kaiju on the right, the kaiju looks surprised by the gesture, its grotesque, otherworldly, surreal form equipped with a row of cyan glowing crystalline spikes along its back, the setting is an urban waterfront, framed by towering skyscrapers and a shimmering ocean with soft waves behind them, the background is bathed in the glow of moonlight and flickering neon billboards with messages like \"HARMONY\" and \"PEACE,\" tiny people on the ground below snap photos with their phones, while some onlookers stare in disbelief, behind the two titanic figures, the calm ocean glistens under the moonlight as distant ships drift by, the robot's posture is calm and serene, its large claw-like hands extended as it presents the bouquet, both figures express an aura of peace and harmony despite their intimidating size, the monstrous kaiju, though menacing, is curious and seemingly receptive to the offering, the overall mood is tranquil, a wide-angle shot captures both the robot and kaiju in their full, towering forms, emphasizing their colossal scale against the peaceful cityscape backdrop, the waterfront and moonlight add depth, while a low-angle shot looking up at the robot and kaiju further enhances their imposing size, the lighting is vibrant and saturated, with soft moonlight reflecting off the robot’s metallic surface, neon city lights providing colorful accents, lens flare, dappled lighting",
28
+ (1024, 1536),
29
+ ),
30
+ ( # 7
31
+ 'a highly detailed and atmospheric, painted western movie poster with the title text "Once Upon a Lime in the West" in a dark red western-style font and the tagline text "There were three men ... and one very sour twist", with movie credits at the bottom, featuring small white text detailing actor and director names and production company logos, inspired by classic western movie posters from the 1960s, an oversized lime is the central element in the middle ground of a rugged, sun-scorched desert landscape typical of a western, the vast expanse of dry, cracked earth stretches toward the horizon, framed by towering red rock formations, the absurdity of the lime is juxtaposed with the intense gravitas of the stoic, iconic gunfighters, as if the lime were as formidable an adversary as any seasoned gunslinger, in the foreground, the silhouettes of two iconic gunfighters stand poised, facing the lime and away from the viewer, the lime looms in the distance like a final showdown in the classic western tradition, in the foreground, the gunfighters stand with long duster coats flowing in the wind, and wide-brimmed hats tilted to cast shadows over their faces, their stances are tense, as if ready for the inevitable draw, and the weapons they carry glint, the background consists of the distant town, where the sun is casting a golden glow, old wooden buildings line the sides, with horses tied to posts and a weathered saloon sign swinging gently in the wind, in this poster, the lime plays the role of the silent villain, an almost mythical object that the gunfighters are preparing to confront, the tension of the scene is palpable, the gunfighters in the foreground have faces marked by dust and sweat, their eyes narrowed against the bright sunlight, their expressions are serious and resolute, as if they have come a long way for this final duel, the absurdity of the lime is in stark contrast with their stoic demeanor, a wide, panoramic shot captures the entire scene, with the gunfighters in the foreground, the lime in the mid-ground, and the town on the horizon, the framing emphasizes the scale of the desert and the dramatic standoff taking place, while subtly highlighting the oversized lime, the camera is positioned low, angled upward from the dusty ground toward the gunfighters, with the distant lime looming ahead, this angle lends the figures an imposing presence, while still giving the lime an absurd grandeur in the distance, the perspective draws the viewer’s eye across the desert, from the silhouettes of the gunfighters to the bizarre focal point of the lime, amplifying the tension, the lighting is harsh and unforgiving, typical of a desert setting, with the evening sun casting deep shadows across the ground, dust clouds drift subtly across the ground, creating a hazy effect, while the sky above is a vast expanse of pale blue, fading into golden hues near the horizon where the sun begins to set, the poster is shot as if using classic anamorphic lenses to capture the wide, epic scale of the desert, the color palette is warm and saturated, evoking the look of a classic spaghetti western, the lime looms unnaturally in the distance, as if conjured from the land itself, casting an absurdly grand shadow across the rugged landscape, the texture and detail evoke hand-painted, weathered posters from the golden age of westerns, with slightly frayed edges and faint creases mimicking the wear of vintage classics',
32
+ (1024, 1536),
33
+ ),
34
+ ( # 8
35
+ 'Anime illustration of a man standing next to a cat',
36
+ (1024, 1024),
37
+ ),
38
+ ( # 9
39
+ 'Selfie of a woman in front of the eiffel tower, a man is standing next to her and giving a thumbs up',
40
+ (1024, 1024),
41
+ ),
42
+ ( # 10
43
+ "a life-sized, clear plastic action figure box with a real woman trapped inside. The box has vibrant, eye-catching colors, featuring bold logos and text reminiscent of classic action figure packaging. The woman stands stiffly in the middle, her pose rigid like a doll, her facial expression conveying a mixture of confusion and surprise. She wears a brightly colored outfit that matches the action figure aesthetic, with exaggerated accessories like a toy sword or futuristic helmet strapped to her side. The box\’s background features bold comic-book-like artwork, framing the woman with dynamic lines and cartoonish explosions, emphasizing the \"action\" theme. The plastic window on the front covers the woman\’s entire body, while the sides display branding and promotional text, like \“Superhero Edition\” or \“Ultimate Collector\’s Item!\” Around the box, toy-like details abound: barcodes, toy company logos, and descriptions of her \“powers\” or \“abilities\” written in comic-style font.",
44
+ (1024, 1536),
45
+ ),
46
+ ( # 11
47
+ 'An image contains three motivational phrases, all in capitalized stylized text on a colorful background: 1. At the top: "PAIN HEALS" 2. In the middle, bold and slightly larger: "CHICKS DIG SCARS" 3. At the bottom: "GLORY LASTS FOREVER"',
48
+ (1024, 1024),
49
+ ),
50
+ ( # 12
51
+ 'An illustration featuring a McDonald\'s on the moon. An anthropomorphic cat in a pink top and blue jeans is ordering McDonald\'s, while a zebra cashier stands behind the counter. The moon\'s surface is visible outside the windows, with craters and a distant view of Earth. The interior of the McDonald\'s is similar to those on Earth but adapted to the lunar environment, with vibrant colors and futuristic design elements. The overall scene is whimsical and imaginative, blending everyday life with a fantastical setting.',
52
+ (1024, 1024),
53
+ ),
54
+ ]
assets/science.png ADDED
assets/splash.jpg ADDED
pipeline.py ADDED
@@ -0,0 +1,1813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
2
+ #
3
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
18
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
19
+
20
+ from typing import Any, Dict, List, Optional, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
28
+ from diffusers.models.attention import FeedForward
29
+ from diffusers.models.attention_processor import (
30
+ Attention,
31
+ apply_rope,
32
+ )
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import (
35
+ AdaLayerNormContinuous,
36
+ AdaLayerNormZero,
37
+ AdaLayerNormZeroSingle,
38
+ )
39
+ from diffusers.utils import (
40
+ USE_PEFT_BACKEND,
41
+ is_torch_version,
42
+ logging,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
47
+ from diffusers.models.embeddings import (
48
+ CombinedTimestepGuidanceTextProjEmbeddings,
49
+ CombinedTimestepTextProjEmbeddings,
50
+ )
51
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
52
+
53
+ from dataclasses import dataclass
54
+ from typing import List, Union
55
+ import PIL.Image
56
+ from diffusers.utils import BaseOutput
57
+
58
+ import inspect
59
+ from functools import lru_cache
60
+ from typing import Any, Callable, Dict, List, Optional, Union
61
+
62
+ import numpy as np
63
+ import torch
64
+ from transformers import (
65
+ CLIPTextModel,
66
+ CLIPTokenizer,
67
+ T5EncoderModel,
68
+ T5TokenizerFast,
69
+ )
70
+
71
+ from diffusers.image_processor import VaeImageProcessor
72
+ from diffusers.loaders import SD3LoraLoaderMixin
73
+ from diffusers.models.autoencoders import AutoencoderKL
74
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
75
+ from diffusers.utils import (
76
+ USE_PEFT_BACKEND,
77
+ is_torch_xla_available,
78
+ logging,
79
+ replace_example_docstring,
80
+ scale_lora_layers,
81
+ unscale_lora_layers,
82
+ )
83
+ from diffusers.utils.torch_utils import randn_tensor
84
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
85
+
86
+ if is_torch_xla_available():
87
+ import torch_xla.core.xla_model as xm
88
+
89
+ XLA_AVAILABLE = True
90
+ else:
91
+ XLA_AVAILABLE = False
92
+
93
+
94
+ @dataclass
95
+ class FluxPipelineOutput(BaseOutput):
96
+ """
97
+ Output class for Stable Diffusion pipelines.
98
+
99
+ Args:
100
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
101
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
102
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
103
+ """
104
+
105
+ images: Union[List[PIL.Image.Image], np.ndarray]
106
+
107
+
108
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
109
+
110
+
111
+ class FluxSingleAttnProcessor2_0:
112
+ r"""
113
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
114
+ """
115
+
116
+ def __init__(self):
117
+ if not hasattr(F, "scaled_dot_product_attention"):
118
+ raise ImportError(
119
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
120
+ )
121
+
122
+ def __call__(
123
+ self,
124
+ attn: Attention,
125
+ hidden_states: torch.Tensor,
126
+ encoder_hidden_states: Optional[torch.Tensor] = None,
127
+ attention_mask: Optional[torch.FloatTensor] = None,
128
+ image_rotary_emb: Optional[torch.Tensor] = None,
129
+ ) -> torch.Tensor:
130
+ input_ndim = hidden_states.ndim
131
+
132
+ if input_ndim == 4:
133
+ batch_size, channel, height, width = hidden_states.shape
134
+ hidden_states = hidden_states.view(
135
+ batch_size, channel, height * width
136
+ ).transpose(1, 2)
137
+
138
+ batch_size, _, _ = hidden_states.shape
139
+ query = attn.to_q(hidden_states)
140
+ key = attn.to_k(hidden_states)
141
+ value = attn.to_v(hidden_states)
142
+
143
+ inner_dim = key.shape[-1]
144
+ head_dim = inner_dim // attn.heads
145
+
146
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
147
+
148
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
149
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
150
+
151
+ if attn.norm_q is not None:
152
+ query = attn.norm_q(query)
153
+ if attn.norm_k is not None:
154
+ key = attn.norm_k(key)
155
+
156
+ # Apply RoPE if needed
157
+ if image_rotary_emb is not None:
158
+ # YiYi to-do: update uising apply_rotary_emb
159
+ # from ..embeddings import apply_rotary_emb
160
+ # query = apply_rotary_emb(query, image_rotary_emb)
161
+ # key = apply_rotary_emb(key, image_rotary_emb)
162
+ query, key = apply_rope(query, key, image_rotary_emb)
163
+
164
+ if attention_mask is not None:
165
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
166
+ attention_mask = (attention_mask > 0).bool()
167
+ attention_mask = attention_mask.to(
168
+ device=hidden_states.device, dtype=hidden_states.dtype
169
+ )
170
+
171
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
172
+ # TODO: add support for attn.scale when we move to Torch 2.1
173
+ hidden_states = F.scaled_dot_product_attention(
174
+ query,
175
+ key,
176
+ value,
177
+ dropout_p=0.0,
178
+ is_causal=False,
179
+ attn_mask=attention_mask,
180
+ )
181
+
182
+ hidden_states = hidden_states.transpose(1, 2).reshape(
183
+ batch_size, -1, attn.heads * head_dim
184
+ )
185
+ hidden_states = hidden_states.to(query.dtype)
186
+
187
+ if input_ndim == 4:
188
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
189
+ batch_size, channel, height, width
190
+ )
191
+
192
+ return hidden_states
193
+
194
+
195
+ class FluxAttnProcessor2_0:
196
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
197
+
198
+ def __init__(self):
199
+ if not hasattr(F, "scaled_dot_product_attention"):
200
+ raise ImportError(
201
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
202
+ )
203
+
204
+ def __call__(
205
+ self,
206
+ attn: Attention,
207
+ hidden_states: torch.FloatTensor,
208
+ encoder_hidden_states: torch.FloatTensor = None,
209
+ attention_mask: Optional[torch.FloatTensor] = None,
210
+ image_rotary_emb: Optional[torch.Tensor] = None,
211
+ ) -> torch.FloatTensor:
212
+ input_ndim = hidden_states.ndim
213
+ if input_ndim == 4:
214
+ batch_size, channel, height, width = hidden_states.shape
215
+ hidden_states = hidden_states.view(
216
+ batch_size, channel, height * width
217
+ ).transpose(1, 2)
218
+ context_input_ndim = encoder_hidden_states.ndim
219
+ if context_input_ndim == 4:
220
+ batch_size, channel, height, width = encoder_hidden_states.shape
221
+ encoder_hidden_states = encoder_hidden_states.view(
222
+ batch_size, channel, height * width
223
+ ).transpose(1, 2)
224
+
225
+ batch_size = encoder_hidden_states.shape[0]
226
+
227
+ # `sample` projections.
228
+ query = attn.to_q(hidden_states)
229
+ key = attn.to_k(hidden_states)
230
+ value = attn.to_v(hidden_states)
231
+
232
+ inner_dim = key.shape[-1]
233
+ head_dim = inner_dim // attn.heads
234
+
235
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
236
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
237
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
238
+
239
+ if attn.norm_q is not None:
240
+ query = attn.norm_q(query)
241
+ if attn.norm_k is not None:
242
+ key = attn.norm_k(key)
243
+
244
+ # `context` projections.
245
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
246
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
247
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
248
+
249
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
250
+ batch_size, -1, attn.heads, head_dim
251
+ ).transpose(1, 2)
252
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
253
+ batch_size, -1, attn.heads, head_dim
254
+ ).transpose(1, 2)
255
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
256
+ batch_size, -1, attn.heads, head_dim
257
+ ).transpose(1, 2)
258
+
259
+ if attn.norm_added_q is not None:
260
+ encoder_hidden_states_query_proj = attn.norm_added_q(
261
+ encoder_hidden_states_query_proj
262
+ )
263
+ if attn.norm_added_k is not None:
264
+ encoder_hidden_states_key_proj = attn.norm_added_k(
265
+ encoder_hidden_states_key_proj
266
+ )
267
+
268
+ # attention
269
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
270
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
271
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
272
+
273
+ if image_rotary_emb is not None:
274
+ # YiYi to-do: update uising apply_rotary_emb
275
+ # from ..embeddings import apply_rotary_emb
276
+ # query = apply_rotary_emb(query, image_rotary_emb)
277
+ # key = apply_rotary_emb(key, image_rotary_emb)
278
+ query, key = apply_rope(query, key, image_rotary_emb)
279
+
280
+ if attention_mask is not None:
281
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
282
+ attention_mask = (attention_mask > 0).bool()
283
+ attention_mask = attention_mask.to(
284
+ device=hidden_states.device, dtype=hidden_states.dtype
285
+ )
286
+
287
+ hidden_states = F.scaled_dot_product_attention(
288
+ query,
289
+ key,
290
+ value,
291
+ dropout_p=0.0,
292
+ is_causal=False,
293
+ attn_mask=attention_mask,
294
+ )
295
+ hidden_states = hidden_states.transpose(1, 2).reshape(
296
+ batch_size, -1, attn.heads * head_dim
297
+ )
298
+ hidden_states = hidden_states.to(query.dtype)
299
+
300
+ encoder_hidden_states, hidden_states = (
301
+ hidden_states[:, : encoder_hidden_states.shape[1]],
302
+ hidden_states[:, encoder_hidden_states.shape[1] :],
303
+ )
304
+
305
+ # linear proj
306
+ hidden_states = attn.to_out[0](hidden_states)
307
+ # dropout
308
+ hidden_states = attn.to_out[1](hidden_states)
309
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
310
+
311
+ if input_ndim == 4:
312
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
313
+ batch_size, channel, height, width
314
+ )
315
+ if context_input_ndim == 4:
316
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
317
+ batch_size, channel, height, width
318
+ )
319
+
320
+ return hidden_states, encoder_hidden_states
321
+
322
+
323
+ # YiYi to-do: refactor rope related functions/classes
324
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
325
+ assert dim % 2 == 0, "The dimension must be even."
326
+
327
+ scale = (
328
+ torch.arange(
329
+ 0,
330
+ dim,
331
+ 2,
332
+ dtype=torch.float64, # torch.float32 if torch.backends.mps.is_available() else
333
+ device=pos.device,
334
+ )
335
+ / dim
336
+ )
337
+ omega = 1.0 / (theta**scale)
338
+
339
+ batch_size, seq_length = pos.shape
340
+ out = torch.einsum("...n,d->...nd", pos, omega)
341
+ cos_out = torch.cos(out)
342
+ sin_out = torch.sin(out)
343
+
344
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
345
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
346
+ return out.float()
347
+
348
+
349
+ # YiYi to-do: refactor rope related functions/classes
350
+ class EmbedND(nn.Module):
351
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
352
+ super().__init__()
353
+ self.dim = dim
354
+ self.theta = theta
355
+ self.axes_dim = axes_dim
356
+
357
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
358
+ n_axes = ids.shape[-1]
359
+ emb = torch.cat(
360
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
361
+ dim=-3,
362
+ )
363
+
364
+ return emb.unsqueeze(1)
365
+
366
+
367
+ def expand_flux_attention_mask(
368
+ hidden_states: torch.Tensor,
369
+ attn_mask: torch.Tensor,
370
+ ) -> torch.Tensor:
371
+ """
372
+ Expand a mask so that the image is included.
373
+ """
374
+ bsz = attn_mask.shape[0]
375
+ assert bsz == hidden_states.shape[0]
376
+ residual_seq_len = hidden_states.shape[1]
377
+ mask_seq_len = attn_mask.shape[1]
378
+
379
+ expanded_mask = torch.ones(bsz, residual_seq_len)
380
+ expanded_mask[:, :mask_seq_len] = attn_mask
381
+
382
+ return expanded_mask
383
+
384
+
385
+ @maybe_allow_in_graph
386
+ class FluxSingleTransformerBlock(nn.Module):
387
+ r"""
388
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
389
+
390
+ Reference: https://arxiv.org/abs/2403.03206
391
+
392
+ Parameters:
393
+ dim (`int`): The number of channels in the input and output.
394
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
395
+ attention_head_dim (`int`): The number of channels in each head.
396
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
397
+ processing of `context` conditions.
398
+ """
399
+
400
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
401
+ super().__init__()
402
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
403
+
404
+ self.norm = AdaLayerNormZeroSingle(dim)
405
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
406
+ self.act_mlp = nn.GELU(approximate="tanh")
407
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
408
+
409
+ processor = FluxSingleAttnProcessor2_0()
410
+ self.attn = Attention(
411
+ query_dim=dim,
412
+ cross_attention_dim=None,
413
+ dim_head=attention_head_dim,
414
+ heads=num_attention_heads,
415
+ out_dim=dim,
416
+ bias=True,
417
+ processor=processor,
418
+ qk_norm="rms_norm",
419
+ eps=1e-6,
420
+ pre_only=True,
421
+ )
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states: torch.FloatTensor,
426
+ temb: torch.FloatTensor,
427
+ image_rotary_emb=None,
428
+ attention_mask: Optional[torch.Tensor] = None,
429
+ ):
430
+ residual = hidden_states
431
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
432
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
433
+
434
+ if attention_mask is not None:
435
+ attention_mask = expand_flux_attention_mask(
436
+ hidden_states,
437
+ attention_mask,
438
+ )
439
+
440
+ attn_output = self.attn(
441
+ hidden_states=norm_hidden_states,
442
+ image_rotary_emb=image_rotary_emb,
443
+ attention_mask=attention_mask,
444
+ )
445
+
446
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
447
+ gate = gate.unsqueeze(1)
448
+ hidden_states = gate * self.proj_out(hidden_states)
449
+ hidden_states = residual + hidden_states
450
+
451
+ return hidden_states
452
+
453
+
454
+ @maybe_allow_in_graph
455
+ class FluxTransformerBlock(nn.Module):
456
+ r"""
457
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
458
+
459
+ Reference: https://arxiv.org/abs/2403.03206
460
+
461
+ Parameters:
462
+ dim (`int`): The number of channels in the input and output.
463
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
464
+ attention_head_dim (`int`): The number of channels in each head.
465
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
466
+ processing of `context` conditions.
467
+ """
468
+
469
+ def __init__(
470
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
471
+ ):
472
+ super().__init__()
473
+
474
+ self.norm1 = AdaLayerNormZero(dim)
475
+
476
+ self.norm1_context = AdaLayerNormZero(dim)
477
+
478
+ if hasattr(F, "scaled_dot_product_attention"):
479
+ processor = FluxAttnProcessor2_0()
480
+ else:
481
+ raise ValueError(
482
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
483
+ )
484
+ self.attn = Attention(
485
+ query_dim=dim,
486
+ cross_attention_dim=None,
487
+ added_kv_proj_dim=dim,
488
+ dim_head=attention_head_dim,
489
+ heads=num_attention_heads,
490
+ out_dim=dim,
491
+ context_pre_only=False,
492
+ bias=True,
493
+ processor=processor,
494
+ qk_norm=qk_norm,
495
+ eps=eps,
496
+ )
497
+
498
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
499
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
500
+
501
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
502
+ self.ff_context = FeedForward(
503
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
504
+ )
505
+
506
+ # let chunk size default to None
507
+ self._chunk_size = None
508
+ self._chunk_dim = 0
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.FloatTensor,
513
+ encoder_hidden_states: torch.FloatTensor,
514
+ temb: torch.FloatTensor,
515
+ image_rotary_emb=None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ ):
518
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
519
+ hidden_states, emb=temb
520
+ )
521
+
522
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
523
+ self.norm1_context(encoder_hidden_states, emb=temb)
524
+ )
525
+
526
+ if attention_mask is not None:
527
+ attention_mask = expand_flux_attention_mask(
528
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
529
+ attention_mask,
530
+ )
531
+
532
+ # Attention.
533
+ attn_output, context_attn_output = self.attn(
534
+ hidden_states=norm_hidden_states,
535
+ encoder_hidden_states=norm_encoder_hidden_states,
536
+ image_rotary_emb=image_rotary_emb,
537
+ attention_mask=attention_mask,
538
+ )
539
+
540
+ # Process attention outputs for the `hidden_states`.
541
+ attn_output = gate_msa.unsqueeze(1) * attn_output
542
+ hidden_states = hidden_states + attn_output
543
+
544
+ norm_hidden_states = self.norm2(hidden_states)
545
+ norm_hidden_states = (
546
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
547
+ )
548
+
549
+ ff_output = self.ff(norm_hidden_states)
550
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
551
+
552
+ hidden_states = hidden_states + ff_output
553
+
554
+ # Process attention outputs for the `encoder_hidden_states`.
555
+
556
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
557
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
558
+
559
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
560
+ norm_encoder_hidden_states = (
561
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
562
+ + c_shift_mlp[:, None]
563
+ )
564
+
565
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
566
+ encoder_hidden_states = (
567
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
568
+ )
569
+
570
+ return encoder_hidden_states, hidden_states
571
+
572
+
573
+ class FluxTransformer2DModelWithMasking(
574
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
575
+ ):
576
+ """
577
+ The Transformer model introduced in Flux.
578
+
579
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
580
+
581
+ Parameters:
582
+ patch_size (`int`): Patch size to turn the input data into small patches.
583
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
584
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
585
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
586
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
587
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
588
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
589
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
590
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
591
+ """
592
+
593
+ _supports_gradient_checkpointing = True
594
+
595
+ @register_to_config
596
+ def __init__(
597
+ self,
598
+ patch_size: int = 1,
599
+ in_channels: int = 64,
600
+ num_layers: int = 19,
601
+ num_single_layers: int = 38,
602
+ attention_head_dim: int = 128,
603
+ num_attention_heads: int = 24,
604
+ joint_attention_dim: int = 4096,
605
+ pooled_projection_dim: int = 768,
606
+ guidance_embeds: bool = False,
607
+ axes_dims_rope: List[int] = [16, 56, 56],
608
+ ):
609
+ super().__init__()
610
+ self.out_channels = in_channels
611
+ self.inner_dim = (
612
+ self.config.num_attention_heads * self.config.attention_head_dim
613
+ )
614
+
615
+ self.pos_embed = EmbedND(
616
+ dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
617
+ )
618
+ text_time_guidance_cls = (
619
+ CombinedTimestepGuidanceTextProjEmbeddings
620
+ if guidance_embeds
621
+ else CombinedTimestepTextProjEmbeddings
622
+ )
623
+ self.time_text_embed = text_time_guidance_cls(
624
+ embedding_dim=self.inner_dim,
625
+ pooled_projection_dim=self.config.pooled_projection_dim,
626
+ )
627
+
628
+ self.context_embedder = nn.Linear(
629
+ self.config.joint_attention_dim, self.inner_dim
630
+ )
631
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
632
+
633
+ self.transformer_blocks = nn.ModuleList(
634
+ [
635
+ FluxTransformerBlock(
636
+ dim=self.inner_dim,
637
+ num_attention_heads=self.config.num_attention_heads,
638
+ attention_head_dim=self.config.attention_head_dim,
639
+ )
640
+ for i in range(self.config.num_layers)
641
+ ]
642
+ )
643
+
644
+ self.single_transformer_blocks = nn.ModuleList(
645
+ [
646
+ FluxSingleTransformerBlock(
647
+ dim=self.inner_dim,
648
+ num_attention_heads=self.config.num_attention_heads,
649
+ attention_head_dim=self.config.attention_head_dim,
650
+ )
651
+ for i in range(self.config.num_single_layers)
652
+ ]
653
+ )
654
+
655
+ self.norm_out = AdaLayerNormContinuous(
656
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
657
+ )
658
+ self.proj_out = nn.Linear(
659
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
660
+ )
661
+
662
+ self.gradient_checkpointing = False
663
+
664
+ def _set_gradient_checkpointing(self, module, value=False):
665
+ if hasattr(module, "gradient_checkpointing"):
666
+ module.gradient_checkpointing = value
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ encoder_hidden_states: torch.Tensor = None,
672
+ pooled_projections: torch.Tensor = None,
673
+ timestep: torch.LongTensor = None,
674
+ img_ids: torch.Tensor = None,
675
+ txt_ids: torch.Tensor = None,
676
+ guidance: torch.Tensor = None,
677
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
678
+ return_dict: bool = True,
679
+ attention_mask: Optional[torch.Tensor] = None,
680
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
681
+ """
682
+ The [`FluxTransformer2DModelWithMasking`] forward method.
683
+
684
+ Args:
685
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
686
+ Input `hidden_states`.
687
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
688
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
689
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
690
+ from the embeddings of input conditions.
691
+ timestep ( `torch.LongTensor`):
692
+ Used to indicate denoising step.
693
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
694
+ A list of tensors that if specified are added to the residuals of transformer blocks.
695
+ joint_attention_kwargs (`dict`, *optional*):
696
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
697
+ `self.processor` in
698
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
699
+ return_dict (`bool`, *optional*, defaults to `True`):
700
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
701
+ tuple.
702
+
703
+ Returns:
704
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
705
+ `tuple` where the first element is the sample tensor.
706
+ """
707
+ if joint_attention_kwargs is not None:
708
+ joint_attention_kwargs = joint_attention_kwargs.copy()
709
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
710
+ else:
711
+ lora_scale = 1.0
712
+
713
+ if USE_PEFT_BACKEND:
714
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
715
+ scale_lora_layers(self, lora_scale)
716
+ else:
717
+ if (
718
+ joint_attention_kwargs is not None
719
+ and joint_attention_kwargs.get("scale", None) is not None
720
+ ):
721
+ logger.warning(
722
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
723
+ )
724
+ hidden_states = self.x_embedder(hidden_states)
725
+
726
+ timestep = timestep.to(hidden_states.dtype) * 1000
727
+ if guidance is not None:
728
+ guidance = guidance.to(hidden_states.dtype) * 1000
729
+ else:
730
+ guidance = None
731
+ temb = (
732
+ self.time_text_embed(timestep, pooled_projections)
733
+ if guidance is None
734
+ else self.time_text_embed(timestep, guidance, pooled_projections)
735
+ )
736
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
737
+
738
+ ids = torch.cat((txt_ids, img_ids), dim=1)
739
+ image_rotary_emb = self.pos_embed(ids)
740
+
741
+ for index_block, block in enumerate(self.transformer_blocks):
742
+ if self.training and self.gradient_checkpointing:
743
+
744
+ def create_custom_forward(module, return_dict=None):
745
+ def custom_forward(*inputs):
746
+ if return_dict is not None:
747
+ return module(*inputs, return_dict=return_dict)
748
+ else:
749
+ return module(*inputs)
750
+
751
+ return custom_forward
752
+
753
+ ckpt_kwargs: Dict[str, Any] = (
754
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
755
+ )
756
+ encoder_hidden_states, hidden_states = (
757
+ torch.utils.checkpoint.checkpoint(
758
+ create_custom_forward(block),
759
+ hidden_states,
760
+ encoder_hidden_states,
761
+ temb,
762
+ image_rotary_emb,
763
+ attention_mask,
764
+ **ckpt_kwargs,
765
+ )
766
+ )
767
+
768
+ else:
769
+ encoder_hidden_states, hidden_states = block(
770
+ hidden_states=hidden_states,
771
+ encoder_hidden_states=encoder_hidden_states,
772
+ temb=temb,
773
+ image_rotary_emb=image_rotary_emb,
774
+ attention_mask=attention_mask,
775
+ )
776
+
777
+ # Flux places the text tokens in front of the image tokens in the
778
+ # sequence.
779
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
780
+
781
+ for index_block, block in enumerate(self.single_transformer_blocks):
782
+ if self.training and self.gradient_checkpointing:
783
+
784
+ def create_custom_forward(module, return_dict=None):
785
+ def custom_forward(*inputs):
786
+ if return_dict is not None:
787
+ return module(*inputs, return_dict=return_dict)
788
+ else:
789
+ return module(*inputs)
790
+
791
+ return custom_forward
792
+
793
+ ckpt_kwargs: Dict[str, Any] = (
794
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
795
+ )
796
+ hidden_states = torch.utils.checkpoint.checkpoint(
797
+ create_custom_forward(block),
798
+ hidden_states,
799
+ temb,
800
+ image_rotary_emb,
801
+ attention_mask,
802
+ **ckpt_kwargs,
803
+ )
804
+
805
+ else:
806
+ hidden_states = block(
807
+ hidden_states=hidden_states,
808
+ temb=temb,
809
+ image_rotary_emb=image_rotary_emb,
810
+ attention_mask=attention_mask,
811
+ )
812
+
813
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
814
+
815
+ hidden_states = self.norm_out(hidden_states, temb)
816
+ output = self.proj_out(hidden_states)
817
+
818
+ if USE_PEFT_BACKEND:
819
+ # remove `lora_scale` from each PEFT layer
820
+ unscale_lora_layers(self, lora_scale)
821
+
822
+ if not return_dict:
823
+ return (output,)
824
+
825
+ return Transformer2DModelOutput(sample=output)
826
+
827
+
828
+ if __name__ == "__main__":
829
+ dtype = torch.bfloat16
830
+ bsz = 2
831
+ img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
832
+ timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
833
+ pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
834
+ text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
835
+ attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
836
+ "cuda", dtype=dtype
837
+ ) # Last 128 positions are masked
838
+
839
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
840
+ latents = latents.view(
841
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
842
+ )
843
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
844
+ latents = latents.reshape(
845
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
846
+ )
847
+
848
+ return latents
849
+
850
+ def _prepare_latent_image_ids(
851
+ batch_size, height, width, device="cuda", dtype=dtype
852
+ ):
853
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
854
+ latent_image_ids[..., 1] = (
855
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
856
+ )
857
+ latent_image_ids[..., 2] = (
858
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
859
+ )
860
+
861
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
862
+ latent_image_ids.shape
863
+ )
864
+
865
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
866
+ latent_image_ids = latent_image_ids.reshape(
867
+ batch_size,
868
+ latent_image_id_height * latent_image_id_width,
869
+ latent_image_id_channels,
870
+ )
871
+
872
+ return latent_image_ids.to(device=device, dtype=dtype)
873
+
874
+ txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
875
+
876
+ vae_scale_factor = 16
877
+ height = 2 * (int(512) // vae_scale_factor)
878
+ width = 2 * (int(512) // vae_scale_factor)
879
+ img_ids = _prepare_latent_image_ids(bsz, height, width)
880
+ img = _pack_latents(img, img.shape[0], 16, height, width)
881
+
882
+ # Gotta go fast
883
+ transformer = FluxTransformer2DModelWithMasking.from_config(
884
+ {
885
+ "attention_head_dim": 128,
886
+ "guidance_embeds": True,
887
+ "in_channels": 64,
888
+ "joint_attention_dim": 4096,
889
+ "num_attention_heads": 24,
890
+ "num_layers": 4,
891
+ "num_single_layers": 8,
892
+ "patch_size": 1,
893
+ "pooled_projection_dim": 768,
894
+ }
895
+ ).to("cuda", dtype=dtype)
896
+
897
+ guidance = torch.tensor([2.0], device="cuda")
898
+ guidance = guidance.expand(bsz)
899
+
900
+ with torch.no_grad():
901
+ no_mask = transformer(
902
+ img,
903
+ encoder_hidden_states=text,
904
+ pooled_projections=pooled,
905
+ timestep=timestep,
906
+ img_ids=img_ids,
907
+ txt_ids=txt_ids,
908
+ guidance=guidance,
909
+ )
910
+ mask = transformer(
911
+ img,
912
+ encoder_hidden_states=text,
913
+ pooled_projections=pooled,
914
+ timestep=timestep,
915
+ img_ids=img_ids,
916
+ txt_ids=txt_ids,
917
+ guidance=guidance,
918
+ attention_mask=attn_mask,
919
+ )
920
+
921
+ assert torch.allclose(no_mask.sample, mask.sample) is False
922
+ print("Attention masking test ran OK. Differences in output were detected.")
923
+
924
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
925
+
926
+ EXAMPLE_DOC_STRING = """
927
+ Examples:
928
+ ```py
929
+ >>> import torch
930
+ >>> from diffusers import FluxPipeline
931
+
932
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
933
+ >>> pipe.to("cuda")
934
+ >>> prompt = "A cat holding a sign that says hello world"
935
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
936
+ >>> # Refer to the pipeline documentation for more details.
937
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
938
+ >>> image.save("flux.png")
939
+ ```
940
+ """
941
+
942
+
943
+ def calculate_shift(
944
+ image_seq_len,
945
+ base_seq_len: int = 256,
946
+ max_seq_len: int = 4096,
947
+ base_shift: float = 0.5,
948
+ max_shift: float = 1.16,
949
+ ):
950
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
951
+ b = base_shift - m * base_seq_len
952
+ mu = image_seq_len * m + b
953
+ return mu
954
+
955
+
956
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
957
+ def retrieve_timesteps(
958
+ scheduler,
959
+ num_inference_steps: Optional[int] = None,
960
+ device: Optional[Union[str, torch.device]] = None,
961
+ timesteps: Optional[List[int]] = None,
962
+ sigmas: Optional[List[float]] = None,
963
+ **kwargs,
964
+ ):
965
+ """
966
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
967
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
968
+
969
+ Args:
970
+ scheduler (`SchedulerMixin`):
971
+ The scheduler to get timesteps from.
972
+ num_inference_steps (`int`):
973
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
974
+ must be `None`.
975
+ device (`str` or `torch.device`, *optional*):
976
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
977
+ timesteps (`List[int]`, *optional*):
978
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
979
+ `num_inference_steps` and `sigmas` must be `None`.
980
+ sigmas (`List[float]`, *optional*):
981
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
982
+ `num_inference_steps` and `timesteps` must be `None`.
983
+
984
+ Returns:
985
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
986
+ second element is the number of inference steps.
987
+ """
988
+ if timesteps is not None and sigmas is not None:
989
+ raise ValueError(
990
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
991
+ )
992
+ if timesteps is not None:
993
+ accepts_timesteps = "timesteps" in set(
994
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
995
+ )
996
+ if not accepts_timesteps:
997
+ raise ValueError(
998
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
999
+ f" timestep schedules. Please check whether you are using the correct scheduler."
1000
+ )
1001
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
1002
+ timesteps = scheduler.timesteps
1003
+ num_inference_steps = len(timesteps)
1004
+ elif sigmas is not None:
1005
+ accept_sigmas = "sigmas" in set(
1006
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
1007
+ )
1008
+ if not accept_sigmas:
1009
+ raise ValueError(
1010
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1011
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
1012
+ )
1013
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
1014
+ timesteps = scheduler.timesteps
1015
+ num_inference_steps = len(timesteps)
1016
+ else:
1017
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
1018
+ timesteps = scheduler.timesteps
1019
+ return timesteps, num_inference_steps
1020
+
1021
+
1022
+ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1023
+ r"""
1024
+ The Flux pipeline for text-to-image generation.
1025
+
1026
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
1027
+
1028
+ Args:
1029
+ transformer ([`FluxTransformer2DModelWithMasking`]):
1030
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
1031
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
1032
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
1033
+ vae ([`AutoencoderKL`]):
1034
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
1035
+ text_encoder ([`CLIPTextModelWithProjection`]):
1036
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
1037
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
1038
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
1039
+ as its dimension.
1040
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
1041
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
1042
+ specifically the
1043
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
1044
+ variant.
1045
+ tokenizer (`CLIPTokenizer`):
1046
+ Tokenizer of class
1047
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
1048
+ tokenizer_2 (`CLIPTokenizer`):
1049
+ Second Tokenizer of class
1050
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
1051
+ """
1052
+
1053
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
1054
+ _optional_components = []
1055
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
1056
+
1057
+ def __init__(
1058
+ self,
1059
+ scheduler: FlowMatchEulerDiscreteScheduler,
1060
+ vae: AutoencoderKL,
1061
+ text_encoder: CLIPTextModel,
1062
+ tokenizer: CLIPTokenizer,
1063
+ text_encoder_2: T5EncoderModel,
1064
+ tokenizer_2: T5TokenizerFast,
1065
+ transformer: FluxTransformer2DModelWithMasking,
1066
+ ):
1067
+ super().__init__()
1068
+
1069
+ self.register_modules(
1070
+ vae=vae,
1071
+ text_encoder=text_encoder,
1072
+ text_encoder_2=text_encoder_2,
1073
+ tokenizer=tokenizer,
1074
+ tokenizer_2=tokenizer_2,
1075
+ transformer=transformer,
1076
+ scheduler=scheduler,
1077
+ )
1078
+ self.vae_scale_factor = (
1079
+ 2 ** (len(self.vae.config.block_out_channels))
1080
+ if hasattr(self, "vae") and self.vae is not None
1081
+ else 16
1082
+ )
1083
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
1084
+ self.tokenizer_max_length = (
1085
+ self.tokenizer.model_max_length
1086
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
1087
+ else 77
1088
+ )
1089
+ self.default_sample_size = 64
1090
+
1091
+ def _get_t5_prompt_embeds(
1092
+ self,
1093
+ prompt: Union[str, List[str]] = None,
1094
+ num_images_per_prompt: int = 1,
1095
+ max_sequence_length: int = 512,
1096
+ device: Optional[torch.device] = None,
1097
+ dtype: Optional[torch.dtype] = None,
1098
+ ):
1099
+ device = device or self._execution_device
1100
+ dtype = dtype or self.text_encoder.dtype
1101
+
1102
+ prompt = [prompt] if isinstance(prompt, str) else prompt
1103
+ batch_size = len(prompt)
1104
+
1105
+ text_inputs = self.tokenizer_2(
1106
+ prompt,
1107
+ padding="max_length",
1108
+ max_length=max_sequence_length,
1109
+ truncation=True,
1110
+ return_length=False,
1111
+ return_overflowing_tokens=False,
1112
+ return_tensors="pt",
1113
+ )
1114
+ prompt_attention_mask = text_inputs.attention_mask
1115
+ text_input_ids = text_inputs.input_ids
1116
+ untruncated_ids = self.tokenizer_2(
1117
+ prompt, padding="longest", return_tensors="pt"
1118
+ ).input_ids
1119
+
1120
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
1121
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
1122
+ logger.warning(
1123
+ "The following part of your input was truncated because `max_sequence_length` is set to "
1124
+ f" {max_sequence_length} tokens: {removed_text}"
1125
+ )
1126
+
1127
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
1128
+
1129
+ dtype = self.text_encoder_2.dtype
1130
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
1131
+
1132
+ _, seq_len, _ = prompt_embeds.shape
1133
+
1134
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
1135
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1136
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
1137
+
1138
+ return prompt_embeds, prompt_attention_mask
1139
+
1140
+ def _get_clip_prompt_embeds(
1141
+ self,
1142
+ prompt: Union[str, List[str]],
1143
+ num_images_per_prompt: int = 1,
1144
+ device: Optional[torch.device] = None,
1145
+ ):
1146
+ device = device or self._execution_device
1147
+
1148
+ prompt = [prompt] if isinstance(prompt, str) else prompt
1149
+ batch_size = len(prompt)
1150
+
1151
+ text_inputs = self.tokenizer(
1152
+ prompt,
1153
+ padding="max_length",
1154
+ max_length=self.tokenizer_max_length,
1155
+ truncation=True,
1156
+ return_overflowing_tokens=False,
1157
+ return_length=False,
1158
+ return_tensors="pt",
1159
+ )
1160
+
1161
+ text_input_ids = text_inputs.input_ids
1162
+ untruncated_ids = self.tokenizer(
1163
+ prompt, padding="longest", return_tensors="pt"
1164
+ ).input_ids
1165
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1166
+ text_input_ids, untruncated_ids
1167
+ ):
1168
+ removed_text = self.tokenizer.batch_decode(
1169
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
1170
+ )
1171
+ logger.warning(
1172
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1173
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
1174
+ )
1175
+ prompt_embeds = self.text_encoder(
1176
+ text_input_ids.to(device), output_hidden_states=False
1177
+ )
1178
+
1179
+ # Use pooled output of CLIPTextModel
1180
+ prompt_embeds = prompt_embeds.pooler_output
1181
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
1182
+
1183
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1184
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1185
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
1186
+
1187
+ return prompt_embeds
1188
+
1189
+ @lru_cache(maxsize=128)
1190
+ def encode_prompt(
1191
+ self,
1192
+ prompt: Union[str, List[str]],
1193
+ prompt_2: Union[str, List[str]],
1194
+ device: Optional[torch.device] = None,
1195
+ num_images_per_prompt: int = 1,
1196
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1197
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1198
+ max_sequence_length: int = 512,
1199
+ lora_scale: Optional[float] = None,
1200
+ ):
1201
+ r"""
1202
+
1203
+ Args:
1204
+ prompt (`str` or `List[str]`, *optional*):
1205
+ prompt to be encoded
1206
+ prompt_2 (`str` or `List[str]`, *optional*):
1207
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1208
+ used in all text-encoders
1209
+ device: (`torch.device`):
1210
+ torch device
1211
+ num_images_per_prompt (`int`):
1212
+ number of images that should be generated per prompt
1213
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1214
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1215
+ provided, text embeddings will be generated from `prompt` input argument.
1216
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1217
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1218
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1219
+ clip_skip (`int`, *optional*):
1220
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1221
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1222
+ lora_scale (`float`, *optional*):
1223
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1224
+ """
1225
+ device = device or self._execution_device
1226
+
1227
+ # set lora scale so that monkey patched LoRA
1228
+ # function of text encoder can correctly access it
1229
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
1230
+ self._lora_scale = lora_scale
1231
+
1232
+ # dynamically adjust the LoRA scale
1233
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
1234
+ scale_lora_layers(self.text_encoder, lora_scale)
1235
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
1236
+ scale_lora_layers(self.text_encoder_2, lora_scale)
1237
+
1238
+ prompt = [prompt] if isinstance(prompt, str) else prompt
1239
+ if prompt is not None:
1240
+ batch_size = len(prompt)
1241
+ else:
1242
+ batch_size = prompt_embeds.shape[0]
1243
+
1244
+ prompt_attention_mask = None
1245
+ if prompt_embeds is None:
1246
+ prompt_2 = prompt_2 or prompt
1247
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
1248
+
1249
+ # We only use the pooled prompt output from the CLIPTextModel
1250
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
1251
+ prompt=prompt,
1252
+ device=device,
1253
+ num_images_per_prompt=num_images_per_prompt,
1254
+ )
1255
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
1256
+ prompt=prompt_2,
1257
+ num_images_per_prompt=num_images_per_prompt,
1258
+ max_sequence_length=max_sequence_length,
1259
+ device=device,
1260
+ )
1261
+
1262
+ if self.text_encoder is not None:
1263
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
1264
+ # Retrieve the original scale by scaling back the LoRA layers
1265
+ unscale_lora_layers(self.text_encoder, lora_scale)
1266
+
1267
+ if self.text_encoder_2 is not None:
1268
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
1269
+ # Retrieve the original scale by scaling back the LoRA layers
1270
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
1271
+
1272
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
1273
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
1274
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
1275
+
1276
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
1277
+
1278
+ def check_inputs(
1279
+ self,
1280
+ prompt,
1281
+ prompt_2,
1282
+ height,
1283
+ width,
1284
+ prompt_embeds=None,
1285
+ pooled_prompt_embeds=None,
1286
+ callback_on_step_end_tensor_inputs=None,
1287
+ max_sequence_length=None,
1288
+ ):
1289
+ if height % 8 != 0 or width % 8 != 0:
1290
+ raise ValueError(
1291
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
1292
+ )
1293
+
1294
+ if callback_on_step_end_tensor_inputs is not None and not all(
1295
+ k in self._callback_tensor_inputs
1296
+ for k in callback_on_step_end_tensor_inputs
1297
+ ):
1298
+ raise ValueError(
1299
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
1300
+ )
1301
+
1302
+ if prompt is not None and prompt_embeds is not None:
1303
+ raise ValueError(
1304
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
1305
+ " only forward one of the two."
1306
+ )
1307
+ elif prompt_2 is not None and prompt_embeds is not None:
1308
+ raise ValueError(
1309
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
1310
+ " only forward one of the two."
1311
+ )
1312
+ elif prompt is None and prompt_embeds is None:
1313
+ raise ValueError(
1314
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
1315
+ )
1316
+ elif prompt is not None and (
1317
+ not isinstance(prompt, str) and not isinstance(prompt, list)
1318
+ ):
1319
+ raise ValueError(
1320
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
1321
+ )
1322
+ elif prompt_2 is not None and (
1323
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
1324
+ ):
1325
+ raise ValueError(
1326
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
1327
+ )
1328
+
1329
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
1330
+ raise ValueError(
1331
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
1332
+ )
1333
+
1334
+ if max_sequence_length is not None and max_sequence_length > 512:
1335
+ raise ValueError(
1336
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
1337
+ )
1338
+
1339
+ @staticmethod
1340
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
1341
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
1342
+ latent_image_ids[..., 1] = (
1343
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
1344
+ )
1345
+ latent_image_ids[..., 2] = (
1346
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
1347
+ )
1348
+
1349
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
1350
+ latent_image_ids.shape
1351
+ )
1352
+
1353
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
1354
+ latent_image_ids = latent_image_ids.reshape(
1355
+ batch_size,
1356
+ latent_image_id_height * latent_image_id_width,
1357
+ latent_image_id_channels,
1358
+ )
1359
+
1360
+ return latent_image_ids
1361
+
1362
+ @staticmethod
1363
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
1364
+ latents = latents.view(
1365
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
1366
+ )
1367
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
1368
+ latents = latents.reshape(
1369
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
1370
+ )
1371
+
1372
+ return latents
1373
+
1374
+ @staticmethod
1375
+ def _unpack_latents(latents, height, width, vae_scale_factor):
1376
+ batch_size, num_patches, channels = latents.shape
1377
+
1378
+ height = height // vae_scale_factor
1379
+ width = width // vae_scale_factor
1380
+
1381
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
1382
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
1383
+
1384
+ latents = latents.reshape(
1385
+ batch_size, channels // (2 * 2), height * 2, width * 2
1386
+ )
1387
+
1388
+ return latents
1389
+
1390
+ def prepare_latents(
1391
+ self,
1392
+ batch_size,
1393
+ num_channels_latents,
1394
+ height,
1395
+ width,
1396
+ dtype,
1397
+ device,
1398
+ generator,
1399
+ latents=None,
1400
+ ):
1401
+ height = 2 * (int(height) // self.vae_scale_factor)
1402
+ width = 2 * (int(width) // self.vae_scale_factor)
1403
+
1404
+ shape = (batch_size, num_channels_latents, height, width)
1405
+
1406
+ if latents is not None:
1407
+ latent_image_ids = self._prepare_latent_image_ids(
1408
+ batch_size, height, width, device, dtype
1409
+ )
1410
+ return latents, latent_image_ids
1411
+
1412
+ if isinstance(generator, list) and len(generator) != batch_size:
1413
+ raise ValueError(
1414
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1415
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1416
+ )
1417
+
1418
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1419
+ latents = self._pack_latents(
1420
+ latents, batch_size, num_channels_latents, height, width
1421
+ )
1422
+
1423
+ latent_image_ids = self._prepare_latent_image_ids(
1424
+ batch_size, height, width, device, dtype
1425
+ )
1426
+
1427
+ return latents, latent_image_ids
1428
+
1429
+ @property
1430
+ def guidance_scale(self):
1431
+ return self._guidance_scale
1432
+
1433
+ @property
1434
+ def joint_attention_kwargs(self):
1435
+ return self._joint_attention_kwargs
1436
+
1437
+ @property
1438
+ def num_timesteps(self):
1439
+ return self._num_timesteps
1440
+
1441
+ @property
1442
+ def interrupt(self):
1443
+ return self._interrupt
1444
+
1445
+ @torch.no_grad()
1446
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1447
+ def __call__(
1448
+ self,
1449
+ prompt: Union[str, List[str]] = None,
1450
+ prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
1451
+ negative_mask: Optional[
1452
+ Union[torch.FloatTensor, List[torch.FloatTensor]]
1453
+ ] = None,
1454
+ prompt_2: Optional[Union[str, List[str]]] = None,
1455
+ height: Optional[int] = None,
1456
+ width: Optional[int] = None,
1457
+ num_inference_steps: int = 28,
1458
+ timesteps: List[int] = None,
1459
+ guidance_scale: float = 3.5,
1460
+ num_images_per_prompt: Optional[int] = 1,
1461
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1462
+ latents: Optional[torch.FloatTensor] = None,
1463
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1464
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1465
+ output_type: Optional[str] = "pil",
1466
+ return_dict: bool = True,
1467
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1468
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1469
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1470
+ max_sequence_length: int = 512,
1471
+ guidance_scale_real: float = 1.0,
1472
+ negative_prompt: Union[str, List[str]] = "",
1473
+ negative_prompt_2: Union[str, List[str]] = "",
1474
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1475
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1476
+ no_cfg_until_timestep: int = 0,
1477
+ use_prompt_mask: bool = True,
1478
+ zero_using_prompt_mask: bool = False,
1479
+ device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
1480
+ ):
1481
+ r"""
1482
+ Function invoked when calling the pipeline for generation.
1483
+
1484
+ Args:
1485
+ prompt (`str` or `List[str]`, *optional*):
1486
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1487
+ instead.
1488
+ prompt_mask (`str` or `List[str]`, *optional*):
1489
+ The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
1490
+ instead.
1491
+ prompt_2 (`str` or `List[str]`, *optional*):
1492
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1493
+ will be used instead
1494
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1495
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1496
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1497
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1498
+ num_inference_steps (`int`, *optional*, defaults to 50):
1499
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1500
+ expense of slower inference.
1501
+ timesteps (`List[int]`, *optional*):
1502
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1503
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1504
+ passed will be used. Must be in descending order.
1505
+ guidance_scale (`float`, *optional*, defaults to 7.0):
1506
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1507
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1508
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1509
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1510
+ usually at the expense of lower image quality.
1511
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1512
+ The number of images to generate per prompt.
1513
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1514
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1515
+ to make generation deterministic.
1516
+ latents (`torch.FloatTensor`, *optional*):
1517
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1518
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1519
+ tensor will ge generated by sampling using the supplied random `generator`.
1520
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1521
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1522
+ provided, text embeddings will be generated from `prompt` input argument.
1523
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1524
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1525
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1526
+ output_type (`str`, *optional*, defaults to `"pil"`):
1527
+ The output format of the generate image. Choose between
1528
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1529
+ return_dict (`bool`, *optional*, defaults to `True`):
1530
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
1531
+ joint_attention_kwargs (`dict`, *optional*):
1532
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1533
+ `self.processor` in
1534
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1535
+ callback_on_step_end (`Callable`, *optional*):
1536
+ A function that calls at the end of each denoising steps during the inference. The function is called
1537
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1538
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1539
+ `callback_on_step_end_tensor_inputs`.
1540
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1541
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1542
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1543
+ `._callback_tensor_inputs` attribute of your pipeline class.
1544
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
1545
+
1546
+ Examples:
1547
+
1548
+ Returns:
1549
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
1550
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
1551
+ images.
1552
+ """
1553
+
1554
+ height = height or self.default_sample_size * self.vae_scale_factor
1555
+ width = width or self.default_sample_size * self.vae_scale_factor
1556
+
1557
+ # 1. Check inputs. Raise error if not correct
1558
+ self.check_inputs(
1559
+ prompt,
1560
+ prompt_2,
1561
+ height,
1562
+ width,
1563
+ prompt_embeds=prompt_embeds,
1564
+ pooled_prompt_embeds=pooled_prompt_embeds,
1565
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1566
+ max_sequence_length=max_sequence_length,
1567
+ )
1568
+
1569
+ self._guidance_scale = guidance_scale
1570
+ self._guidance_scale_real = guidance_scale_real
1571
+ self._joint_attention_kwargs = joint_attention_kwargs
1572
+ self._interrupt = False
1573
+
1574
+ # 2. Define call parameters
1575
+ if prompt is not None and isinstance(prompt, str):
1576
+ batch_size = 1
1577
+ elif prompt is not None and isinstance(prompt, list):
1578
+ batch_size = len(prompt)
1579
+ else:
1580
+ batch_size = prompt_embeds.shape[0]
1581
+
1582
+ device = device or self._execution_device
1583
+
1584
+ lora_scale = (
1585
+ self.joint_attention_kwargs.get("scale", None)
1586
+ if self.joint_attention_kwargs is not None
1587
+ else None
1588
+ )
1589
+ (
1590
+ prompt_embeds,
1591
+ pooled_prompt_embeds,
1592
+ text_ids,
1593
+ _prompt_mask,
1594
+ ) = self.encode_prompt(
1595
+ prompt=prompt,
1596
+ prompt_2=prompt_2,
1597
+ prompt_embeds=prompt_embeds,
1598
+ pooled_prompt_embeds=pooled_prompt_embeds,
1599
+ device=device,
1600
+ num_images_per_prompt=num_images_per_prompt,
1601
+ max_sequence_length=max_sequence_length,
1602
+ lora_scale=lora_scale,
1603
+ )
1604
+ if _prompt_mask is not None:
1605
+ prompt_mask = _prompt_mask
1606
+
1607
+ if negative_prompt_2 == "" and negative_prompt != "":
1608
+ negative_prompt_2 = negative_prompt
1609
+
1610
+ negative_text_ids = text_ids
1611
+ if self._guidance_scale_real > 1.0 and (
1612
+ negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
1613
+ ):
1614
+ (
1615
+ negative_prompt_embeds,
1616
+ negative_pooled_prompt_embeds,
1617
+ negative_text_ids,
1618
+ _neg_prompt_mask,
1619
+ ) = self.encode_prompt(
1620
+ prompt=negative_prompt,
1621
+ prompt_2=negative_prompt_2,
1622
+ prompt_embeds=None,
1623
+ pooled_prompt_embeds=None,
1624
+ device=device,
1625
+ num_images_per_prompt=num_images_per_prompt,
1626
+ max_sequence_length=max_sequence_length,
1627
+ lora_scale=lora_scale,
1628
+ )
1629
+
1630
+ if _neg_prompt_mask is not None:
1631
+ negative_mask = _neg_prompt_mask
1632
+
1633
+ # 4. Prepare latent variables
1634
+ num_channels_latents = self.transformer.config.in_channels // 4
1635
+ latents, latent_image_ids = self.prepare_latents(
1636
+ batch_size * num_images_per_prompt,
1637
+ num_channels_latents,
1638
+ height,
1639
+ width,
1640
+ prompt_embeds.dtype,
1641
+ device,
1642
+ generator,
1643
+ latents,
1644
+ )
1645
+
1646
+ # 5. Prepare timesteps
1647
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1648
+ image_seq_len = latents.shape[1]
1649
+ mu = calculate_shift(
1650
+ image_seq_len,
1651
+ self.scheduler.config.base_image_seq_len,
1652
+ self.scheduler.config.max_image_seq_len,
1653
+ self.scheduler.config.base_shift,
1654
+ self.scheduler.config.max_shift,
1655
+ )
1656
+ timesteps, num_inference_steps = retrieve_timesteps(
1657
+ self.scheduler,
1658
+ num_inference_steps,
1659
+ device,
1660
+ timesteps,
1661
+ sigmas,
1662
+ mu=mu,
1663
+ )
1664
+ num_warmup_steps = max(
1665
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
1666
+ )
1667
+ self._num_timesteps = len(timesteps)
1668
+
1669
+ latents = latents
1670
+ latent_image_ids = latent_image_ids
1671
+ timesteps = timesteps
1672
+ text_ids = text_ids.to(device=device)
1673
+
1674
+ # handle guidance
1675
+ if self.transformer.config.guidance_embeds:
1676
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1677
+ guidance = guidance.expand(latents.shape[0])
1678
+ else:
1679
+ guidance = None
1680
+
1681
+ if use_prompt_mask and prompt_mask is not None and not zero_using_prompt_mask:
1682
+ print('Using masking')
1683
+ elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
1684
+ print('Using zeroed embeds')
1685
+ else:
1686
+ print('Not using masking')
1687
+
1688
+ if self._guidance_scale_real > 1.0:
1689
+ print('Using classifier free guidance', self._guidance_scale_real)
1690
+
1691
+ # 6. Denoising loop
1692
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1693
+ for i, t in enumerate(timesteps):
1694
+ if self.interrupt:
1695
+ continue
1696
+
1697
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1698
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1699
+
1700
+ assert prompt_mask is not None
1701
+
1702
+ extra_transformer_args = {}
1703
+ if use_prompt_mask and prompt_mask is not None and not zero_using_prompt_mask:
1704
+ extra_transformer_args["attention_mask"] = prompt_mask
1705
+ elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
1706
+ mask_tens = prompt_mask.unsqueeze(-1).to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
1707
+ prompt_embeds = prompt_embeds * mask_tens
1708
+
1709
+ noise_pred = self.transformer(
1710
+ hidden_states=latents,
1711
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1712
+ timestep=timestep / 1000,
1713
+ guidance=guidance,
1714
+ pooled_projections=pooled_prompt_embeds,
1715
+ encoder_hidden_states=prompt_embeds,
1716
+ txt_ids=text_ids,
1717
+ img_ids=latent_image_ids.to(device=device),
1718
+ joint_attention_kwargs=self.joint_attention_kwargs,
1719
+ return_dict=False,
1720
+ **extra_transformer_args,
1721
+ )[0]
1722
+
1723
+ # TODO optionally use batch prediction to speed this up.
1724
+ if self._guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1725
+ extra_transformer_args_neg = {}
1726
+ if negative_mask is not None:
1727
+ extra_transformer_args_neg["attention_mask"] = negative_mask
1728
+ extra_transformer_args_neg["attention_mask"] is not None
1729
+
1730
+ noise_pred_uncond = self.transformer(
1731
+ hidden_states=latents,
1732
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1733
+ timestep=timestep / 1000,
1734
+ guidance=guidance,
1735
+ pooled_projections=negative_pooled_prompt_embeds,
1736
+ encoder_hidden_states=negative_prompt_embeds,
1737
+ txt_ids=negative_text_ids,
1738
+ img_ids=latent_image_ids.to(device=device),
1739
+ joint_attention_kwargs=self.joint_attention_kwargs,
1740
+ return_dict=False,
1741
+ **extra_transformer_args_neg,
1742
+ )[0]
1743
+
1744
+ noise_pred = noise_pred_uncond + self._guidance_scale_real * (
1745
+ noise_pred - noise_pred_uncond
1746
+ )
1747
+ progress_bar.set_postfix(
1748
+ {
1749
+ 'ts': timestep.detach().item() / 1000,
1750
+ 'cfg': self._guidance_scale_real,
1751
+ },
1752
+ )
1753
+ else:
1754
+ progress_bar.set_postfix(
1755
+ {
1756
+ 'ts': timestep.detach().item() / 1000,
1757
+ 'cfg': 'N/A',
1758
+ },
1759
+ )
1760
+
1761
+ # compute the previous noisy sample x_t -> x_t-1
1762
+ latents_dtype = latents.dtype
1763
+ latents = self.scheduler.step(
1764
+ noise_pred, t, latents, return_dict=False
1765
+ )[0]
1766
+
1767
+ if latents.dtype != latents_dtype:
1768
+ if torch.backends.mps.is_available():
1769
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1770
+ latents = latents.to(latents_dtype)
1771
+
1772
+ if callback_on_step_end is not None:
1773
+ callback_kwargs = {}
1774
+ for k in callback_on_step_end_tensor_inputs:
1775
+ callback_kwargs[k] = locals()[k]
1776
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1777
+
1778
+ latents = callback_outputs.pop("latents", latents)
1779
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1780
+
1781
+ # call the callback, if provided
1782
+ if i == len(timesteps) - 1 or (
1783
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1784
+ ):
1785
+ progress_bar.update()
1786
+
1787
+ if XLA_AVAILABLE:
1788
+ xm.mark_step()
1789
+
1790
+ if output_type == "latent":
1791
+ image = latents
1792
+
1793
+ else:
1794
+ latents = self._unpack_latents(
1795
+ latents, height, width, self.vae_scale_factor
1796
+ )
1797
+ latents = (
1798
+ latents / self.vae.config.scaling_factor
1799
+ ) + self.vae.config.shift_factor
1800
+
1801
+ image = self.vae.decode(
1802
+ latents,
1803
+ return_dict=False,
1804
+ )[0]
1805
+ image = self.image_processor.postprocess(image, output_type=output_type)
1806
+
1807
+ # Offload all models
1808
+ self.maybe_free_model_hooks()
1809
+
1810
+ if not return_dict:
1811
+ return (image,)
1812
+
1813
+ return FluxPipelineOutput(images=image)