RohitGandikota commited on
Commit
4cbd4f2
·
verified ·
1 Parent(s): 7ad90cf

adding utils for sliders

Browse files
utils/__init__.py ADDED
File without changes
utils/clip_util.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import math, random, os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+ from sklearn.decomposition import PCA
8
+
9
+
10
+ def extract_clip_features(clip, image, encoder):
11
+ """
12
+ Extracts feature embeddings from an image using either CLIP or DINOv2 models.
13
+
14
+ Args:
15
+ clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2)
16
+ image (torch.Tensor): Input image tensor normalized according to model requirements
17
+ encoder (str): Type of encoder to use ('dinov2-small' or 'clip')
18
+
19
+ Returns:
20
+ torch.Tensor: Feature embeddings extracted from the image
21
+
22
+ Note:
23
+ - For DINOv2 models, uses the pooled output features
24
+ - For CLIP models, uses the image features from the vision encoder
25
+ - The input image should already be properly resized and normalized
26
+ """
27
+ # Handle DINOv2 models
28
+ if 'dino' in encoder:
29
+ denoised = clip(image)
30
+ denoised = denoised.pooler_output
31
+ # Handle CLIP models
32
+ else:
33
+ denoised = clip.get_image_features(image)
34
+
35
+ return denoised
36
+
37
+ @torch.no_grad()
38
+ def compute_clip_pca(
39
+ diverse_prompts: List[str],
40
+ pipe,
41
+ clip_model,
42
+ clip_processor,
43
+ device,
44
+ guidance_scale,
45
+ params,
46
+ total_samples = 5000,
47
+ num_pca_components = 100,
48
+ batch_size = 10
49
+
50
+ ) -> torch.Tensor:
51
+ """
52
+ Extract CLIP features from generated images based on prompts.
53
+
54
+ Args:
55
+ diverse_prompts: List of prompts to generate images from
56
+ model_components: Various model components needed for generation
57
+ args: Training arguments
58
+
59
+ Returns:
60
+ Tensor of CLIP principle components
61
+ """
62
+
63
+
64
+ # Calculate how many total batches we need
65
+ num_batches = math.ceil(total_samples / batch_size)
66
+ # Randomly sample prompts (with replacement if needed)
67
+ sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches)
68
+
69
+ clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt"
70
+
71
+ if os.path.exists(clip_features_path):
72
+ df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv")
73
+ prompts_training = list(df.prompt)
74
+ image_paths = list(df.image_path)
75
+ return torch.load(clip_features_path).to(device), prompts_training, image_paths
76
+
77
+ os.makedirs(params['savepath_training_images'], exist_ok=True)
78
+
79
+ # Generate images and extract features
80
+ img_idx = 0
81
+ clip_features = []
82
+ image_paths = []
83
+ prompts_training = []
84
+ print('Calculating Semantic PCA')
85
+
86
+ for prompt in tqdm(sampled_prompts_clip):
87
+ if 'max_sequence_length' in params:
88
+ images = pipe(prompt,
89
+ num_images_per_prompt = batch_size,
90
+ num_inference_steps = params['max_denoising_steps'],
91
+ guidance_scale=guidance_scale,
92
+ max_sequence_length = params['max_sequence_length'],
93
+ height = params['height'],
94
+ width = params['width'],
95
+ ).images
96
+ else:
97
+ images = pipe(prompt,
98
+ num_images_per_prompt = batch_size,
99
+ num_inference_steps = params['max_denoising_steps'],
100
+ guidance_scale=guidance_scale,
101
+ height = params['height'],
102
+ width = params['width'],
103
+ ).images
104
+
105
+
106
+ # Process images
107
+ clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True)
108
+ pixel_values = clip_inputs['pixel_values'].to(device)
109
+
110
+ # Get image embeddings
111
+ with torch.no_grad():
112
+ image_features = clip_model.get_image_features(pixel_values)
113
+
114
+ # Normalize embeddings
115
+ clip_feats = image_features / image_features.norm(dim=1, keepdim=True)
116
+ clip_features.append(clip_feats)
117
+
118
+ for im in images:
119
+ image_path = f"{params['savepath_training_images']}/{img_idx}.png"
120
+ im.save(image_path)
121
+ image_paths.append(image_path)
122
+ prompts_training.append(prompt)
123
+ img_idx += 1
124
+
125
+
126
+ clip_features = torch.cat(clip_features)
127
+
128
+
129
+ # Calculate principle components
130
+ pca = PCA(n_components=num_pca_components)
131
+ clip_embeds_np = clip_features.float().cpu().numpy()
132
+ pca.fit(clip_embeds_np)
133
+ clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype)
134
+
135
+ # Save results
136
+ torch.save(clip_principles, clip_features_path)
137
+ pd.DataFrame({
138
+ 'prompt': prompts_training,
139
+ 'image_path': image_paths
140
+ }).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False)
141
+
142
+ return clip_principles, prompts_training, image_paths
utils/flux_utils.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os , torch
2
+ import argparse
3
+ import copy
4
+ import gc
5
+ import itertools
6
+ import logging
7
+ import math
8
+
9
+ import random
10
+ import shutil
11
+ import warnings
12
+ from contextlib import nullcontext
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ import transformers
19
+ from accelerate import Accelerator
20
+ from accelerate.logging import get_logger
21
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
22
+ from huggingface_hub import create_repo, upload_folder
23
+ from huggingface_hub.utils import insecure_hashlib
24
+ from PIL import Image
25
+ from PIL.ImageOps import exif_transpose
26
+ from torch.utils.data import Dataset
27
+ from torchvision import transforms
28
+ from torchvision.transforms.functional import crop
29
+ from tqdm.auto import tqdm
30
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
31
+
32
+ import diffusers
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ FlowMatchEulerDiscreteScheduler,
36
+ FluxTransformer2DModel,
37
+ )
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.training_utils import (
40
+ _set_state_dict_into_text_encoder,
41
+ cast_training_params,
42
+ compute_density_for_timestep_sampling,
43
+ compute_loss_weighting_for_sd3,
44
+ )
45
+ from diffusers.utils import (
46
+ check_min_version,
47
+ convert_unet_state_dict_to_peft,
48
+ is_wandb_available,
49
+ )
50
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51
+ from diffusers.utils.torch_utils import is_compiled_module
52
+
53
+
54
+ from collections import defaultdict
55
+
56
+
57
+ from typing import List, Optional
58
+ import argparse
59
+ import ast
60
+ from pathlib import Path
61
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
62
+ from huggingface_hub import hf_hub_download
63
+ import gc
64
+ import torch.nn.functional as F
65
+ import os
66
+ import torch
67
+ from tqdm.auto import tqdm
68
+ import time, datetime
69
+ import numpy as np
70
+ from torch.optim import AdamW
71
+ from contextlib import ExitStack
72
+ from safetensors.torch import load_file
73
+ import torch.nn as nn
74
+ import random
75
+ from transformers import CLIPModel
76
+
77
+ from transformers import logging
78
+ logging.set_verbosity_warning()
79
+
80
+ from diffusers import logging
81
+ logging.set_verbosity_error()
82
+
83
+
84
+ def flush():
85
+ torch.cuda.empty_cache()
86
+ gc.collect()
87
+ flush()
88
+ def unwrap_model(model):
89
+ options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
90
+ #if is_deepspeed_available():
91
+ # options += (DeepSpeedEngine,)
92
+ while isinstance(model, options):
93
+ model = model.module
94
+ return model
95
+
96
+
97
+ # Function to log gradients
98
+ def log_gradients(named_parameters):
99
+ grad_dict = defaultdict(lambda: defaultdict(float))
100
+ for name, param in named_parameters:
101
+ if param.requires_grad and param.grad is not None:
102
+ grad_dict[name]['mean'] = param.grad.abs().mean().item()
103
+ grad_dict[name]['std'] = param.grad.std().item()
104
+ grad_dict[name]['max'] = param.grad.abs().max().item()
105
+ grad_dict[name]['min'] = param.grad.abs().min().item()
106
+ return grad_dict
107
+
108
+ def import_model_class_from_model_name_or_path(
109
+ pretrained_model_name_or_path: str, subfolder: str = "text_encoder",
110
+ ):
111
+ text_encoder_config = PretrainedConfig.from_pretrained(
112
+ pretrained_model_name_or_path, subfolder=subfolder
113
+ , device_map='cuda:0'
114
+ )
115
+ model_class = text_encoder_config.architectures[0]
116
+ if model_class == "CLIPTextModel":
117
+ from transformers import CLIPTextModel
118
+
119
+ return CLIPTextModel
120
+ elif model_class == "T5EncoderModel":
121
+ from transformers import T5EncoderModel
122
+
123
+ return T5EncoderModel
124
+ else:
125
+ raise ValueError(f"{model_class} is not supported.")
126
+ def load_text_encoders(pretrained_model_name_or_path, class_one, class_two, weight_dtype):
127
+ text_encoder_one = class_one.from_pretrained(
128
+ pretrained_model_name_or_path,
129
+ subfolder="text_encoder",
130
+ torch_dtype=weight_dtype,
131
+ device_map='cuda:0'
132
+ )
133
+ text_encoder_two = class_two.from_pretrained(
134
+ pretrained_model_name_or_path,
135
+ subfolder="text_encoder_2",
136
+ torch_dtype=weight_dtype,
137
+ device_map='cuda:0'
138
+ )
139
+ return text_encoder_one, text_encoder_two
140
+ import matplotlib.pyplot as plt
141
+ def plot_labeled_images(images, labels):
142
+ # Determine the number of images
143
+ n = len(images)
144
+
145
+ # Create a new figure with a single row
146
+ fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
147
+
148
+ # If there's only one image, axes will be a single object, not an array
149
+ if n == 1:
150
+ axes = [axes]
151
+
152
+ # Plot each image
153
+ for i, (img, label) in enumerate(zip(images, labels)):
154
+ # Convert PIL image to numpy array
155
+ img_array = np.array(img)
156
+
157
+ # Display the image
158
+ axes[i].imshow(img_array)
159
+ axes[i].axis('off') # Turn off axis
160
+
161
+ # Set the title (label) for the image
162
+ axes[i].set_title(label)
163
+
164
+ # Adjust the layout and display the plot
165
+ plt.tight_layout()
166
+ plt.show()
167
+
168
+
169
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
170
+ text_inputs = tokenizer(
171
+ prompt,
172
+ padding="max_length",
173
+ max_length=max_sequence_length,
174
+ truncation=True,
175
+ return_length=False,
176
+ return_overflowing_tokens=False,
177
+ return_tensors="pt",
178
+ )
179
+ text_input_ids = text_inputs.input_ids
180
+ return text_input_ids
181
+
182
+
183
+ def _encode_prompt_with_t5(
184
+ text_encoder,
185
+ tokenizer,
186
+ max_sequence_length=512,
187
+ prompt=None,
188
+ num_images_per_prompt=1,
189
+ device=None,
190
+ text_input_ids=None,
191
+ ):
192
+ prompt = [prompt] if isinstance(prompt, str) else prompt
193
+ batch_size = len(prompt)
194
+
195
+ if tokenizer is not None:
196
+ text_inputs = tokenizer(
197
+ prompt,
198
+ padding="max_length",
199
+ max_length=max_sequence_length,
200
+ truncation=True,
201
+ return_length=False,
202
+ return_overflowing_tokens=False,
203
+ return_tensors="pt",
204
+ )
205
+ text_input_ids = text_inputs.input_ids
206
+ else:
207
+ if text_input_ids is None:
208
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
209
+
210
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
211
+
212
+ dtype = text_encoder.dtype
213
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
214
+
215
+ _, seq_len, _ = prompt_embeds.shape
216
+
217
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
218
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
219
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
220
+
221
+ return prompt_embeds
222
+
223
+
224
+ def _encode_prompt_with_clip(
225
+ text_encoder,
226
+ tokenizer,
227
+ prompt: str,
228
+ device=None,
229
+ text_input_ids=None,
230
+ num_images_per_prompt: int = 1,
231
+ ):
232
+ prompt = [prompt] if isinstance(prompt, str) else prompt
233
+ batch_size = len(prompt)
234
+
235
+ if tokenizer is not None:
236
+ text_inputs = tokenizer(
237
+ prompt,
238
+ padding="max_length",
239
+ max_length=77,
240
+ truncation=True,
241
+ return_overflowing_tokens=False,
242
+ return_length=False,
243
+ return_tensors="pt",
244
+ )
245
+
246
+ text_input_ids = text_inputs.input_ids
247
+ else:
248
+ if text_input_ids is None:
249
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
250
+
251
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
252
+
253
+ # Use pooled output of CLIPTextModel
254
+ prompt_embeds = prompt_embeds.pooler_output
255
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
256
+
257
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
258
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
259
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
260
+
261
+ return prompt_embeds
262
+
263
+ def encode_prompt(
264
+ text_encoders,
265
+ tokenizers,
266
+ prompt: str,
267
+ max_sequence_length,
268
+ device=None,
269
+ num_images_per_prompt: int = 1,
270
+ text_input_ids_list=None,
271
+ ):
272
+ prompt = [prompt] if isinstance(prompt, str) else prompt
273
+ batch_size = len(prompt)
274
+ dtype = text_encoders[0].dtype
275
+
276
+ pooled_prompt_embeds = _encode_prompt_with_clip(
277
+ text_encoder=text_encoders[0],
278
+ tokenizer=tokenizers[0],
279
+ prompt=prompt,
280
+ device=device if device is not None else text_encoders[0].device,
281
+ num_images_per_prompt=num_images_per_prompt,
282
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
283
+ )
284
+
285
+ prompt_embeds = _encode_prompt_with_t5(
286
+ text_encoder=text_encoders[1],
287
+ tokenizer=tokenizers[1],
288
+ max_sequence_length=max_sequence_length,
289
+ prompt=prompt,
290
+ num_images_per_prompt=num_images_per_prompt,
291
+ device=device if device is not None else text_encoders[1].device,
292
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
293
+ )
294
+
295
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
296
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
297
+
298
+ return prompt_embeds, pooled_prompt_embeds, text_ids
299
+
300
+ def compute_text_embeddings(prompt, text_encoders, tokenizers,max_sequence_length=256):
301
+ device = text_encoders[0].device
302
+ with torch.no_grad():
303
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
304
+ text_encoders, tokenizers, prompt, max_sequence_length=max_sequence_length
305
+ )
306
+ prompt_embeds = prompt_embeds.to(device)
307
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device)
308
+ text_ids = text_ids.to(device)
309
+ return prompt_embeds, pooled_prompt_embeds, text_ids
310
+
311
+
312
+ def get_sigmas(timesteps, n_dim=4, device='cuda:0', dtype=torch.bfloat16):
313
+ sigmas = noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
314
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(device)
315
+ timesteps = timesteps.to(device)
316
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
317
+
318
+ sigma = sigmas[step_indices].flatten()
319
+ while len(sigma.shape) < n_dim:
320
+ sigma = sigma.unsqueeze(-1)
321
+ return sigma
322
+
323
+
324
+ def plot_history(history):
325
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
326
+ ax1.plot(history['concept'])
327
+ ax1.set_title('Concept Loss')
328
+ ax2.plot(movingaverage(history['concept'], 10))
329
+ ax2.set_title('Moving Average Concept Loss')
330
+ plt.tight_layout()
331
+ plt.show()
332
+
333
+ def movingaverage(interval, window_size):
334
+ window = np.ones(int(window_size))/float(window_size)
335
+ return np.convolve(interval, window, 'same')
336
+
337
+
338
+
339
+ @torch.no_grad()
340
+ def get_noisy_image_flux(
341
+ image,
342
+ vae,
343
+ transformer,
344
+ scheduler,
345
+ timesteps_to=1000,
346
+ generator=None,
347
+ **kwargs,
348
+ ):
349
+ """
350
+ Gets noisy latents for a given image using Flux pipeline approach.
351
+
352
+ Args:
353
+ image: PIL image or tensor
354
+ vae: Flux VAE model
355
+ transformer: Flux transformer model
356
+ scheduler: Flux noise scheduler
357
+ timesteps_to: Target timestep
358
+ generator: Random generator for reproducibility
359
+
360
+ Returns:
361
+ tuple: (noisy_latents, noise)
362
+ """
363
+ device = vae.device
364
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
365
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
366
+
367
+ # Preprocess image
368
+ if not isinstance(image, torch.Tensor):
369
+ image = image_processor.preprocess(image)
370
+ image = image.to(device=device, dtype=torch.float32)
371
+
372
+ # Encode through VAE
373
+ init_latents = vae.encode(image).latents
374
+ init_latents = vae.config.scaling_factor * init_latents
375
+
376
+ # Get shape for noise
377
+ shape = init_latents.shape
378
+
379
+ # Generate noise
380
+ noise = randn_tensor(shape, generator=generator, device=device)
381
+
382
+ # Pack latents using Flux's method
383
+ init_latents = _pack_latents(
384
+ init_latents,
385
+ shape[0], # batch size
386
+ transformer.config.in_channels // 4,
387
+ height=shape[2],
388
+ width=shape[3]
389
+ )
390
+ noise = _pack_latents(
391
+ noise,
392
+ shape[0],
393
+ transformer.config.in_channels // 4,
394
+ height=shape[2],
395
+ width=shape[3]
396
+ )
397
+
398
+ # Get timestep
399
+ timestep = scheduler.timesteps[timesteps_to:timesteps_to+1]
400
+
401
+ # Add noise to latents
402
+ noisy_latents = scheduler.add_noise(init_latents, noise, timestep)
403
+
404
+ return noisy_latents, noise
utils/lora.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+ from datetime import datetime
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ "DownBlock2D",
24
+ "UpBlock2D",
25
+
26
+ ] # locon, 3clier
27
+
28
+ LORA_PREFIX_UNET = "lora_unet"
29
+
30
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
31
+
32
+ TRAINING_METHODS = Literal[
33
+ "noxattn", # train all layers except x-attns and time_embed layers
34
+ "innoxattn", # train all layers except self attention layers
35
+ "selfattn", # ESD-u, train only self attention layers
36
+ "xattn", # ESD-x, train only x attention layers
37
+ "xattn-up", # all up blocks only
38
+ "xattn-down",# all down blocks only
39
+ "xattn-mid",# mid blocks only
40
+ "full", # train all layers
41
+ "xattn-strict", # q and k values
42
+ "noxattn-hspace",
43
+ "noxattn-hspace-last",
44
+ "flux-attn",
45
+ # "xlayer",
46
+ # "outxattn",
47
+ # "outsattn",
48
+ # "inxattn",
49
+ # "inmidsattn",
50
+ # "selflayer",
51
+ ]
52
+
53
+ def load_ortho_dict(n):
54
+ path = f'/share/u/rohit/orthogonal_basis/{n:09}.ckpt'
55
+ if os.path.isfile(path):
56
+ return torch.load(path)
57
+ else:
58
+ x = torch.randn(n,n)
59
+ eig, _, _ = torch.svd(x)
60
+ torch.save(eig, path)
61
+ return eig
62
+
63
+ def init_ortho_proj(rank, weight):
64
+ seed = torch.seed()
65
+ torch.manual_seed(datetime.now().timestamp())
66
+ q_index = torch.randint(high=weight.size(0),size=(rank,))
67
+ torch.manual_seed(seed)
68
+
69
+ ortho_q_init = load_ortho_dict(weight.size(0)).to(dtype=weight.dtype)[:,q_index]
70
+ return nn.Parameter(ortho_q_init)
71
+
72
+
73
+ class LoRAModule(nn.Module):
74
+ """
75
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ lora_name,
81
+ org_module: nn.Module,
82
+ multiplier=1.0,
83
+ lora_dim=4,
84
+ alpha=1,
85
+ train_method='xattn',
86
+ fast_init = False
87
+ ):
88
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
89
+ super().__init__()
90
+ self.lora_name = lora_name
91
+ self.lora_dim = lora_dim
92
+
93
+ if "Linear" in org_module.__class__.__name__:
94
+ in_dim = org_module.in_features
95
+ out_dim = org_module.out_features
96
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
97
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
98
+
99
+ elif "Conv" in org_module.__class__.__name__: # 一応
100
+ in_dim = org_module.in_channels
101
+ out_dim = org_module.out_channels
102
+
103
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
104
+ if self.lora_dim != lora_dim:
105
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
106
+
107
+ kernel_size = org_module.kernel_size
108
+ stride = org_module.stride
109
+ padding = org_module.padding
110
+ self.lora_down = nn.Conv2d(
111
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
112
+ )
113
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
114
+
115
+ if type(alpha) == torch.Tensor:
116
+ alpha = alpha.detach().numpy()
117
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
118
+ self.scale = alpha / self.lora_dim
119
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
120
+
121
+ # same as microsoft's
122
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
123
+ if train_method == 'full':
124
+ nn.init.zeros_(self.lora_up.weight)
125
+ else:
126
+ if not fast_init:
127
+ self.lora_up.weight = init_ortho_proj(lora_dim, self.lora_up.weight)
128
+ self.lora_up.weight.requires_grad_(False)
129
+ else:
130
+ nn.init.zeros_(self.lora_up.weight)
131
+
132
+ self.multiplier = multiplier
133
+ self.org_module = org_module # remove in applying
134
+
135
+ def apply_to(self):
136
+ self.org_forward = self.org_module.forward
137
+ self.org_module.forward = self.forward
138
+ del self.org_module
139
+
140
+ def forward(self, x):
141
+ return (
142
+ self.org_forward(x)
143
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
144
+ )
145
+
146
+
147
+ class LoRANetwork(nn.Module):
148
+ def __init__(
149
+ self,
150
+ unet: UNet2DConditionModel,
151
+ rank: int = 4,
152
+ multiplier: float = 1.0,
153
+ alpha: float = 1.0,
154
+ train_method: TRAINING_METHODS = "full",
155
+ layers = ['Linear', 'Conv'],
156
+ fast_init = False,
157
+ ) -> None:
158
+ super().__init__()
159
+ self.lora_scale = 1
160
+ self.multiplier = multiplier
161
+ self.lora_dim = rank
162
+ self.alpha = alpha
163
+ self.train_method=train_method
164
+ # LoRAのみ
165
+ self.module = LoRAModule
166
+
167
+ # unetのloraを作る
168
+ self.unet_loras = self.create_modules(
169
+ LORA_PREFIX_UNET,
170
+ unet,
171
+ DEFAULT_TARGET_REPLACE,
172
+ self.lora_dim,
173
+ self.multiplier,
174
+ train_method=train_method,
175
+ layers = layers,
176
+ fast_init=fast_init,
177
+ )
178
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
179
+
180
+ # assertion 名前の被りがないか確認しているようだ
181
+ lora_names = set()
182
+ for lora in self.unet_loras:
183
+ assert (
184
+ lora.lora_name not in lora_names
185
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
186
+ lora_names.add(lora.lora_name)
187
+
188
+ # 適用する
189
+ for lora in self.unet_loras:
190
+ lora.apply_to()
191
+ self.add_module(
192
+ lora.lora_name,
193
+ lora,
194
+ )
195
+
196
+ del unet
197
+
198
+ torch.cuda.empty_cache()
199
+
200
+ def create_modules(
201
+ self,
202
+ prefix: str,
203
+ root_module: nn.Module,
204
+ target_replace_modules: List[str],
205
+ rank: int,
206
+ multiplier: float,
207
+ train_method: TRAINING_METHODS,
208
+ layers: List[str],
209
+ fast_init: bool,
210
+ ) -> list:
211
+ filt_layers = []
212
+ if 'Linear' in layers:
213
+ filt_layers.extend(["Linear", "LoRACompatibleLinear"])
214
+ if 'Conv' in layers:
215
+ filt_layers.extend(["Conv2d", "LoRACompatibleConv"])
216
+ loras = []
217
+ names = []
218
+ for name, module in root_module.named_modules():
219
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
220
+ if "attn2" in name or "time_embed" in name:
221
+ continue
222
+ elif train_method == "innoxattn": # Cross Attention 以外学習
223
+ if "attn2" in name:
224
+ continue
225
+ elif train_method == "selfattn": # Self Attention のみ学習
226
+ if "attn1" not in name:
227
+ continue
228
+ elif train_method in ["xattn", "xattn-strict", "xattn-up", "xattn-down", "xattn-mid"]: # Cross Attention のみ学習
229
+ if "attn2" not in name:
230
+ continue
231
+ if train_method == 'xattn-up':
232
+ if 'up_block' not in name:
233
+ continue
234
+ if train_method == 'xattn-down':
235
+ if 'down_block' not in name:
236
+ continue
237
+ if train_method == 'xattn-mid':
238
+ if 'mid_block' not in name:
239
+ continue
240
+ elif train_method == "full": # 全部学習
241
+ pass
242
+ elif train_method == "flux-attn":
243
+ if "attn" not in name:
244
+ continue
245
+ else:
246
+ raise NotImplementedError(
247
+ f"train_method: {train_method} is not implemented."
248
+ )
249
+ if module.__class__.__name__ in target_replace_modules:
250
+ for child_name, child_module in module.named_modules():
251
+ if child_module.__class__.__name__ in filt_layers:
252
+
253
+
254
+ if train_method == 'xattn-strict':
255
+ if 'out' in child_name:
256
+ continue
257
+ if 'to_q' in child_name:
258
+ continue
259
+ if train_method == 'noxattn-hspace':
260
+ if 'mid_block' not in name:
261
+ continue
262
+ if train_method == 'noxattn-hspace-last':
263
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
264
+ continue
265
+ lora_name = prefix + "." + name + "." + child_name
266
+ lora_name = lora_name.replace(".", "_")
267
+ # print(f"{lora_name}")
268
+ lora = self.module(
269
+ lora_name, child_module, multiplier, rank, self.alpha, train_method, fast_init
270
+ )
271
+ # print(name, child_name)
272
+ # print(child_module.weight.shape)
273
+ if lora_name not in names:
274
+ loras.append(lora)
275
+ names.append(lora_name)
276
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
277
+ return loras
278
+
279
+ def prepare_optimizer_params(self):
280
+ all_params = []
281
+
282
+ if self.unet_loras: # 実質これしかない
283
+ params = []
284
+ if self.train_method == 'full':
285
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
286
+ else:
287
+ [params.extend(lora.lora_down.parameters()) for lora in self.unet_loras]
288
+ param_data = {"params": params}
289
+ all_params.append(param_data)
290
+
291
+ return all_params
292
+
293
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
294
+ state_dict = self.state_dict()
295
+
296
+ if dtype is not None:
297
+ for key in list(state_dict.keys()):
298
+ v = state_dict[key]
299
+ v = v.detach().clone().to("cpu").to(dtype)
300
+ state_dict[key] = v
301
+
302
+ # for key in list(state_dict.keys()):
303
+ # if not key.startswith("lora"):
304
+ # # lora以外除外
305
+ # del state_dict[key]
306
+
307
+ if os.path.splitext(file)[1] == ".safetensors":
308
+ save_file(state_dict, file, metadata)
309
+ else:
310
+ torch.save(state_dict, file)
311
+ def set_lora_slider(self, scale):
312
+ self.lora_scale = scale
313
+
314
+ def __enter__(self):
315
+ for lora in self.unet_loras:
316
+ lora.multiplier = 1.0 * self.lora_scale
317
+
318
+ def __exit__(self, exc_type, exc_value, tb):
319
+ for lora in self.unet_loras:
320
+ lora.multiplier = 0
utils/model_util.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional
2
+
3
+ import torch, gc, os
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast
5
+ from transformers import (
6
+ AutoModel,
7
+ CLIPModel,
8
+ CLIPProcessor,
9
+ )
10
+ from huggingface_hub import hf_hub_download
11
+ from diffusers import (
12
+ UNet2DConditionModel,
13
+ SchedulerMixin,
14
+ StableDiffusionPipeline,
15
+ StableDiffusionXLPipeline,
16
+ FluxPipeline,
17
+ AutoencoderKL,
18
+ FluxTransformer2DModel,
19
+ )
20
+ import copy
21
+ from diffusers.schedulers import (
22
+ DDIMScheduler,
23
+ DDPMScheduler,
24
+ LMSDiscreteScheduler,
25
+ EulerAncestralDiscreteScheduler,
26
+ FlowMatchEulerDiscreteScheduler,
27
+ )
28
+ from diffusers import LCMScheduler, AutoencoderTiny
29
+ import sys
30
+ sys.path.append('.')
31
+ from .flux_utils import *
32
+
33
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
34
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
35
+
36
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
37
+
38
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
39
+
40
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
41
+
42
+
43
+ def load_diffusers_model(
44
+ pretrained_model_name_or_path: str,
45
+ v2: bool = False,
46
+ clip_skip: Optional[int] = None,
47
+ weight_dtype: torch.dtype = torch.float32,
48
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
49
+ # VAE はいらない
50
+
51
+ if v2:
52
+ tokenizer = CLIPTokenizer.from_pretrained(
53
+ TOKENIZER_V2_MODEL_NAME,
54
+ subfolder="tokenizer",
55
+ torch_dtype=weight_dtype,
56
+ cache_dir=DIFFUSERS_CACHE_DIR,
57
+ )
58
+ text_encoder = CLIPTextModel.from_pretrained(
59
+ pretrained_model_name_or_path,
60
+ subfolder="text_encoder",
61
+ # default is clip skip 2
62
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
63
+ torch_dtype=weight_dtype,
64
+ cache_dir=DIFFUSERS_CACHE_DIR,
65
+ )
66
+ else:
67
+ tokenizer = CLIPTokenizer.from_pretrained(
68
+ TOKENIZER_V1_MODEL_NAME,
69
+ subfolder="tokenizer",
70
+ torch_dtype=weight_dtype,
71
+ cache_dir=DIFFUSERS_CACHE_DIR,
72
+ )
73
+ text_encoder = CLIPTextModel.from_pretrained(
74
+ pretrained_model_name_or_path,
75
+ subfolder="text_encoder",
76
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
77
+ torch_dtype=weight_dtype,
78
+ cache_dir=DIFFUSERS_CACHE_DIR,
79
+ )
80
+
81
+ unet = UNet2DConditionModel.from_pretrained(
82
+ pretrained_model_name_or_path,
83
+ subfolder="unet",
84
+ torch_dtype=weight_dtype,
85
+ cache_dir=DIFFUSERS_CACHE_DIR,
86
+ )
87
+
88
+ return tokenizer, text_encoder, unet
89
+
90
+
91
+ def load_checkpoint_model(
92
+ checkpoint_path: str,
93
+ v2: bool = False,
94
+ clip_skip: Optional[int] = None,
95
+ weight_dtype: torch.dtype = torch.float32,
96
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
97
+ pipe = StableDiffusionPipeline.from_ckpt(
98
+ checkpoint_path,
99
+ upcast_attention=True if v2 else False,
100
+ torch_dtype=weight_dtype,
101
+ cache_dir=DIFFUSERS_CACHE_DIR,
102
+ )
103
+
104
+ unet = pipe.unet
105
+ tokenizer = pipe.tokenizer
106
+ text_encoder = pipe.text_encoder
107
+ if clip_skip is not None:
108
+ if v2:
109
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
110
+ else:
111
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
112
+
113
+ del pipe
114
+
115
+ return tokenizer, text_encoder, unet
116
+
117
+
118
+ def load_models(
119
+ pretrained_model_name_or_path: str,
120
+ scheduler_name: AVAILABLE_SCHEDULERS,
121
+ v2: bool = False,
122
+ v_pred: bool = False,
123
+ weight_dtype: torch.dtype = torch.float32,
124
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
125
+ if pretrained_model_name_or_path.endswith(
126
+ ".ckpt"
127
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
128
+ tokenizer, text_encoder, unet = load_checkpoint_model(
129
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
130
+ )
131
+ else: # diffusers
132
+ tokenizer, text_encoder, unet = load_diffusers_model(
133
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
134
+ )
135
+
136
+ # VAE はいらない
137
+
138
+ scheduler = create_noise_scheduler(
139
+ scheduler_name,
140
+ prediction_type="v_prediction" if v_pred else "epsilon",
141
+ )
142
+
143
+ return tokenizer, text_encoder, unet, scheduler
144
+
145
+
146
+ def load_diffusers_model_xl(
147
+ pretrained_model_name_or_path: str,
148
+ weight_dtype: torch.dtype = torch.float32,
149
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
150
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
151
+
152
+ tokenizers = [
153
+ CLIPTokenizer.from_pretrained(
154
+ pretrained_model_name_or_path,
155
+ subfolder="tokenizer",
156
+ torch_dtype=weight_dtype,
157
+ cache_dir=DIFFUSERS_CACHE_DIR,
158
+ ),
159
+ CLIPTokenizer.from_pretrained(
160
+ pretrained_model_name_or_path,
161
+ subfolder="tokenizer_2",
162
+ torch_dtype=weight_dtype,
163
+ cache_dir=DIFFUSERS_CACHE_DIR,
164
+ pad_token_id=0, # same as open clip
165
+ ),
166
+ ]
167
+
168
+ text_encoders = [
169
+ CLIPTextModel.from_pretrained(
170
+ pretrained_model_name_or_path,
171
+ subfolder="text_encoder",
172
+ torch_dtype=weight_dtype,
173
+ cache_dir=DIFFUSERS_CACHE_DIR,
174
+ ),
175
+ CLIPTextModelWithProjection.from_pretrained(
176
+ pretrained_model_name_or_path,
177
+ subfolder="text_encoder_2",
178
+ torch_dtype=weight_dtype,
179
+ cache_dir=DIFFUSERS_CACHE_DIR,
180
+ ),
181
+ ]
182
+
183
+ unet = UNet2DConditionModel.from_pretrained(
184
+ pretrained_model_name_or_path,
185
+ subfolder="unet",
186
+ torch_dtype=weight_dtype,
187
+ cache_dir=DIFFUSERS_CACHE_DIR,
188
+ )
189
+
190
+ return tokenizers, text_encoders, unet
191
+
192
+
193
+ def load_checkpoint_model_xl(
194
+ checkpoint_path: str,
195
+ weight_dtype: torch.dtype = torch.float32,
196
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
197
+ pipe = StableDiffusionXLPipeline.from_single_file(
198
+ checkpoint_path,
199
+ torch_dtype=weight_dtype,
200
+ cache_dir=DIFFUSERS_CACHE_DIR,
201
+ )
202
+
203
+ unet = pipe.unet
204
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
205
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
206
+ if len(text_encoders) == 2:
207
+ text_encoders[1].pad_token_id = 0
208
+
209
+ del pipe
210
+
211
+ return tokenizers, text_encoders, unet
212
+
213
+
214
+ def load_models_xl_(
215
+ pretrained_model_name_or_path: str,
216
+ scheduler_name: AVAILABLE_SCHEDULERS,
217
+ weight_dtype: torch.dtype = torch.float32,
218
+ ) -> tuple[
219
+ list[CLIPTokenizer],
220
+ list[SDXL_TEXT_ENCODER_TYPE],
221
+ UNet2DConditionModel,
222
+ SchedulerMixin,
223
+ ]:
224
+ if pretrained_model_name_or_path.endswith(
225
+ ".ckpt"
226
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
227
+ (
228
+ tokenizers,
229
+ text_encoders,
230
+ unet,
231
+ ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
232
+ else: # diffusers
233
+ (
234
+ tokenizers,
235
+ text_encoders,
236
+ unet,
237
+ ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
238
+
239
+ scheduler = create_noise_scheduler(scheduler_name)
240
+
241
+ return tokenizers, text_encoders, unet, scheduler
242
+
243
+
244
+ def create_noise_scheduler(
245
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
246
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
247
+ ) -> SchedulerMixin:
248
+ # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
249
+
250
+ name = scheduler_name.lower().replace(" ", "_")
251
+ if name == "ddim":
252
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
253
+ scheduler = DDIMScheduler(
254
+ beta_start=0.00085,
255
+ beta_end=0.012,
256
+ beta_schedule="scaled_linear",
257
+ num_train_timesteps=1000,
258
+ clip_sample=False,
259
+ prediction_type=prediction_type, # これでいいの?
260
+ )
261
+ elif name == "ddpm":
262
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
263
+ scheduler = DDPMScheduler(
264
+ beta_start=0.00085,
265
+ beta_end=0.012,
266
+ beta_schedule="scaled_linear",
267
+ num_train_timesteps=1000,
268
+ clip_sample=False,
269
+ prediction_type=prediction_type,
270
+ )
271
+ elif name == "lms":
272
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
273
+ scheduler = LMSDiscreteScheduler(
274
+ beta_start=0.00085,
275
+ beta_end=0.012,
276
+ beta_schedule="scaled_linear",
277
+ num_train_timesteps=1000,
278
+ prediction_type=prediction_type,
279
+ )
280
+ elif name == "euler_a":
281
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
282
+ scheduler = EulerAncestralDiscreteScheduler(
283
+ beta_start=0.00085,
284
+ beta_end=0.012,
285
+ beta_schedule="scaled_linear",
286
+ num_train_timesteps=1000,
287
+ # clip_sample=False,
288
+ prediction_type=prediction_type,
289
+ )
290
+ else:
291
+ raise ValueError(f"Unknown scheduler name: {name}")
292
+
293
+ return scheduler
294
+
295
+
296
+ def load_models_xl(params):
297
+ """
298
+ Load all required models for training
299
+
300
+ Args:
301
+ params: Dictionary containing model parameters and configurations
302
+
303
+ Returns:
304
+ dict: Dictionary containing all loaded models and tokenizers
305
+ """
306
+ device = params['device']
307
+ weight_dtype = params['weight_dtype']
308
+
309
+ # Load SDXL components (UNet, text encoders, tokenizers)
310
+ scheduler_name = 'ddim'
311
+ tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_(
312
+ params['pretrained_model_name_or_path'],
313
+ scheduler_name=scheduler_name,
314
+ )
315
+
316
+ # Move text encoders to device and set to eval mode
317
+ for text_encoder in text_encoders:
318
+ text_encoder.to(device, dtype=weight_dtype)
319
+ text_encoder.requires_grad_(False)
320
+ text_encoder.eval()
321
+
322
+ # Set up UNet
323
+ unet.to(device, dtype=weight_dtype)
324
+ unet.requires_grad_(False)
325
+ unet.eval()
326
+
327
+ # Load tiny VAE for efficiency
328
+ vae = AutoencoderTiny.from_pretrained(
329
+ "madebyollin/taesdxl",
330
+ torch_dtype=weight_dtype
331
+ )
332
+ vae = vae.to(device, dtype=weight_dtype)
333
+ vae.requires_grad_(False)
334
+
335
+ # Load appropriate encoder (CLIP or DinoV2)
336
+ if params['encoder'] == 'dinov2-small':
337
+ clip_model = AutoModel.from_pretrained(
338
+ 'facebook/dinov2-small',
339
+ torch_dtype=weight_dtype
340
+ )
341
+ clip_processor= None
342
+ else:
343
+ clip_model = CLIPModel.from_pretrained(
344
+ "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
345
+ torch_dtype=weight_dtype
346
+ )
347
+ clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
348
+ clip_model = clip_model.to(device, dtype=weight_dtype)
349
+ clip_model.requires_grad_(False)
350
+
351
+
352
+
353
+ # If using DMD checkpoint, load it
354
+ if params['distilled'] != 'None':
355
+ if '.safetensors' in params['distilled']:
356
+ unet.load_state_dict(load_file(params['distilled'], device=device))
357
+ elif 'dmd2' in params['distilled']:
358
+ repo_name = "tianweiy/DMD2"
359
+ ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
360
+ unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
361
+ else:
362
+ unet.load_state_dict(torch.load(params['distilled']))
363
+
364
+
365
+ # Set up LCM scheduler for DMD
366
+ noise_scheduler = LCMScheduler(
367
+ beta_start=0.00085,
368
+ beta_end=0.012,
369
+ beta_schedule="scaled_linear",
370
+ num_train_timesteps=1000,
371
+ prediction_type="epsilon",
372
+ original_inference_steps=1000
373
+ )
374
+
375
+ noise_scheduler.set_timesteps(params['max_denoising_steps'])
376
+ pipe = StableDiffusionXLPipeline(vae = vae,
377
+ text_encoder = text_encoders[0],
378
+ text_encoder_2 = text_encoders[1],
379
+ tokenizer = tokenizers[0],
380
+ tokenizer_2 = tokenizers[1],
381
+ unet = unet,
382
+ scheduler = noise_scheduler)
383
+ pipe.set_progress_bar_config(disable=True)
384
+ return {
385
+ 'unet': unet,
386
+ 'vae': vae,
387
+ 'clip_model': clip_model,
388
+ 'clip_processor': clip_processor,
389
+ 'tokenizers': tokenizers,
390
+ 'text_encoders': text_encoders,
391
+ 'noise_scheduler': noise_scheduler
392
+ }, pipe
393
+
394
+
395
+ def load_models_flux(params):
396
+ # Load the tokenizers
397
+ tokenizer_one = CLIPTokenizer.from_pretrained(
398
+ params['pretrained_model_name_or_path'],
399
+ subfolder="tokenizer",
400
+ torch_dtype=params['weight_dtype'], device_map=params['device']
401
+ )
402
+ tokenizer_two = T5TokenizerFast.from_pretrained(
403
+ params['pretrained_model_name_or_path'],
404
+ subfolder="tokenizer_2",
405
+ torch_dtype=params['weight_dtype'], device_map=params['device']
406
+ )
407
+
408
+ # Load scheduler and models
409
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
410
+ params['pretrained_model_name_or_path'],
411
+ subfolder="scheduler",
412
+ torch_dtype=params['weight_dtype'], device=params['device']
413
+ )
414
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
415
+
416
+
417
+
418
+ # import correct text encoder classes
419
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
420
+ params['pretrained_model_name_or_path'],
421
+ )
422
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
423
+ params['pretrained_model_name_or_path'], subfolder="text_encoder_2"
424
+ )
425
+ # Load the text encoders
426
+ text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype'])
427
+
428
+ # Load VAE
429
+ vae = AutoencoderKL.from_pretrained(
430
+ params['pretrained_model_name_or_path'],
431
+ subfolder="vae",
432
+ torch_dtype=params['weight_dtype'], device_map='auto'
433
+ )
434
+ transformer = FluxTransformer2DModel.from_pretrained(
435
+ params['pretrained_model_name_or_path'],
436
+ subfolder="transformer",
437
+ torch_dtype=params['weight_dtype']
438
+ )
439
+
440
+ # We only train the additional adapter LoRA layers
441
+ transformer.requires_grad_(False)
442
+ vae.requires_grad_(False)
443
+ text_encoder_one.requires_grad_(False)
444
+ text_encoder_two.requires_grad_(False)
445
+
446
+ vae.to(params['device'])
447
+ transformer.to(params['device'])
448
+ text_encoder_one.to(params['device'])
449
+ text_encoder_two.to(params['device'])
450
+
451
+ # Load appropriate encoder (CLIP or DinoV2)
452
+ if params['encoder'] == 'dinov2-small':
453
+ clip_model = AutoModel.from_pretrained(
454
+ 'facebook/dinov2-small',
455
+ torch_dtype=params['weight_dtype']
456
+ )
457
+ clip_processor= None
458
+ else:
459
+ clip_model = CLIPModel.from_pretrained(
460
+ "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
461
+ torch_dtype=params['weight_dtype']
462
+ )
463
+ clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
464
+ clip_model = clip_model.to(params['device'], dtype=params['weight_dtype'])
465
+ clip_model.requires_grad_(False)
466
+
467
+
468
+ pipe = FluxPipeline(noise_scheduler,
469
+ vae,
470
+ text_encoder_one,
471
+ tokenizer_one,
472
+ text_encoder_two,
473
+ tokenizer_two,
474
+ transformer,
475
+ )
476
+ pipe.set_progress_bar_config(disable=True)
477
+
478
+ return {
479
+ 'transformer': transformer,
480
+ 'vae': vae,
481
+ 'clip_model': clip_model,
482
+ 'clip_processor': clip_processor,
483
+ 'tokenizers': [tokenizer_one, tokenizer_two],
484
+ 'text_encoders': [text_encoder_one,text_encoder_two],
485
+ 'noise_scheduler': noise_scheduler
486
+ }, pipe
487
+
488
+ def save_checkpoint(networks, save_path, weight_dtype):
489
+ """
490
+ Save network weights and perform cleanup
491
+
492
+ Args:
493
+ networks: Dictionary of LoRA networks to save
494
+ save_path: Path to save the checkpoints
495
+ weight_dtype: Data type for the weights
496
+ """
497
+ print("Saving checkpoint...")
498
+
499
+ try:
500
+ # Create save directory if it doesn't exist
501
+ os.makedirs(save_path, exist_ok=True)
502
+
503
+ # Save each network's weights
504
+ for net_idx, network in networks.items():
505
+ save_name = f"{save_path}/slider_{net_idx}.pt"
506
+ try:
507
+ network.save_weights(
508
+ save_name,
509
+ dtype=weight_dtype,
510
+ )
511
+ except Exception as e:
512
+ print(f"Error saving network {net_idx}: {str(e)}")
513
+ continue
514
+
515
+ # Cleanup
516
+ torch.cuda.empty_cache()
517
+ gc.collect()
518
+
519
+ print("Checkpoint saved successfully.")
520
+
521
+ except Exception as e:
522
+ print(f"Error during checkpoint saving: {str(e)}")
523
+
524
+ finally:
525
+ # Ensure memory is cleaned up even if save fails
526
+ torch.cuda.empty_cache()
527
+ gc.collect()
utils/prompt_util.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ client = anthropic.Anthropic()
3
+ from typing import List, Optional
4
+
5
+ def claude_generate_prompts_sliders(prompt,
6
+ num_prompts=20,
7
+ temperature=0.2,
8
+ max_tokens=2000,
9
+ frequency_penalty=0.0,
10
+ model="claude-3-5-sonnet-20240620",
11
+ verbose=False):
12
+ assistant_prompt = f''' You are an expert in writing diverse image captions. When i provide a prompt, I want you to give me {num_prompts} alternative prompts that is similar to the provided prompt but produces diverse images. Be creative and make sure the original subjects in the original prompt are present in your prompts. Make sure that you end the prompts with keywords that will produce high quality images like ",detailed, 8k" or ",hyper-realistic, 4k".
13
+
14
+ Give me the expanded prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
15
+ I need you to give me only the python list and nothing else. Do not explain yourself
16
+
17
+ example output format:
18
+ ["prompt1", "prompt2", ...]
19
+ '''
20
+
21
+ user_prompt = prompt
22
+
23
+ message=[
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {
28
+ "type": "text",
29
+ "text": user_prompt
30
+ }
31
+ ]
32
+ }
33
+ ]
34
+
35
+ output = client.messages.create(
36
+ model=model,
37
+ max_tokens=max_tokens,
38
+ temperature=temperature,
39
+ system=assistant_prompt,
40
+ messages=message
41
+ )
42
+ content = output.content[0].text
43
+ return content
44
+
45
+
46
+ def expand_prompts(concept_prompts: List[str], diverse_prompt_num: int, args) -> List[str]:
47
+ """
48
+ Expand the input prompts using Claude if requested.
49
+
50
+ Args:
51
+ concept_prompts: Initial list of prompts
52
+ diverse_prompt_num: Number of variations to generate per prompt
53
+ args: Training arguments
54
+
55
+ Returns:
56
+ List of expanded prompts
57
+ """
58
+ diverse_prompts = []
59
+
60
+ if diverse_prompt_num != 0:
61
+ for prompt in concept_prompts:
62
+ try:
63
+ claude_generated_prompts = claude_generate_prompts_sliders(
64
+ prompt=prompt,
65
+ num_prompts=diverse_prompt_num,
66
+ temperature=0.2,
67
+ max_tokens=8000,
68
+ frequency_penalty=0.0,
69
+ model="claude-3-5-sonnet-20240620",
70
+ verbose=False
71
+ )
72
+ diverse_prompts.extend(eval(claude_generated_prompts))
73
+ except Exception as e:
74
+ print(f"Error with Claude response: {e}")
75
+ diverse_prompts.append(prompt)
76
+ else:
77
+ diverse_prompts = concept_prompts
78
+
79
+ print(f"Using prompts: {diverse_prompts}")
80
+ return diverse_prompts
utils/train_util.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import UNet2DConditionModel, SchedulerMixin, FluxImg2ImgPipeline
7
+ from diffusers.image_processor import VaeImageProcessor
8
+ # from model_util import SDXL_TEXT_ENCODER_TYPE
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from tqdm import tqdm
12
+
13
+ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
14
+ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
15
+
16
+ UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
17
+ TEXT_ENCODER_2_PROJECTION_DIM = 1280
18
+ UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
19
+
20
+
21
+ def get_random_noise(
22
+ batch_size: int, height: int, width: int, generator: torch.Generator = None
23
+ ) -> torch.Tensor:
24
+ return torch.randn(
25
+ (
26
+ batch_size,
27
+ UNET_IN_CHANNELS,
28
+ height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
29
+ width // VAE_SCALE_FACTOR,
30
+ ),
31
+ generator=generator,
32
+ device="cpu",
33
+ )
34
+
35
+
36
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
37
+ def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
38
+ latents = latents + noise_offset * torch.randn(
39
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
40
+ )
41
+ return latents
42
+
43
+
44
+ def get_initial_latents(
45
+ scheduler: SchedulerMixin,
46
+ n_imgs: int,
47
+ height: int,
48
+ width: int,
49
+ n_prompts: int,
50
+ generator=None,
51
+ ) -> torch.Tensor:
52
+ noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
53
+ n_prompts, 1, 1, 1
54
+ )
55
+
56
+ latents = noise * scheduler.init_noise_sigma
57
+
58
+ return latents
59
+
60
+
61
+ def text_tokenize(
62
+ tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
63
+ prompts: list[str],
64
+ ):
65
+ return tokenizer(
66
+ prompts,
67
+ padding="max_length",
68
+ max_length=tokenizer.model_max_length,
69
+ truncation=True,
70
+ return_tensors="pt",
71
+ ).input_ids
72
+
73
+
74
+ def text_encode(text_encoder: CLIPTextModel, tokens):
75
+ return text_encoder(tokens.to(text_encoder.device))[0]
76
+
77
+
78
+ def encode_prompts(
79
+ tokenizer: CLIPTokenizer,
80
+ text_encoder: CLIPTokenizer,
81
+ prompts: list[str],
82
+ ):
83
+
84
+ text_tokens = text_tokenize(tokenizer, prompts)
85
+ text_embeddings = text_encode(text_encoder, text_tokens)
86
+
87
+
88
+
89
+ return text_embeddings
90
+
91
+
92
+ # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
93
+ def text_encode_xl(
94
+ text_encoder,
95
+ tokens: torch.FloatTensor,
96
+ num_images_per_prompt: int = 1,
97
+ ):
98
+ prompt_embeds = text_encoder(
99
+ tokens.to(text_encoder.device), output_hidden_states=True
100
+ )
101
+ pooled_prompt_embeds = prompt_embeds[0]
102
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
103
+
104
+ bs_embed, seq_len, _ = prompt_embeds.shape
105
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
106
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
107
+
108
+ return prompt_embeds, pooled_prompt_embeds
109
+
110
+
111
+ def encode_prompts_xl(
112
+ tokenizers,
113
+ text_encoders,
114
+ prompts: list[str],
115
+ num_images_per_prompt: int = 1,
116
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
117
+ # text_encoder and text_encoder_2's penuultimate layer's output
118
+ text_embeds_list = []
119
+ pooled_text_embeds = None # always text_encoder_2's pool
120
+
121
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
122
+ text_tokens_input_ids = text_tokenize(tokenizer, prompts)
123
+ text_embeds, pooled_text_embeds = text_encode_xl(
124
+ text_encoder, text_tokens_input_ids, num_images_per_prompt
125
+ )
126
+
127
+ text_embeds_list.append(text_embeds)
128
+
129
+ bs_embed = pooled_text_embeds.shape[0]
130
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
131
+ bs_embed * num_images_per_prompt, -1
132
+ )
133
+
134
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
135
+
136
+
137
+ def concat_embeddings(
138
+ unconditional: torch.FloatTensor,
139
+ conditional: torch.FloatTensor,
140
+ n_imgs: int,
141
+ ):
142
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
143
+
144
+
145
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
146
+ def predict_noise(
147
+ unet: UNet2DConditionModel,
148
+ scheduler: SchedulerMixin,
149
+ timestep: int, # 現在のタイムステップ
150
+ latents: torch.FloatTensor,
151
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
152
+ guidance_scale=7.5,
153
+ ) -> torch.FloatTensor:
154
+ latent_model_input = latents
155
+ if guidance_scale!=0:
156
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
157
+ latent_model_input = torch.cat([latents] * 2)
158
+
159
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
160
+
161
+ # predict the noise residual
162
+ noise_pred = unet(
163
+ latent_model_input,
164
+ timestep,
165
+ encoder_hidden_states=text_embeddings,
166
+ ).sample
167
+
168
+ # perform guidance
169
+ if guidance_scale != 1 and guidance_scale!=0:
170
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
171
+ noise_pred = noise_pred_uncond + guidance_scale * (
172
+ noise_pred_text - noise_pred_uncond
173
+ )
174
+
175
+ return noise_pred
176
+
177
+
178
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
179
+ @torch.no_grad()
180
+ def diffusion(
181
+ unet: UNet2DConditionModel,
182
+ scheduler: SchedulerMixin,
183
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
184
+ text_embeddings: torch.FloatTensor,
185
+ total_timesteps: int = 1000,
186
+ start_timesteps=0,
187
+ guidance_scale=1,
188
+ composition=False,
189
+ **kwargs,
190
+ ):
191
+ # latents_steps = []
192
+
193
+ for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
194
+ if not composition:
195
+ noise_pred = predict_noise(
196
+ unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale
197
+ )
198
+ if guidance_scale==1:
199
+ _, noise_pred = noise_pred.chunk(2)
200
+ else:
201
+ for idx in range(text_embeddings.shape[0]):
202
+ pred = predict_noise(
203
+ unet, scheduler, timestep, latents, text_embeddings[idx:idx+1], guidance_scale=1
204
+ )
205
+ uncond, pred = noise_pred.chunk(2)
206
+ if idx == 0:
207
+ noise_pred = guidance_scale * pred
208
+ else:
209
+ noise_pred += guidance_scale * pred
210
+ noise_pred += uncond
211
+
212
+
213
+ # compute the previous noisy sample x_t -> x_t-1
214
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
215
+
216
+ # return latents_steps
217
+ return latents
218
+
219
+
220
+ def rescale_noise_cfg(
221
+ noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
222
+ ):
223
+ """
224
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
225
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
226
+ """
227
+ std_text = noise_pred_text.std(
228
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
229
+ )
230
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
231
+ # rescale the results from guidance (fixes overexposure)
232
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
233
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
234
+ noise_cfg = (
235
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
236
+ )
237
+
238
+ return noise_cfg
239
+
240
+
241
+ def predict_noise_xl(
242
+ unet: UNet2DConditionModel,
243
+ scheduler: SchedulerMixin,
244
+ timestep: int, # 現在のタイムステップ
245
+ latents: torch.FloatTensor,
246
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
247
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
248
+ add_time_ids: torch.FloatTensor,
249
+ guidance_scale=7.5,
250
+ guidance_rescale=0.7,
251
+ ) -> torch.FloatTensor:
252
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
253
+ latent_model_input = latents
254
+ if guidance_scale !=0:
255
+ latent_model_input = torch.cat([latents] * 2)
256
+
257
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
258
+
259
+ added_cond_kwargs = {
260
+ "text_embeds": add_text_embeddings,
261
+ "time_ids": add_time_ids,
262
+ }
263
+
264
+ # predict the noise residual
265
+ noise_pred = unet(
266
+ latent_model_input,
267
+ timestep,
268
+ encoder_hidden_states=text_embeddings,
269
+ added_cond_kwargs=added_cond_kwargs,
270
+ ).sample
271
+ # perform guidance
272
+ if guidance_scale != 1 and guidance_scale!=0:
273
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
274
+ noise_pred = noise_pred_uncond + guidance_scale * (
275
+ noise_pred_text - noise_pred_uncond
276
+ )
277
+
278
+ return noise_pred
279
+ # # perform guidance
280
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
281
+ # guided_target = noise_pred_uncond + guidance_scale * (
282
+ # noise_pred_text - noise_pred_uncond
283
+ # )
284
+
285
+ # # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
286
+ # noise_pred = rescale_noise_cfg(
287
+ # noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
288
+ # )
289
+
290
+ # return guided_target
291
+
292
+
293
+ @torch.no_grad()
294
+ def diffusion_xl(
295
+ unet: UNet2DConditionModel,
296
+ scheduler: SchedulerMixin,
297
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
298
+ text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
299
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
300
+ add_time_ids: torch.FloatTensor,
301
+ guidance_scale: float = 1.0,
302
+ total_timesteps: int = 1000,
303
+ start_timesteps=0,
304
+ composition=False,
305
+ ):
306
+ # latents_steps = []
307
+
308
+ for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
309
+ if not composition:
310
+ noise_pred = predict_noise_xl(
311
+ unet,
312
+ scheduler,
313
+ timestep,
314
+ latents,
315
+ text_embeddings,
316
+ add_text_embeddings,
317
+ add_time_ids,
318
+ guidance_scale=guidance_scale,
319
+ guidance_rescale=0.7,
320
+ )
321
+ if guidance_scale==1:
322
+ _, noise_pred = noise_pred.chunk(2)
323
+ # compute the previous noisy sample x_t -> x_t-1
324
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
325
+
326
+ # return latents_steps
327
+ return latents
328
+
329
+
330
+ # for XL
331
+ def get_add_time_ids(
332
+ height: int,
333
+ width: int,
334
+ dynamic_crops: bool = False,
335
+ dtype: torch.dtype = torch.float32,
336
+ ):
337
+ if dynamic_crops:
338
+ # random float scale between 1 and 3
339
+ random_scale = torch.rand(1).item() * 2 + 1
340
+ original_size = (int(height * random_scale), int(width * random_scale))
341
+ # random position
342
+ crops_coords_top_left = (
343
+ torch.randint(0, original_size[0] - height, (1,)).item(),
344
+ torch.randint(0, original_size[1] - width, (1,)).item(),
345
+ )
346
+ target_size = (height, width)
347
+ else:
348
+ original_size = (height, width)
349
+ crops_coords_top_left = (0, 0)
350
+ target_size = (height, width)
351
+
352
+ # this is expected as 6
353
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
354
+
355
+ # this is expected as 2816
356
+ passed_add_embed_dim = (
357
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
358
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
359
+ )
360
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
361
+ raise ValueError(
362
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
363
+ )
364
+
365
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
366
+ return add_time_ids
367
+
368
+
369
+ def get_optimizer(name: str):
370
+ name = name.lower()
371
+
372
+ if name.startswith("dadapt"):
373
+ import dadaptation
374
+
375
+ if name == "dadaptadam":
376
+ return dadaptation.DAdaptAdam
377
+ elif name == "dadaptlion":
378
+ return dadaptation.DAdaptLion
379
+ else:
380
+ raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
381
+
382
+ elif name.endswith("8bit"): # 検証してない
383
+ import bitsandbytes as bnb
384
+
385
+ if name == "adam8bit":
386
+ return bnb.optim.Adam8bit
387
+ elif name == "lion8bit":
388
+ return bnb.optim.Lion8bit
389
+ else:
390
+ raise ValueError("8bit optimizer must be adam8bit or lion8bit")
391
+
392
+ else:
393
+ if name == "adam":
394
+ return torch.optim.Adam
395
+ elif name == "adamw":
396
+ return torch.optim.AdamW
397
+ elif name == "lion":
398
+ from lion_pytorch import Lion
399
+
400
+ return Lion
401
+ elif name == "prodigy":
402
+ import prodigyopt
403
+
404
+ return prodigyopt.Prodigy
405
+ else:
406
+ raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
407
+
408
+ @torch.no_grad()
409
+ def get_noisy_image(
410
+ image,
411
+ vae,
412
+ unet,
413
+ scheduler,
414
+ timesteps_to = 1000,
415
+ generator=None,
416
+ **kwargs,
417
+ ):
418
+ # latents_steps = []
419
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
420
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
421
+
422
+ device = vae.device
423
+ image = image_processor.preprocess(image).to(device).to(vae.dtype)
424
+
425
+ init_latents = vae.encode(image).latents
426
+
427
+ init_latents = vae.config.scaling_factor * init_latents
428
+
429
+ init_latents = torch.cat([init_latents], dim=0)
430
+
431
+ shape = init_latents.shape
432
+
433
+ noise = randn_tensor(shape, generator=generator, device=device)
434
+
435
+ timestep = scheduler.timesteps[timesteps_to:timesteps_to+1]
436
+ # get latents
437
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
438
+
439
+ return init_latents, noise
440
+
441
+
442
+ def get_lr_scheduler(
443
+ name: Optional[str],
444
+ optimizer: torch.optim.Optimizer,
445
+ max_iterations: Optional[int],
446
+ lr_min: Optional[float],
447
+ **kwargs,
448
+ ):
449
+ if name == "cosine":
450
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
451
+ optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
452
+ )
453
+ elif name == "cosine_with_restarts":
454
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
455
+ optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
456
+ )
457
+ elif name == "step":
458
+ return torch.optim.lr_scheduler.StepLR(
459
+ optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
460
+ )
461
+ elif name == "constant":
462
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
463
+ elif name == "linear":
464
+ return torch.optim.lr_scheduler.LinearLR(
465
+ optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
466
+ )
467
+ else:
468
+ raise ValueError(
469
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
470
+ )
471
+
472
+
473
+ def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
474
+ max_resolution = bucket_resolution
475
+ min_resolution = bucket_resolution // 2
476
+
477
+ step = 64
478
+
479
+ min_step = min_resolution // step
480
+ max_step = max_resolution // step
481
+
482
+ height = torch.randint(min_step, max_step, (1,)).item() * step
483
+ width = torch.randint(min_step, max_step, (1,)).item() * step
484
+
485
+ return height, width
486
+
487
+
488
+
489
+ def _get_t5_prompt_embeds(
490
+ text_encoder,
491
+ tokenizer,
492
+ prompt,
493
+ max_sequence_length=512,
494
+ device=None,
495
+ dtype=None
496
+ ):
497
+ """Helper function to get T5 embeddings in Flux format"""
498
+ device = device or text_encoder.device
499
+ dtype = dtype or text_encoder.dtype
500
+
501
+ prompt = [prompt] if isinstance(prompt, str) else prompt
502
+ batch_size = len(prompt)
503
+
504
+ text_inputs = tokenizer(
505
+ prompt,
506
+ padding="max_length",
507
+ max_length=max_sequence_length,
508
+ truncation=True,
509
+ return_length=False,
510
+ return_overflowing_tokens=False,
511
+ return_tensors="pt",
512
+ )
513
+ text_input_ids = text_inputs.input_ids
514
+
515
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
516
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
517
+
518
+ return prompt_embeds
519
+
520
+ def _get_clip_prompt_embeds(
521
+ text_encoder,
522
+ tokenizer,
523
+ prompt,
524
+ device=None,
525
+ ):
526
+ """Helper function to get CLIP embeddings in Flux format"""
527
+ device = device or text_encoder.device
528
+
529
+ prompt = [prompt] if isinstance(prompt, str) else prompt
530
+ batch_size = len(prompt)
531
+
532
+ text_inputs = tokenizer(
533
+ prompt,
534
+ padding="max_length",
535
+ max_length=tokenizer.model_max_length,
536
+ truncation=True,
537
+ return_overflowing_tokens=False,
538
+ return_length=False,
539
+ return_tensors="pt",
540
+ )
541
+
542
+ text_input_ids = text_inputs.input_ids
543
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
544
+
545
+ # Use pooled output for Flux
546
+ prompt_embeds = prompt_embeds.pooler_output
547
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
548
+
549
+ return prompt_embeds
550
+
551
+
552
+
553
+
554
+
555
+ @torch.no_grad()
556
+ def get_noisy_image_flux(
557
+ image,
558
+ vae,
559
+ transformer,
560
+ scheduler,
561
+ timesteps_to=1000,
562
+ generator=None,
563
+ params = None
564
+ ):
565
+ """
566
+ Gets noisy latents for a given image using Flux pipeline approach.
567
+
568
+ Args:
569
+ image (Union[PIL.Image.Image, torch.Tensor]): Input image
570
+ vae (AutoencoderKL): Flux VAE model
571
+ transformer (FluxTransformer2DModel): Flux transformer model
572
+ scheduler (FlowMatchEulerDiscreteScheduler): Flux noise scheduler
573
+ timesteps_to (int, optional): Target timestep. Defaults to 1000.
574
+ generator (torch.Generator, optional): Random generator for reproducibility.
575
+
576
+ Returns:
577
+ tuple: (noisy_latents, noise) - Both in packed Flux format
578
+ """
579
+
580
+ vae_scale_factor = params['vae_scale_factor']
581
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
582
+
583
+ image = image_processor.preprocess(image, height=params['height'], width=params['width'])
584
+ image = image.to(dtype=torch.float32)
585
+
586
+ # 5. Prepare latent variables
587
+ num_channels_latents = transformer.config.in_channels // 4
588
+
589
+ latents, latent_image_ids = prepare_latents_flux(
590
+ image,
591
+ timesteps_to.repeat(params['batchsize']),
592
+ params['batchsize'],
593
+ num_channels_latents,
594
+ params['height'],
595
+ params['width'],
596
+ transformer.dtype,
597
+ transformer.device,
598
+ generator,
599
+ None,
600
+ vae_scale_factor,
601
+ vae,
602
+ scheduler
603
+ )
604
+
605
+ return latents, latent_image_ids
606
+
607
+
608
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
609
+ """
610
+ Pack latents into Flux's 2x2 patch format
611
+ """
612
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
613
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
614
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
615
+ return latents
616
+
617
+
618
+ def _unpack_latents(latents, height, width, vae_scale_factor):
619
+ """
620
+ Unpack latents from Flux's 2x2 patch format back to image space
621
+ """
622
+ batch_size, num_patches, channels = latents.shape
623
+
624
+ # Account for VAE compression and packing
625
+ height = 2 * (int(height) // (vae_scale_factor * 2))
626
+ width = 2 * (int(width) // (vae_scale_factor * 2))
627
+
628
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
629
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
630
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
631
+
632
+ return latents
633
+
634
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
635
+ latent_image_ids = torch.zeros(height, width, 3)
636
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
637
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
638
+
639
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
640
+
641
+ latent_image_ids = latent_image_ids.reshape(
642
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
643
+ )
644
+
645
+ return latent_image_ids.to(device=device, dtype=dtype)
646
+
647
+
648
+ def prepare_latents_flux(
649
+ image,
650
+ timestep,
651
+ batch_size,
652
+ num_channels_latents,
653
+ height,
654
+ width,
655
+ dtype,
656
+ device,
657
+ generator,
658
+ latents=None,
659
+ vae_scale_factor=None,
660
+ vae=None,
661
+ scheduler=None
662
+ ):
663
+ if isinstance(generator, list) and len(generator) != batch_size:
664
+ raise ValueError(
665
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
666
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
667
+ )
668
+
669
+ # VAE applies 8x compression on images but we must also account for packing which requires
670
+ # latent height and width to be divisible by 2.
671
+ height = 2 * (int(height) // (vae_scale_factor * 2))
672
+ width = 2 * (int(width) // (vae_scale_factor * 2))
673
+ shape = (batch_size, num_channels_latents, height, width)
674
+ latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
675
+
676
+ if latents is not None:
677
+ return latents.to(device=device, dtype=dtype), latent_image_ids
678
+
679
+ image = image.to(device=device, dtype=dtype)
680
+ image_latents = _encode_vae_image(vae=vae, image=image, generator=generator)
681
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
682
+ # expand init_latents for batch_size
683
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
684
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
685
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
686
+ raise ValueError(
687
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
688
+ )
689
+ else:
690
+ image_latents = torch.cat([image_latents], dim=0)
691
+
692
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
693
+ latents = scheduler.scale_noise(image_latents, timestep, noise)
694
+ latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
695
+ return latents, latent_image_ids
696
+
697
+
698
+ def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
699
+ if isinstance(generator, list):
700
+ image_latents = [
701
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
702
+ for i in range(image.shape[0])
703
+ ]
704
+ image_latents = torch.cat(image_latents, dim=0)
705
+ else:
706
+ image_latents = retrieve_latents(vae.encode(image), generator=generator)
707
+
708
+ image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
709
+ return image_latents
710
+
711
+
712
+ def retrieve_latents(
713
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
714
+ ):
715
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
716
+ return encoder_output.latent_dist.sample(generator)
717
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
718
+ return encoder_output.latent_dist.mode()
719
+ elif hasattr(encoder_output, "latents"):
720
+ return encoder_output.latents
721
+ else:
722
+ raise AttributeError("Could not access latents of provided encoder_output")
utils/utils.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ client = anthropic.Anthropic()
3
+ from diffusers.image_processor import VaeImageProcessor
4
+ from typing import List, Optional
5
+ import argparse
6
+ import ast
7
+ import pandas as pd
8
+ from pathlib import Path
9
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderTiny
10
+ from huggingface_hub import hf_hub_download
11
+ import gc
12
+ import torch.nn.functional as F
13
+ import os
14
+ import torch
15
+ from tqdm.auto import tqdm
16
+ import time, datetime
17
+ import numpy as np
18
+ from torch.optim import AdamW
19
+ from contextlib import ExitStack
20
+ from safetensors.torch import load_file
21
+ import torch.nn as nn
22
+ import random
23
+ from transformers import CLIPModel
24
+
25
+ import sys
26
+ import argparse
27
+ import wandb
28
+ from diffusers import AutoencoderKL
29
+ from diffusers.image_processor import VaeImageProcessor
30
+
31
+ sys.path.append('../')
32
+ from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
33
+
34
+ from transformers import logging
35
+ logging.set_verbosity_warning()
36
+ import matplotlib.pyplot as plt
37
+ from diffusers import logging
38
+ logging.set_verbosity_error()
39
+ modules = DEFAULT_TARGET_REPLACE
40
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from sklearn.decomposition import PCA
44
+ import random
45
+ import gc
46
+ import diffusers
47
+ from diffusers import DiffusionPipeline, FluxPipeline
48
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, SchedulerMixin
49
+ from diffusers.loaders import AttnProcsLayers
50
+ from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
51
+ from typing import Any, Dict, List, Optional, Tuple, Union
52
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+
55
+ import inspect
56
+ import os
57
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
58
+ from diffusers.pipelines import StableDiffusionXLPipeline
59
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
60
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
61
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
62
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import XLA_AVAILABLE
63
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
64
+
65
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
66
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
67
+
68
+ import sys
69
+ sys.path.append('../.')
70
+ from utils.flux_utils import *
71
+ import random
72
+
73
+ import torch
74
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
75
+
76
+
77
+ def flush():
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
80
+
81
+ def calculate_shift(
82
+ image_seq_len,
83
+ base_seq_len: int = 256,
84
+ max_seq_len: int = 4096,
85
+ base_shift: float = 0.5,
86
+ max_shift: float = 1.16,
87
+ ):
88
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
89
+ b = base_shift - m * base_seq_len
90
+ mu = image_seq_len * m + b
91
+ return mu
92
+
93
+
94
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
95
+ def retrieve_timesteps(
96
+ scheduler,
97
+ num_inference_steps: Optional[int] = None,
98
+ device: Optional[Union[str, torch.device]] = None,
99
+ timesteps: Optional[List[int]] = None,
100
+ sigmas: Optional[List[float]] = None,
101
+ **kwargs,
102
+ ):
103
+ """
104
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
105
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
106
+
107
+ Args:
108
+ scheduler (`SchedulerMixin`):
109
+ The scheduler to get timesteps from.
110
+ num_inference_steps (`int`):
111
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
112
+ must be `None`.
113
+ device (`str` or `torch.device`, *optional*):
114
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
115
+ timesteps (`List[int]`, *optional*):
116
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
117
+ `num_inference_steps` and `sigmas` must be `None`.
118
+ sigmas (`List[float]`, *optional*):
119
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
120
+ `num_inference_steps` and `timesteps` must be `None`.
121
+
122
+ Returns:
123
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
124
+ second element is the number of inference steps.
125
+ """
126
+ if timesteps is not None and sigmas is not None:
127
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
128
+ if timesteps is not None:
129
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
130
+ if not accepts_timesteps:
131
+ raise ValueError(
132
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
133
+ f" timestep schedules. Please check whether you are using the correct scheduler."
134
+ )
135
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ num_inference_steps = len(timesteps)
138
+ elif sigmas is not None:
139
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
140
+ if not accept_sigmas:
141
+ raise ValueError(
142
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
143
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
144
+ )
145
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ num_inference_steps = len(timesteps)
148
+ else:
149
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
150
+ timesteps = scheduler.timesteps
151
+ return timesteps, num_inference_steps
152
+
153
+ def claude_generate_prompts_sliders(prompt,
154
+ num_prompts=20,
155
+ temperature=0.2,
156
+ max_tokens=2000,
157
+ frequency_penalty=0.0,
158
+ model="claude-3-5-sonnet-20240620",
159
+ verbose=False,
160
+ train_type='concept'):
161
+ gpt_assistant_prompt = f''' You are an expert in writing diverse image captions. When i provide a prompt, I want you to give me {num_prompts} alternative prompts that is similar to the provided prompt but produces diverse images. Be creative and make sure the original subjects in the original prompt are present in your prompts. Make sure that you end the prompts with keywords that will produce high quality images like ",detailed, 8k" or ",hyper-realistic, 4k".
162
+
163
+ Give me the expanded prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
164
+ I need you to give me only the python list and nothing else. Do not explain yourself
165
+
166
+ example output format:
167
+ ["prompt1", "prompt2", ...]
168
+ '''
169
+
170
+ if train_type == 'art':
171
+ gpt_assistant_prompt = f'''You are an expert in writing art image captions. I want you to generate prompts that would create diverse artwork images.
172
+ Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output creative and interesting artwork images with unique and diverse artistic styles. A prompt could like "an <object/landscape> in the style of <an artist>" or "an <object/landscape> in the style of <an artistic style (e.g. cubism)>". make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
173
+
174
+ Give me the prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
175
+ I need you to give me only the python list and nothing else. Do not explain yourself
176
+
177
+ example output format:
178
+ ["prompt1", "prompt2", ...]
179
+ '''
180
+ # if 'dog' in prompt:
181
+ # gpt_assistant_prompt = f'''You are an expert in prompting text-image generation models. I want you to generate simple prompts that would trigger the image generation model to generate a unique dog breeds.
182
+ # Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output diverse and interesting dog breeds with unique and diverse looks. make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
183
+
184
+ # Be creative and make sure to remember diversity is the key. Give me the prompts in the form of a list. start with a [ and end with ] do not add any special characters like \n
185
+ # I need you to give me only the python list and nothing else. Do not explain yourself
186
+
187
+ # example output format:
188
+ # ["prompt1", "prompt2", ...]
189
+ # '''
190
+
191
+ if train_type == 'artclaudesemantics':
192
+ gpt_assistant_prompt = f'''You are an expert in prompting text-image generation models. I want you to generate simple prompts that would trigger the image generation model to generate a unique artistic images but DO NOT SPECIFY THE ART STYLE.
193
+ Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output diverse and interesting art images. Usually like "<some object or scene> in the style of " or "<some object or scene> in style of". Always end your prompts with "in the style of" so that i can manually add the style i want. make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
194
+
195
+ Be creative and make sure to remember diversity is the key. Give me the prompts in the form of a list. start with a [ and end with ] do not add any special characters like \n
196
+ I need you to give me only the python list and nothing else. Do not explain yourself
197
+
198
+ example output format:
199
+ ["prompt1", "prompt2", ...]
200
+ '''
201
+ gpt_user_prompt = prompt
202
+ gpt_prompt = gpt_assistant_prompt, gpt_user_prompt
203
+ message=[
204
+ {
205
+ "role": "user",
206
+ "content": [
207
+ {
208
+ "type": "text",
209
+ "text": gpt_user_prompt
210
+ }
211
+ ]
212
+ }
213
+ ]
214
+
215
+ output = client.messages.create(
216
+ model=model,
217
+ max_tokens=max_tokens,
218
+ temperature=temperature,
219
+ system=gpt_assistant_prompt,
220
+ messages=message
221
+ )
222
+ content = output.content[0].text
223
+ return content
224
+
225
+ def normalize_image(image):
226
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(image.device)
227
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(image.device)
228
+ return (image - mean) / std
229
+
230
+
231
+ @torch.no_grad()
232
+ def call_sdxl(
233
+ self,
234
+ prompt: Union[str, List[str]] = None,
235
+ prompt_2: Optional[Union[str, List[str]]] = None,
236
+ height: Optional[int] = None,
237
+ width: Optional[int] = None,
238
+ num_inference_steps: int = 50,
239
+ timesteps: List[int] = None,
240
+ sigmas: List[float] = None,
241
+ denoising_end: Optional[float] = None,
242
+ guidance_scale: float = 5.0,
243
+ negative_prompt: Optional[Union[str, List[str]]] = None,
244
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
245
+ num_images_per_prompt: Optional[int] = 1,
246
+ eta: float = 0.0,
247
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
248
+ latents: Optional[torch.Tensor] = None,
249
+ prompt_embeds: Optional[torch.Tensor] = None,
250
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
251
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
252
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
253
+ ip_adapter_image: Optional[PipelineImageInput] = None,
254
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
255
+ output_type: Optional[str] = "pil",
256
+ return_dict: bool = True,
257
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
258
+ guidance_rescale: float = 0.0,
259
+ original_size: Optional[Tuple[int, int]] = None,
260
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
261
+ target_size: Optional[Tuple[int, int]] = None,
262
+ negative_original_size: Optional[Tuple[int, int]] = None,
263
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
264
+ negative_target_size: Optional[Tuple[int, int]] = None,
265
+ clip_skip: Optional[int] = None,
266
+ callback_on_step_end: Optional[
267
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
268
+ ] = None,
269
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
270
+ save_timesteps = None,
271
+ clip=None,
272
+ use_clip=True,
273
+ encoder='clip',
274
+ ):
275
+
276
+ callback = None
277
+ callback_steps = None
278
+
279
+ if callback is not None:
280
+ deprecate(
281
+ "callback",
282
+ "1.0.0",
283
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
284
+ )
285
+ if callback_steps is not None:
286
+ deprecate(
287
+ "callback_steps",
288
+ "1.0.0",
289
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
290
+ )
291
+
292
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
293
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
294
+
295
+ # 0. Default height and width to unet
296
+ height = height or self.default_sample_size * self.vae_scale_factor
297
+ width = width or self.default_sample_size * self.vae_scale_factor
298
+
299
+ original_size = original_size or (height, width)
300
+ target_size = target_size or (height, width)
301
+
302
+ # 1. Check inputs. Raise error if not correct
303
+ self.check_inputs(
304
+ prompt,
305
+ prompt_2,
306
+ height,
307
+ width,
308
+ callback_steps,
309
+ negative_prompt,
310
+ negative_prompt_2,
311
+ prompt_embeds,
312
+ negative_prompt_embeds,
313
+ pooled_prompt_embeds,
314
+ negative_pooled_prompt_embeds,
315
+ ip_adapter_image,
316
+ ip_adapter_image_embeds,
317
+ callback_on_step_end_tensor_inputs,
318
+ )
319
+
320
+ self._guidance_scale = guidance_scale
321
+ self._guidance_rescale = guidance_rescale
322
+ self._clip_skip = clip_skip
323
+ self._cross_attention_kwargs = cross_attention_kwargs
324
+ self._denoising_end = denoising_end
325
+ self._interrupt = False
326
+
327
+ # 2. Define call parameters
328
+ if prompt is not None and isinstance(prompt, str):
329
+ batch_size = 1
330
+ elif prompt is not None and isinstance(prompt, list):
331
+ batch_size = len(prompt)
332
+ else:
333
+ batch_size = prompt_embeds.shape[0]
334
+
335
+ device = self._execution_device
336
+
337
+ # 3. Encode input prompt
338
+ lora_scale = (
339
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
340
+ )
341
+
342
+ (
343
+ prompt_embeds,
344
+ negative_prompt_embeds,
345
+ pooled_prompt_embeds,
346
+ negative_pooled_prompt_embeds,
347
+ ) = self.encode_prompt(
348
+ prompt=prompt,
349
+ prompt_2=prompt_2,
350
+ device=device,
351
+ num_images_per_prompt=num_images_per_prompt,
352
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
353
+ negative_prompt=negative_prompt,
354
+ negative_prompt_2=negative_prompt_2,
355
+ prompt_embeds=prompt_embeds,
356
+ negative_prompt_embeds=negative_prompt_embeds,
357
+ pooled_prompt_embeds=pooled_prompt_embeds,
358
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
359
+ lora_scale=lora_scale,
360
+ clip_skip=self.clip_skip,
361
+ )
362
+
363
+ # 4. Prepare timesteps
364
+ timesteps, num_inference_steps = retrieve_timesteps(
365
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
366
+ )
367
+
368
+ # 5. Prepare latent variables
369
+ num_channels_latents = self.unet.config.in_channels
370
+ latents = self.prepare_latents(
371
+ batch_size * num_images_per_prompt,
372
+ num_channels_latents,
373
+ height,
374
+ width,
375
+ prompt_embeds.dtype,
376
+ device,
377
+ generator,
378
+ latents,
379
+ )
380
+
381
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
382
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
383
+
384
+ # 7. Prepare added time ids & embeddings
385
+ add_text_embeds = pooled_prompt_embeds
386
+ if self.text_encoder_2 is None:
387
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
388
+ else:
389
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
390
+
391
+ add_time_ids = self._get_add_time_ids(
392
+ original_size,
393
+ crops_coords_top_left,
394
+ target_size,
395
+ dtype=prompt_embeds.dtype,
396
+ text_encoder_projection_dim=text_encoder_projection_dim,
397
+ )
398
+ if negative_original_size is not None and negative_target_size is not None:
399
+ negative_add_time_ids = self._get_add_time_ids(
400
+ negative_original_size,
401
+ negative_crops_coords_top_left,
402
+ negative_target_size,
403
+ dtype=prompt_embeds.dtype,
404
+ text_encoder_projection_dim=text_encoder_projection_dim,
405
+ )
406
+ else:
407
+ negative_add_time_ids = add_time_ids
408
+
409
+ if self.do_classifier_free_guidance:
410
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
411
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
412
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
413
+
414
+ prompt_embeds = prompt_embeds.to(device)
415
+ add_text_embeds = add_text_embeds.to(device)
416
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
417
+
418
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
419
+ image_embeds = self.prepare_ip_adapter_image_embeds(
420
+ ip_adapter_image,
421
+ ip_adapter_image_embeds,
422
+ device,
423
+ batch_size * num_images_per_prompt,
424
+ self.do_classifier_free_guidance,
425
+ )
426
+
427
+ # 8. Denoising loop
428
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
429
+
430
+ # 8.1 Apply denoising_end
431
+ if (
432
+ self.denoising_end is not None
433
+ and isinstance(self.denoising_end, float)
434
+ and self.denoising_end > 0
435
+ and self.denoising_end < 1
436
+ ):
437
+ discrete_timestep_cutoff = int(
438
+ round(
439
+ self.scheduler.config.num_train_timesteps
440
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
441
+ )
442
+ )
443
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
444
+ timesteps = timesteps[:num_inference_steps]
445
+
446
+ # 9. Optionally get Guidance Scale Embedding
447
+ timestep_cond = None
448
+ if self.unet.config.time_cond_proj_dim is not None:
449
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
450
+ timestep_cond = self.get_guidance_scale_embedding(
451
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
452
+ ).to(device=device, dtype=latents.dtype)
453
+
454
+ self._num_timesteps = len(timesteps)
455
+ clip_features = []
456
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
457
+ for i, t in enumerate(timesteps):
458
+ if self.interrupt:
459
+ continue
460
+
461
+ # expand the latents if we are doing classifier free guidance
462
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
463
+
464
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
465
+
466
+ # predict the noise residual
467
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
468
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
469
+ added_cond_kwargs["image_embeds"] = image_embeds
470
+ noise_pred = self.unet(
471
+ latent_model_input,
472
+ t,
473
+ encoder_hidden_states=prompt_embeds,
474
+ timestep_cond=timestep_cond,
475
+ cross_attention_kwargs=self.cross_attention_kwargs,
476
+ added_cond_kwargs=added_cond_kwargs,
477
+ return_dict=False,
478
+ )[0]
479
+
480
+ # perform guidance
481
+ if self.do_classifier_free_guidance:
482
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
483
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
484
+
485
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
486
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
487
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
488
+
489
+ # compute the previous noisy sample x_t -> x_t-1
490
+ latents_dtype = latents.dtype
491
+ # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
492
+
493
+ # compute the previous noisy sample x_t -> x_t-1
494
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
495
+ try:
496
+ denoised = latents['pred_original_sample'] / self.vae.config.scaling_factor
497
+ except:
498
+ denoised = latents['denoised'] / self.vae.config.scaling_factor
499
+ latents = latents['prev_sample']
500
+
501
+
502
+ # if latents.dtype != latents_dtype:
503
+ # if torch.backends.mps.is_available():
504
+ # # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
505
+ latents = latents.to(self.vae.dtype)
506
+ denoised = denoised.to(self.vae.dtype)
507
+
508
+ if i in save_timesteps:
509
+ if use_clip:
510
+ denoised = self.vae.decode(denoised.to(self.vae.dtype), return_dict=False)[0]
511
+ denoised = F.adaptive_avg_pool2d(denoised, (224, 224))
512
+ denoised = normalize_image(denoised)
513
+ if 'dino' in encoder:
514
+ denoised = clip(denoised)
515
+ denoised = denoised.pooler_output
516
+ denoised = denoised.cpu().view(denoised.shape[0], -1)
517
+ else:
518
+ denoised = clip.get_image_features(denoised)
519
+ denoised = denoised.cpu().view(denoised.shape[0], -1)
520
+
521
+ # denoised = clip.get_image_features(denoised)
522
+ clip_features.append(denoised)
523
+
524
+
525
+
526
+
527
+ if callback_on_step_end is not None:
528
+ callback_kwargs = {}
529
+ for k in callback_on_step_end_tensor_inputs:
530
+ callback_kwargs[k] = locals()[k]
531
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
532
+
533
+ latents = callback_outputs.pop("latents", latents)
534
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
535
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
536
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
537
+ negative_pooled_prompt_embeds = callback_outputs.pop(
538
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
539
+ )
540
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
541
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
542
+
543
+ # call the callback, if provided
544
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
545
+ # progress_bar.update()
546
+ if callback is not None and i % callback_steps == 0:
547
+ step_idx = i // getattr(self.scheduler, "order", 1)
548
+ callback(step_idx, t, latents)
549
+
550
+ if XLA_AVAILABLE:
551
+ xm.mark_step()
552
+
553
+ if not output_type == "latent":
554
+ # make sure the VAE is in float32 mode, as it overflows in float16
555
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
556
+
557
+ if needs_upcasting:
558
+ self.upcast_vae()
559
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
560
+ elif latents.dtype != self.vae.dtype:
561
+ if torch.backends.mps.is_available():
562
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
563
+ self.vae = self.vae.to(latents.dtype)
564
+
565
+ # unscale/denormalize the latents
566
+ # denormalize with the mean and std if available and not None
567
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
568
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
569
+ if has_latents_mean and has_latents_std:
570
+ latents_mean = (
571
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
572
+ )
573
+ latents_std = (
574
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
575
+ )
576
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
577
+ else:
578
+ latents = latents / self.vae.config.scaling_factor
579
+
580
+ image = self.vae.decode(latents, return_dict=False)[0]
581
+
582
+ # cast back to fp16 if needed
583
+ if needs_upcasting:
584
+ self.vae.to(dtype=torch.float16)
585
+ else:
586
+ image = latents
587
+
588
+ if not output_type == "latent":
589
+
590
+ image = self.image_processor.postprocess(image, output_type=output_type)
591
+
592
+ # Offload all models
593
+ self.maybe_free_model_hooks()
594
+
595
+ return image, clip_features
596
+
597
+ @torch.no_grad()
598
+
599
+ def call_flux(
600
+ self,
601
+ prompt: Union[str, List[str]] = None,
602
+ prompt_2: Optional[Union[str, List[str]]] = None,
603
+ height: Optional[int] = None,
604
+ width: Optional[int] = None,
605
+ num_inference_steps: int = 28,
606
+ timesteps: List[int] = None,
607
+ guidance_scale: float = 7.0,
608
+ num_images_per_prompt: Optional[int] = 1,
609
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
610
+ latents: Optional[torch.FloatTensor] = None,
611
+ prompt_embeds: Optional[torch.FloatTensor] = None,
612
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
613
+ output_type: Optional[str] = "pil",
614
+ return_dict: bool = True,
615
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
616
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
617
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
618
+ max_sequence_length: int = 512,
619
+ verbose=False,
620
+ save_timesteps = None,
621
+ clip=None,
622
+ use_clip=True,
623
+ encoder='clip'
624
+ ):
625
+
626
+
627
+ height = height or self.default_sample_size * self.vae_scale_factor
628
+ width = width or self.default_sample_size * self.vae_scale_factor
629
+
630
+ # 1. Check inputs. Raise error if not correct
631
+ self.check_inputs(
632
+ prompt,
633
+ prompt_2,
634
+ height,
635
+ width,
636
+ prompt_embeds=prompt_embeds,
637
+ pooled_prompt_embeds=pooled_prompt_embeds,
638
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
639
+ max_sequence_length=max_sequence_length,
640
+ )
641
+
642
+ self._guidance_scale = guidance_scale
643
+ self._joint_attention_kwargs = joint_attention_kwargs
644
+ self._interrupt = False
645
+
646
+ # 2. Define call parameters
647
+ if prompt is not None and isinstance(prompt, str):
648
+ batch_size = 1
649
+ elif prompt is not None and isinstance(prompt, list):
650
+ batch_size = len(prompt)
651
+ else:
652
+ batch_size = prompt_embeds.shape[0]
653
+
654
+ device = self._execution_device
655
+
656
+ lora_scale = (
657
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
658
+ )
659
+ (
660
+ prompt_embeds,
661
+ pooled_prompt_embeds,
662
+ text_ids,
663
+ ) = self.encode_prompt(
664
+ prompt=prompt,
665
+ prompt_2=prompt_2,
666
+ prompt_embeds=prompt_embeds,
667
+ pooled_prompt_embeds=pooled_prompt_embeds,
668
+ device=device,
669
+ num_images_per_prompt=num_images_per_prompt,
670
+ max_sequence_length=max_sequence_length,
671
+ lora_scale=lora_scale,
672
+ )
673
+
674
+ # 4. Prepare latent variables
675
+ num_channels_latents = self.transformer.config.in_channels // 4
676
+ latents, latent_image_ids = self.prepare_latents(
677
+ batch_size * num_images_per_prompt,
678
+ num_channels_latents,
679
+ height,
680
+ width,
681
+ prompt_embeds.dtype,
682
+ device,
683
+ generator,
684
+ latents,
685
+ )
686
+
687
+ # 5. Prepare timesteps
688
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
689
+ image_seq_len = latents.shape[1]
690
+ mu = calculate_shift(
691
+ image_seq_len,
692
+ self.scheduler.config.base_image_seq_len,
693
+ self.scheduler.config.max_image_seq_len,
694
+ self.scheduler.config.base_shift,
695
+ self.scheduler.config.max_shift,
696
+ )
697
+ timesteps, num_inference_steps = retrieve_timesteps(
698
+ self.scheduler,
699
+ num_inference_steps,
700
+ device,
701
+ timesteps,
702
+ sigmas,
703
+ mu=mu,
704
+ )
705
+
706
+ timesteps = timesteps
707
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
708
+ self._num_timesteps = len(timesteps)
709
+
710
+ # handle guidance
711
+ if self.transformer.config.guidance_embeds:
712
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
713
+ guidance = guidance.expand(latents.shape[0])
714
+ else:
715
+ guidance = None
716
+ clip_features = []
717
+ # 6. Denoising loop
718
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
719
+ for i, t in enumerate(timesteps):
720
+ if self.interrupt:
721
+ continue
722
+
723
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
724
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
725
+
726
+ noise_pred = self.transformer(
727
+ hidden_states=latents,
728
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
729
+ timestep=timestep / 1000,
730
+ guidance=guidance,
731
+ pooled_projections=pooled_prompt_embeds,
732
+ encoder_hidden_states=prompt_embeds,
733
+ txt_ids=text_ids,
734
+ img_ids=latent_image_ids,
735
+ joint_attention_kwargs=self.joint_attention_kwargs,
736
+ return_dict=False,
737
+ )[0]
738
+
739
+ # compute the previous noisy sample x_t -> x_t-1
740
+ latents_dtype = latents.dtype
741
+ # compute the previous noisy sample x_t -> x_t-1
742
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=True)
743
+
744
+
745
+ denoised = latents['prev_sample']
746
+ latents = latents['prev_sample']
747
+
748
+ denoised = self._unpack_latents(denoised, height, width, self.vae_scale_factor)
749
+ denoised = (denoised / self.vae.config.scaling_factor) + self.vae.config.shift_factor
750
+ denoised = self.vae.decode(denoised, return_dict=False)[0]
751
+ denoised = F.adaptive_avg_pool2d(denoised, (224, 224))
752
+ if 'dino' in encoder:
753
+ outputs = clip(**inputs)
754
+ denoised = outputs.pooler_output
755
+ denoised = denoised.cpu().view(denoised.shape[0], -1)
756
+ else:
757
+ denoised = clip.get_image_features(denoised)
758
+ denoised = denoised.cpu().view(denoised.shape[0], -1)
759
+
760
+ clip_features.append()
761
+
762
+ if latents.dtype != latents_dtype:
763
+ if torch.backends.mps.is_available():
764
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
765
+ latents = latents.to(latents_dtype)
766
+
767
+ if callback_on_step_end is not None:
768
+ callback_kwargs = {}
769
+ for k in callback_on_step_end_tensor_inputs:
770
+ callback_kwargs[k] = locals()[k]
771
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
772
+
773
+ latents = callback_outputs.pop("latents", latents)
774
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
775
+
776
+ # call the callback, if provided
777
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
778
+ progress_bar.update()
779
+
780
+ if XLA_AVAILABLE:
781
+ xm.mark_step()
782
+
783
+ if output_type == "latent":
784
+ image = latents
785
+ return image
786
+
787
+ else:
788
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
789
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
790
+ image = self.vae.decode(latents, return_dict=False)[0]
791
+ image = self.image_processor.postprocess(image, output_type=output_type)
792
+
793
+ # Offload all models
794
+ self.maybe_free_model_hooks()
795
+
796
+ if not return_dict:
797
+ return (image,)
798
+
799
+ return image, clip_features
800
+
801
+
802
+
803
+
804
+ def get_diffusion_clip_directions(prompts, unet, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True,encoder='clip'):
805
+ device = unet.device
806
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
807
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
808
+
809
+ os.makedirs(savepath_training_images, exist_ok=True)
810
+
811
+
812
+ if len(noise_scheduler.timesteps) != max_denoising_steps:
813
+ noise_scheduler_orig = noise_scheduler
814
+ max_denoising_steps_orig = len(noise_scheduler.timesteps)
815
+ noise_scheduler.set_timesteps(max_denoising_steps)
816
+ timesteps_distilled = noise_scheduler.timesteps
817
+
818
+ noise_scheduler.set_timesteps(max_denoising_steps_orig)
819
+ timesteps_full = noise_scheduler.timesteps
820
+ save_timesteps = []
821
+ for timesteps_to_distilled in range(max_denoising_steps):
822
+ # Get the value from timesteps_distilled that we want to find in timesteps_full
823
+ value_to_find = timesteps_distilled[timesteps_to_distilled]
824
+ timesteps_to_full = (timesteps_full == value_to_find).nonzero().item()
825
+ save_timesteps.append(timesteps_to_full)
826
+
827
+ guidance_scale = 7
828
+ else:
829
+ max_denoising_steps_orig = max_denoising_steps
830
+ save_timesteps = [i for i in range(max_denoising_steps_orig)]
831
+ guidance_scale = 7
832
+ if max_denoising_steps_orig <=4:
833
+ guidance_scale = 0
834
+
835
+ noise_scheduler.set_timesteps(max_denoising_steps_orig)
836
+ # if max_denoising_steps_orig == 1:
837
+ # noise_scheduler.set_timesteps(timesteps=[399],
838
+ # device=device)
839
+
840
+ weight_dtype = unet.dtype
841
+ device = unet.device
842
+ StableDiffusionXLPipeline.__call__ = call_sdxl
843
+ pipe = StableDiffusionXLPipeline(vae = vae,
844
+ text_encoder= text_encoders[0],
845
+ text_encoder_2=text_encoders[1],
846
+ tokenizer = tokenizers[0],
847
+ tokenizer_2= tokenizers[1],
848
+ unet=unet,
849
+ scheduler=noise_scheduler)
850
+ pipe.to(unet.device)
851
+ # print(guidance_scale, max_denoising_steps_orig, save_timesteps)
852
+ images, clip_features = pipe(prompts, guidance_scale=guidance_scale, num_inference_steps = max_denoising_steps_orig, clip=clip, save_timesteps =save_timesteps, use_clip=use_clip, encoder=encoder)
853
+
854
+ return images, torch.stack(clip_features)
855
+
856
+
857
+
858
+ def get_flux_clip_directions(prompts, transformer, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True):
859
+ device = transformer.device
860
+ FluxPipeline.__call__ = call_flux
861
+ pipe = FluxPipeline(noise_scheduler,
862
+ vae,
863
+ text_encoders[0],
864
+ tokenizers[0],
865
+ text_encoders[1],
866
+ tokenizers[1],
867
+ transformer,
868
+ )
869
+ pipe.set_progress_bar_config(disable=True)
870
+
871
+ os.makedirs(savepath_training_images, exist_ok=True)
872
+
873
+ images, clip_features = pipe(
874
+ prompts,
875
+ height=height,
876
+ width=width,
877
+ guidance_scale=0,
878
+ num_inference_steps=4,
879
+ max_sequence_length=256,
880
+ num_images_per_prompt=1,
881
+ output_type='pil',
882
+ clip=clip
883
+ )
884
+
885
+ return images, torch.stack(clip_features)
886
+
887
+
888
+
889
+
890
+ def get_diffusion_clip_directions(prompts, unet, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True,encoder='clip', num_images_per_prompt=1):
891
+
892
+
893
+ device = unet.device
894
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
895
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
896
+ os.makedirs(savepath_training_images, exist_ok=True)
897
+
898
+
899
+ if len(noise_scheduler.timesteps) != max_denoising_steps:
900
+ noise_scheduler_orig = noise_scheduler
901
+ max_denoising_steps_orig = len(noise_scheduler.timesteps)
902
+ noise_scheduler.set_timesteps(max_denoising_steps)
903
+ timesteps_distilled = noise_scheduler.timesteps
904
+
905
+ noise_scheduler.set_timesteps(max_denoising_steps_orig)
906
+ timesteps_full = noise_scheduler.timesteps
907
+ save_timesteps = []
908
+ for timesteps_to_distilled in range(max_denoising_steps):
909
+ # Get the value from timesteps_distilled that we want to find in timesteps_full
910
+ value_to_find = timesteps_distilled[timesteps_to_distilled]
911
+ timesteps_to_full = (timesteps_full == value_to_find).nonzero().item()
912
+ save_timesteps.append(timesteps_to_full)
913
+
914
+ guidance_scale = 7
915
+ else:
916
+ max_denoising_steps_orig = max_denoising_steps
917
+ save_timesteps = [i for i in range(max_denoising_steps_orig)]
918
+ guidance_scale = 7
919
+ if max_denoising_steps_orig <=4:
920
+ guidance_scale = 0
921
+
922
+ noise_scheduler.set_timesteps(max_denoising_steps_orig)
923
+ # if max_denoising_steps_orig == 1:
924
+ # noise_scheduler.set_timesteps(timesteps=[399],
925
+ # device=device)
926
+
927
+ weight_dtype = unet.dtype
928
+ device = unet.device
929
+ StableDiffusionXLPipeline.__call__ = call_sdxl
930
+ pipe = StableDiffusionXLPipeline(vae = vae,
931
+ text_encoder= text_encoders[0],
932
+ text_encoder_2=text_encoders[1],
933
+ tokenizer = tokenizers[0],
934
+ tokenizer_2= tokenizers[1],
935
+ unet=unet,
936
+ scheduler=noise_scheduler)
937
+ pipe.to(unet.device)
938
+ # print(guidance_scale, max_denoising_steps_orig, save_timesteps)
939
+ images, clip_features = pipe(prompts, guidance_scale=guidance_scale, num_inference_steps = max_denoising_steps_orig, clip=clip, save_timesteps =save_timesteps, use_clip=use_clip, encoder=encoder)
940
+
941
+ return images, torch.stack(clip_features)
942
+
943
+
944
+
945
+