Spaces:
Running
on
Zero
Running
on
Zero
adding utils for sliders
Browse files- utils/__init__.py +0 -0
- utils/clip_util.py +142 -0
- utils/flux_utils.py +404 -0
- utils/lora.py +320 -0
- utils/model_util.py +527 -0
- utils/prompt_util.py +80 -0
- utils/train_util.py +722 -0
- utils/utils.py +945 -0
utils/__init__.py
ADDED
File without changes
|
utils/clip_util.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
import math, random, os
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from sklearn.decomposition import PCA
|
8 |
+
|
9 |
+
|
10 |
+
def extract_clip_features(clip, image, encoder):
|
11 |
+
"""
|
12 |
+
Extracts feature embeddings from an image using either CLIP or DINOv2 models.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2)
|
16 |
+
image (torch.Tensor): Input image tensor normalized according to model requirements
|
17 |
+
encoder (str): Type of encoder to use ('dinov2-small' or 'clip')
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
torch.Tensor: Feature embeddings extracted from the image
|
21 |
+
|
22 |
+
Note:
|
23 |
+
- For DINOv2 models, uses the pooled output features
|
24 |
+
- For CLIP models, uses the image features from the vision encoder
|
25 |
+
- The input image should already be properly resized and normalized
|
26 |
+
"""
|
27 |
+
# Handle DINOv2 models
|
28 |
+
if 'dino' in encoder:
|
29 |
+
denoised = clip(image)
|
30 |
+
denoised = denoised.pooler_output
|
31 |
+
# Handle CLIP models
|
32 |
+
else:
|
33 |
+
denoised = clip.get_image_features(image)
|
34 |
+
|
35 |
+
return denoised
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def compute_clip_pca(
|
39 |
+
diverse_prompts: List[str],
|
40 |
+
pipe,
|
41 |
+
clip_model,
|
42 |
+
clip_processor,
|
43 |
+
device,
|
44 |
+
guidance_scale,
|
45 |
+
params,
|
46 |
+
total_samples = 5000,
|
47 |
+
num_pca_components = 100,
|
48 |
+
batch_size = 10
|
49 |
+
|
50 |
+
) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
Extract CLIP features from generated images based on prompts.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
diverse_prompts: List of prompts to generate images from
|
56 |
+
model_components: Various model components needed for generation
|
57 |
+
args: Training arguments
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tensor of CLIP principle components
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
# Calculate how many total batches we need
|
65 |
+
num_batches = math.ceil(total_samples / batch_size)
|
66 |
+
# Randomly sample prompts (with replacement if needed)
|
67 |
+
sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches)
|
68 |
+
|
69 |
+
clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt"
|
70 |
+
|
71 |
+
if os.path.exists(clip_features_path):
|
72 |
+
df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv")
|
73 |
+
prompts_training = list(df.prompt)
|
74 |
+
image_paths = list(df.image_path)
|
75 |
+
return torch.load(clip_features_path).to(device), prompts_training, image_paths
|
76 |
+
|
77 |
+
os.makedirs(params['savepath_training_images'], exist_ok=True)
|
78 |
+
|
79 |
+
# Generate images and extract features
|
80 |
+
img_idx = 0
|
81 |
+
clip_features = []
|
82 |
+
image_paths = []
|
83 |
+
prompts_training = []
|
84 |
+
print('Calculating Semantic PCA')
|
85 |
+
|
86 |
+
for prompt in tqdm(sampled_prompts_clip):
|
87 |
+
if 'max_sequence_length' in params:
|
88 |
+
images = pipe(prompt,
|
89 |
+
num_images_per_prompt = batch_size,
|
90 |
+
num_inference_steps = params['max_denoising_steps'],
|
91 |
+
guidance_scale=guidance_scale,
|
92 |
+
max_sequence_length = params['max_sequence_length'],
|
93 |
+
height = params['height'],
|
94 |
+
width = params['width'],
|
95 |
+
).images
|
96 |
+
else:
|
97 |
+
images = pipe(prompt,
|
98 |
+
num_images_per_prompt = batch_size,
|
99 |
+
num_inference_steps = params['max_denoising_steps'],
|
100 |
+
guidance_scale=guidance_scale,
|
101 |
+
height = params['height'],
|
102 |
+
width = params['width'],
|
103 |
+
).images
|
104 |
+
|
105 |
+
|
106 |
+
# Process images
|
107 |
+
clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True)
|
108 |
+
pixel_values = clip_inputs['pixel_values'].to(device)
|
109 |
+
|
110 |
+
# Get image embeddings
|
111 |
+
with torch.no_grad():
|
112 |
+
image_features = clip_model.get_image_features(pixel_values)
|
113 |
+
|
114 |
+
# Normalize embeddings
|
115 |
+
clip_feats = image_features / image_features.norm(dim=1, keepdim=True)
|
116 |
+
clip_features.append(clip_feats)
|
117 |
+
|
118 |
+
for im in images:
|
119 |
+
image_path = f"{params['savepath_training_images']}/{img_idx}.png"
|
120 |
+
im.save(image_path)
|
121 |
+
image_paths.append(image_path)
|
122 |
+
prompts_training.append(prompt)
|
123 |
+
img_idx += 1
|
124 |
+
|
125 |
+
|
126 |
+
clip_features = torch.cat(clip_features)
|
127 |
+
|
128 |
+
|
129 |
+
# Calculate principle components
|
130 |
+
pca = PCA(n_components=num_pca_components)
|
131 |
+
clip_embeds_np = clip_features.float().cpu().numpy()
|
132 |
+
pca.fit(clip_embeds_np)
|
133 |
+
clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype)
|
134 |
+
|
135 |
+
# Save results
|
136 |
+
torch.save(clip_principles, clip_features_path)
|
137 |
+
pd.DataFrame({
|
138 |
+
'prompt': prompts_training,
|
139 |
+
'image_path': image_paths
|
140 |
+
}).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False)
|
141 |
+
|
142 |
+
return clip_principles, prompts_training, image_paths
|
utils/flux_utils.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os , torch
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import gc
|
5 |
+
import itertools
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
|
9 |
+
import random
|
10 |
+
import shutil
|
11 |
+
import warnings
|
12 |
+
from contextlib import nullcontext
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
import transformers
|
19 |
+
from accelerate import Accelerator
|
20 |
+
from accelerate.logging import get_logger
|
21 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
22 |
+
from huggingface_hub import create_repo, upload_folder
|
23 |
+
from huggingface_hub.utils import insecure_hashlib
|
24 |
+
from PIL import Image
|
25 |
+
from PIL.ImageOps import exif_transpose
|
26 |
+
from torch.utils.data import Dataset
|
27 |
+
from torchvision import transforms
|
28 |
+
from torchvision.transforms.functional import crop
|
29 |
+
from tqdm.auto import tqdm
|
30 |
+
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
|
31 |
+
|
32 |
+
import diffusers
|
33 |
+
from diffusers import (
|
34 |
+
AutoencoderKL,
|
35 |
+
FlowMatchEulerDiscreteScheduler,
|
36 |
+
FluxTransformer2DModel,
|
37 |
+
)
|
38 |
+
from diffusers.optimization import get_scheduler
|
39 |
+
from diffusers.training_utils import (
|
40 |
+
_set_state_dict_into_text_encoder,
|
41 |
+
cast_training_params,
|
42 |
+
compute_density_for_timestep_sampling,
|
43 |
+
compute_loss_weighting_for_sd3,
|
44 |
+
)
|
45 |
+
from diffusers.utils import (
|
46 |
+
check_min_version,
|
47 |
+
convert_unet_state_dict_to_peft,
|
48 |
+
is_wandb_available,
|
49 |
+
)
|
50 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
51 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
52 |
+
|
53 |
+
|
54 |
+
from collections import defaultdict
|
55 |
+
|
56 |
+
|
57 |
+
from typing import List, Optional
|
58 |
+
import argparse
|
59 |
+
import ast
|
60 |
+
from pathlib import Path
|
61 |
+
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
|
62 |
+
from huggingface_hub import hf_hub_download
|
63 |
+
import gc
|
64 |
+
import torch.nn.functional as F
|
65 |
+
import os
|
66 |
+
import torch
|
67 |
+
from tqdm.auto import tqdm
|
68 |
+
import time, datetime
|
69 |
+
import numpy as np
|
70 |
+
from torch.optim import AdamW
|
71 |
+
from contextlib import ExitStack
|
72 |
+
from safetensors.torch import load_file
|
73 |
+
import torch.nn as nn
|
74 |
+
import random
|
75 |
+
from transformers import CLIPModel
|
76 |
+
|
77 |
+
from transformers import logging
|
78 |
+
logging.set_verbosity_warning()
|
79 |
+
|
80 |
+
from diffusers import logging
|
81 |
+
logging.set_verbosity_error()
|
82 |
+
|
83 |
+
|
84 |
+
def flush():
|
85 |
+
torch.cuda.empty_cache()
|
86 |
+
gc.collect()
|
87 |
+
flush()
|
88 |
+
def unwrap_model(model):
|
89 |
+
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
|
90 |
+
#if is_deepspeed_available():
|
91 |
+
# options += (DeepSpeedEngine,)
|
92 |
+
while isinstance(model, options):
|
93 |
+
model = model.module
|
94 |
+
return model
|
95 |
+
|
96 |
+
|
97 |
+
# Function to log gradients
|
98 |
+
def log_gradients(named_parameters):
|
99 |
+
grad_dict = defaultdict(lambda: defaultdict(float))
|
100 |
+
for name, param in named_parameters:
|
101 |
+
if param.requires_grad and param.grad is not None:
|
102 |
+
grad_dict[name]['mean'] = param.grad.abs().mean().item()
|
103 |
+
grad_dict[name]['std'] = param.grad.std().item()
|
104 |
+
grad_dict[name]['max'] = param.grad.abs().max().item()
|
105 |
+
grad_dict[name]['min'] = param.grad.abs().min().item()
|
106 |
+
return grad_dict
|
107 |
+
|
108 |
+
def import_model_class_from_model_name_or_path(
|
109 |
+
pretrained_model_name_or_path: str, subfolder: str = "text_encoder",
|
110 |
+
):
|
111 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
112 |
+
pretrained_model_name_or_path, subfolder=subfolder
|
113 |
+
, device_map='cuda:0'
|
114 |
+
)
|
115 |
+
model_class = text_encoder_config.architectures[0]
|
116 |
+
if model_class == "CLIPTextModel":
|
117 |
+
from transformers import CLIPTextModel
|
118 |
+
|
119 |
+
return CLIPTextModel
|
120 |
+
elif model_class == "T5EncoderModel":
|
121 |
+
from transformers import T5EncoderModel
|
122 |
+
|
123 |
+
return T5EncoderModel
|
124 |
+
else:
|
125 |
+
raise ValueError(f"{model_class} is not supported.")
|
126 |
+
def load_text_encoders(pretrained_model_name_or_path, class_one, class_two, weight_dtype):
|
127 |
+
text_encoder_one = class_one.from_pretrained(
|
128 |
+
pretrained_model_name_or_path,
|
129 |
+
subfolder="text_encoder",
|
130 |
+
torch_dtype=weight_dtype,
|
131 |
+
device_map='cuda:0'
|
132 |
+
)
|
133 |
+
text_encoder_two = class_two.from_pretrained(
|
134 |
+
pretrained_model_name_or_path,
|
135 |
+
subfolder="text_encoder_2",
|
136 |
+
torch_dtype=weight_dtype,
|
137 |
+
device_map='cuda:0'
|
138 |
+
)
|
139 |
+
return text_encoder_one, text_encoder_two
|
140 |
+
import matplotlib.pyplot as plt
|
141 |
+
def plot_labeled_images(images, labels):
|
142 |
+
# Determine the number of images
|
143 |
+
n = len(images)
|
144 |
+
|
145 |
+
# Create a new figure with a single row
|
146 |
+
fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
|
147 |
+
|
148 |
+
# If there's only one image, axes will be a single object, not an array
|
149 |
+
if n == 1:
|
150 |
+
axes = [axes]
|
151 |
+
|
152 |
+
# Plot each image
|
153 |
+
for i, (img, label) in enumerate(zip(images, labels)):
|
154 |
+
# Convert PIL image to numpy array
|
155 |
+
img_array = np.array(img)
|
156 |
+
|
157 |
+
# Display the image
|
158 |
+
axes[i].imshow(img_array)
|
159 |
+
axes[i].axis('off') # Turn off axis
|
160 |
+
|
161 |
+
# Set the title (label) for the image
|
162 |
+
axes[i].set_title(label)
|
163 |
+
|
164 |
+
# Adjust the layout and display the plot
|
165 |
+
plt.tight_layout()
|
166 |
+
plt.show()
|
167 |
+
|
168 |
+
|
169 |
+
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
170 |
+
text_inputs = tokenizer(
|
171 |
+
prompt,
|
172 |
+
padding="max_length",
|
173 |
+
max_length=max_sequence_length,
|
174 |
+
truncation=True,
|
175 |
+
return_length=False,
|
176 |
+
return_overflowing_tokens=False,
|
177 |
+
return_tensors="pt",
|
178 |
+
)
|
179 |
+
text_input_ids = text_inputs.input_ids
|
180 |
+
return text_input_ids
|
181 |
+
|
182 |
+
|
183 |
+
def _encode_prompt_with_t5(
|
184 |
+
text_encoder,
|
185 |
+
tokenizer,
|
186 |
+
max_sequence_length=512,
|
187 |
+
prompt=None,
|
188 |
+
num_images_per_prompt=1,
|
189 |
+
device=None,
|
190 |
+
text_input_ids=None,
|
191 |
+
):
|
192 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
193 |
+
batch_size = len(prompt)
|
194 |
+
|
195 |
+
if tokenizer is not None:
|
196 |
+
text_inputs = tokenizer(
|
197 |
+
prompt,
|
198 |
+
padding="max_length",
|
199 |
+
max_length=max_sequence_length,
|
200 |
+
truncation=True,
|
201 |
+
return_length=False,
|
202 |
+
return_overflowing_tokens=False,
|
203 |
+
return_tensors="pt",
|
204 |
+
)
|
205 |
+
text_input_ids = text_inputs.input_ids
|
206 |
+
else:
|
207 |
+
if text_input_ids is None:
|
208 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
209 |
+
|
210 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
211 |
+
|
212 |
+
dtype = text_encoder.dtype
|
213 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
214 |
+
|
215 |
+
_, seq_len, _ = prompt_embeds.shape
|
216 |
+
|
217 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
218 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
219 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
220 |
+
|
221 |
+
return prompt_embeds
|
222 |
+
|
223 |
+
|
224 |
+
def _encode_prompt_with_clip(
|
225 |
+
text_encoder,
|
226 |
+
tokenizer,
|
227 |
+
prompt: str,
|
228 |
+
device=None,
|
229 |
+
text_input_ids=None,
|
230 |
+
num_images_per_prompt: int = 1,
|
231 |
+
):
|
232 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
233 |
+
batch_size = len(prompt)
|
234 |
+
|
235 |
+
if tokenizer is not None:
|
236 |
+
text_inputs = tokenizer(
|
237 |
+
prompt,
|
238 |
+
padding="max_length",
|
239 |
+
max_length=77,
|
240 |
+
truncation=True,
|
241 |
+
return_overflowing_tokens=False,
|
242 |
+
return_length=False,
|
243 |
+
return_tensors="pt",
|
244 |
+
)
|
245 |
+
|
246 |
+
text_input_ids = text_inputs.input_ids
|
247 |
+
else:
|
248 |
+
if text_input_ids is None:
|
249 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
250 |
+
|
251 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
252 |
+
|
253 |
+
# Use pooled output of CLIPTextModel
|
254 |
+
prompt_embeds = prompt_embeds.pooler_output
|
255 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
256 |
+
|
257 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
258 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
259 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
260 |
+
|
261 |
+
return prompt_embeds
|
262 |
+
|
263 |
+
def encode_prompt(
|
264 |
+
text_encoders,
|
265 |
+
tokenizers,
|
266 |
+
prompt: str,
|
267 |
+
max_sequence_length,
|
268 |
+
device=None,
|
269 |
+
num_images_per_prompt: int = 1,
|
270 |
+
text_input_ids_list=None,
|
271 |
+
):
|
272 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
273 |
+
batch_size = len(prompt)
|
274 |
+
dtype = text_encoders[0].dtype
|
275 |
+
|
276 |
+
pooled_prompt_embeds = _encode_prompt_with_clip(
|
277 |
+
text_encoder=text_encoders[0],
|
278 |
+
tokenizer=tokenizers[0],
|
279 |
+
prompt=prompt,
|
280 |
+
device=device if device is not None else text_encoders[0].device,
|
281 |
+
num_images_per_prompt=num_images_per_prompt,
|
282 |
+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
283 |
+
)
|
284 |
+
|
285 |
+
prompt_embeds = _encode_prompt_with_t5(
|
286 |
+
text_encoder=text_encoders[1],
|
287 |
+
tokenizer=tokenizers[1],
|
288 |
+
max_sequence_length=max_sequence_length,
|
289 |
+
prompt=prompt,
|
290 |
+
num_images_per_prompt=num_images_per_prompt,
|
291 |
+
device=device if device is not None else text_encoders[1].device,
|
292 |
+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
293 |
+
)
|
294 |
+
|
295 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
296 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
297 |
+
|
298 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
299 |
+
|
300 |
+
def compute_text_embeddings(prompt, text_encoders, tokenizers,max_sequence_length=256):
|
301 |
+
device = text_encoders[0].device
|
302 |
+
with torch.no_grad():
|
303 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
304 |
+
text_encoders, tokenizers, prompt, max_sequence_length=max_sequence_length
|
305 |
+
)
|
306 |
+
prompt_embeds = prompt_embeds.to(device)
|
307 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
|
308 |
+
text_ids = text_ids.to(device)
|
309 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
310 |
+
|
311 |
+
|
312 |
+
def get_sigmas(timesteps, n_dim=4, device='cuda:0', dtype=torch.bfloat16):
|
313 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
|
314 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(device)
|
315 |
+
timesteps = timesteps.to(device)
|
316 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
317 |
+
|
318 |
+
sigma = sigmas[step_indices].flatten()
|
319 |
+
while len(sigma.shape) < n_dim:
|
320 |
+
sigma = sigma.unsqueeze(-1)
|
321 |
+
return sigma
|
322 |
+
|
323 |
+
|
324 |
+
def plot_history(history):
|
325 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
|
326 |
+
ax1.plot(history['concept'])
|
327 |
+
ax1.set_title('Concept Loss')
|
328 |
+
ax2.plot(movingaverage(history['concept'], 10))
|
329 |
+
ax2.set_title('Moving Average Concept Loss')
|
330 |
+
plt.tight_layout()
|
331 |
+
plt.show()
|
332 |
+
|
333 |
+
def movingaverage(interval, window_size):
|
334 |
+
window = np.ones(int(window_size))/float(window_size)
|
335 |
+
return np.convolve(interval, window, 'same')
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
@torch.no_grad()
|
340 |
+
def get_noisy_image_flux(
|
341 |
+
image,
|
342 |
+
vae,
|
343 |
+
transformer,
|
344 |
+
scheduler,
|
345 |
+
timesteps_to=1000,
|
346 |
+
generator=None,
|
347 |
+
**kwargs,
|
348 |
+
):
|
349 |
+
"""
|
350 |
+
Gets noisy latents for a given image using Flux pipeline approach.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
image: PIL image or tensor
|
354 |
+
vae: Flux VAE model
|
355 |
+
transformer: Flux transformer model
|
356 |
+
scheduler: Flux noise scheduler
|
357 |
+
timesteps_to: Target timestep
|
358 |
+
generator: Random generator for reproducibility
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
tuple: (noisy_latents, noise)
|
362 |
+
"""
|
363 |
+
device = vae.device
|
364 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
365 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
366 |
+
|
367 |
+
# Preprocess image
|
368 |
+
if not isinstance(image, torch.Tensor):
|
369 |
+
image = image_processor.preprocess(image)
|
370 |
+
image = image.to(device=device, dtype=torch.float32)
|
371 |
+
|
372 |
+
# Encode through VAE
|
373 |
+
init_latents = vae.encode(image).latents
|
374 |
+
init_latents = vae.config.scaling_factor * init_latents
|
375 |
+
|
376 |
+
# Get shape for noise
|
377 |
+
shape = init_latents.shape
|
378 |
+
|
379 |
+
# Generate noise
|
380 |
+
noise = randn_tensor(shape, generator=generator, device=device)
|
381 |
+
|
382 |
+
# Pack latents using Flux's method
|
383 |
+
init_latents = _pack_latents(
|
384 |
+
init_latents,
|
385 |
+
shape[0], # batch size
|
386 |
+
transformer.config.in_channels // 4,
|
387 |
+
height=shape[2],
|
388 |
+
width=shape[3]
|
389 |
+
)
|
390 |
+
noise = _pack_latents(
|
391 |
+
noise,
|
392 |
+
shape[0],
|
393 |
+
transformer.config.in_channels // 4,
|
394 |
+
height=shape[2],
|
395 |
+
width=shape[3]
|
396 |
+
)
|
397 |
+
|
398 |
+
# Get timestep
|
399 |
+
timestep = scheduler.timesteps[timesteps_to:timesteps_to+1]
|
400 |
+
|
401 |
+
# Add noise to latents
|
402 |
+
noisy_latents = scheduler.add_noise(init_latents, noise, timestep)
|
403 |
+
|
404 |
+
return noisy_latents, noise
|
utils/lora.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
3 |
+
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
4 |
+
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from typing import Optional, List, Type, Set, Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from diffusers import UNet2DConditionModel
|
12 |
+
from safetensors.torch import save_file
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
16 |
+
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
|
17 |
+
"Attention"
|
18 |
+
]
|
19 |
+
UNET_TARGET_REPLACE_MODULE_CONV = [
|
20 |
+
"ResnetBlock2D",
|
21 |
+
"Downsample2D",
|
22 |
+
"Upsample2D",
|
23 |
+
"DownBlock2D",
|
24 |
+
"UpBlock2D",
|
25 |
+
|
26 |
+
] # locon, 3clier
|
27 |
+
|
28 |
+
LORA_PREFIX_UNET = "lora_unet"
|
29 |
+
|
30 |
+
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
|
31 |
+
|
32 |
+
TRAINING_METHODS = Literal[
|
33 |
+
"noxattn", # train all layers except x-attns and time_embed layers
|
34 |
+
"innoxattn", # train all layers except self attention layers
|
35 |
+
"selfattn", # ESD-u, train only self attention layers
|
36 |
+
"xattn", # ESD-x, train only x attention layers
|
37 |
+
"xattn-up", # all up blocks only
|
38 |
+
"xattn-down",# all down blocks only
|
39 |
+
"xattn-mid",# mid blocks only
|
40 |
+
"full", # train all layers
|
41 |
+
"xattn-strict", # q and k values
|
42 |
+
"noxattn-hspace",
|
43 |
+
"noxattn-hspace-last",
|
44 |
+
"flux-attn",
|
45 |
+
# "xlayer",
|
46 |
+
# "outxattn",
|
47 |
+
# "outsattn",
|
48 |
+
# "inxattn",
|
49 |
+
# "inmidsattn",
|
50 |
+
# "selflayer",
|
51 |
+
]
|
52 |
+
|
53 |
+
def load_ortho_dict(n):
|
54 |
+
path = f'/share/u/rohit/orthogonal_basis/{n:09}.ckpt'
|
55 |
+
if os.path.isfile(path):
|
56 |
+
return torch.load(path)
|
57 |
+
else:
|
58 |
+
x = torch.randn(n,n)
|
59 |
+
eig, _, _ = torch.svd(x)
|
60 |
+
torch.save(eig, path)
|
61 |
+
return eig
|
62 |
+
|
63 |
+
def init_ortho_proj(rank, weight):
|
64 |
+
seed = torch.seed()
|
65 |
+
torch.manual_seed(datetime.now().timestamp())
|
66 |
+
q_index = torch.randint(high=weight.size(0),size=(rank,))
|
67 |
+
torch.manual_seed(seed)
|
68 |
+
|
69 |
+
ortho_q_init = load_ortho_dict(weight.size(0)).to(dtype=weight.dtype)[:,q_index]
|
70 |
+
return nn.Parameter(ortho_q_init)
|
71 |
+
|
72 |
+
|
73 |
+
class LoRAModule(nn.Module):
|
74 |
+
"""
|
75 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
lora_name,
|
81 |
+
org_module: nn.Module,
|
82 |
+
multiplier=1.0,
|
83 |
+
lora_dim=4,
|
84 |
+
alpha=1,
|
85 |
+
train_method='xattn',
|
86 |
+
fast_init = False
|
87 |
+
):
|
88 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
89 |
+
super().__init__()
|
90 |
+
self.lora_name = lora_name
|
91 |
+
self.lora_dim = lora_dim
|
92 |
+
|
93 |
+
if "Linear" in org_module.__class__.__name__:
|
94 |
+
in_dim = org_module.in_features
|
95 |
+
out_dim = org_module.out_features
|
96 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
97 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
98 |
+
|
99 |
+
elif "Conv" in org_module.__class__.__name__: # 一応
|
100 |
+
in_dim = org_module.in_channels
|
101 |
+
out_dim = org_module.out_channels
|
102 |
+
|
103 |
+
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
104 |
+
if self.lora_dim != lora_dim:
|
105 |
+
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
106 |
+
|
107 |
+
kernel_size = org_module.kernel_size
|
108 |
+
stride = org_module.stride
|
109 |
+
padding = org_module.padding
|
110 |
+
self.lora_down = nn.Conv2d(
|
111 |
+
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
112 |
+
)
|
113 |
+
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
114 |
+
|
115 |
+
if type(alpha) == torch.Tensor:
|
116 |
+
alpha = alpha.detach().numpy()
|
117 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
118 |
+
self.scale = alpha / self.lora_dim
|
119 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
120 |
+
|
121 |
+
# same as microsoft's
|
122 |
+
nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
|
123 |
+
if train_method == 'full':
|
124 |
+
nn.init.zeros_(self.lora_up.weight)
|
125 |
+
else:
|
126 |
+
if not fast_init:
|
127 |
+
self.lora_up.weight = init_ortho_proj(lora_dim, self.lora_up.weight)
|
128 |
+
self.lora_up.weight.requires_grad_(False)
|
129 |
+
else:
|
130 |
+
nn.init.zeros_(self.lora_up.weight)
|
131 |
+
|
132 |
+
self.multiplier = multiplier
|
133 |
+
self.org_module = org_module # remove in applying
|
134 |
+
|
135 |
+
def apply_to(self):
|
136 |
+
self.org_forward = self.org_module.forward
|
137 |
+
self.org_module.forward = self.forward
|
138 |
+
del self.org_module
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
return (
|
142 |
+
self.org_forward(x)
|
143 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
class LoRANetwork(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
unet: UNet2DConditionModel,
|
151 |
+
rank: int = 4,
|
152 |
+
multiplier: float = 1.0,
|
153 |
+
alpha: float = 1.0,
|
154 |
+
train_method: TRAINING_METHODS = "full",
|
155 |
+
layers = ['Linear', 'Conv'],
|
156 |
+
fast_init = False,
|
157 |
+
) -> None:
|
158 |
+
super().__init__()
|
159 |
+
self.lora_scale = 1
|
160 |
+
self.multiplier = multiplier
|
161 |
+
self.lora_dim = rank
|
162 |
+
self.alpha = alpha
|
163 |
+
self.train_method=train_method
|
164 |
+
# LoRAのみ
|
165 |
+
self.module = LoRAModule
|
166 |
+
|
167 |
+
# unetのloraを作る
|
168 |
+
self.unet_loras = self.create_modules(
|
169 |
+
LORA_PREFIX_UNET,
|
170 |
+
unet,
|
171 |
+
DEFAULT_TARGET_REPLACE,
|
172 |
+
self.lora_dim,
|
173 |
+
self.multiplier,
|
174 |
+
train_method=train_method,
|
175 |
+
layers = layers,
|
176 |
+
fast_init=fast_init,
|
177 |
+
)
|
178 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
179 |
+
|
180 |
+
# assertion 名前の被りがないか確認しているようだ
|
181 |
+
lora_names = set()
|
182 |
+
for lora in self.unet_loras:
|
183 |
+
assert (
|
184 |
+
lora.lora_name not in lora_names
|
185 |
+
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
186 |
+
lora_names.add(lora.lora_name)
|
187 |
+
|
188 |
+
# 適用する
|
189 |
+
for lora in self.unet_loras:
|
190 |
+
lora.apply_to()
|
191 |
+
self.add_module(
|
192 |
+
lora.lora_name,
|
193 |
+
lora,
|
194 |
+
)
|
195 |
+
|
196 |
+
del unet
|
197 |
+
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
|
200 |
+
def create_modules(
|
201 |
+
self,
|
202 |
+
prefix: str,
|
203 |
+
root_module: nn.Module,
|
204 |
+
target_replace_modules: List[str],
|
205 |
+
rank: int,
|
206 |
+
multiplier: float,
|
207 |
+
train_method: TRAINING_METHODS,
|
208 |
+
layers: List[str],
|
209 |
+
fast_init: bool,
|
210 |
+
) -> list:
|
211 |
+
filt_layers = []
|
212 |
+
if 'Linear' in layers:
|
213 |
+
filt_layers.extend(["Linear", "LoRACompatibleLinear"])
|
214 |
+
if 'Conv' in layers:
|
215 |
+
filt_layers.extend(["Conv2d", "LoRACompatibleConv"])
|
216 |
+
loras = []
|
217 |
+
names = []
|
218 |
+
for name, module in root_module.named_modules():
|
219 |
+
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
|
220 |
+
if "attn2" in name or "time_embed" in name:
|
221 |
+
continue
|
222 |
+
elif train_method == "innoxattn": # Cross Attention 以外学習
|
223 |
+
if "attn2" in name:
|
224 |
+
continue
|
225 |
+
elif train_method == "selfattn": # Self Attention のみ学習
|
226 |
+
if "attn1" not in name:
|
227 |
+
continue
|
228 |
+
elif train_method in ["xattn", "xattn-strict", "xattn-up", "xattn-down", "xattn-mid"]: # Cross Attention のみ学習
|
229 |
+
if "attn2" not in name:
|
230 |
+
continue
|
231 |
+
if train_method == 'xattn-up':
|
232 |
+
if 'up_block' not in name:
|
233 |
+
continue
|
234 |
+
if train_method == 'xattn-down':
|
235 |
+
if 'down_block' not in name:
|
236 |
+
continue
|
237 |
+
if train_method == 'xattn-mid':
|
238 |
+
if 'mid_block' not in name:
|
239 |
+
continue
|
240 |
+
elif train_method == "full": # 全部学習
|
241 |
+
pass
|
242 |
+
elif train_method == "flux-attn":
|
243 |
+
if "attn" not in name:
|
244 |
+
continue
|
245 |
+
else:
|
246 |
+
raise NotImplementedError(
|
247 |
+
f"train_method: {train_method} is not implemented."
|
248 |
+
)
|
249 |
+
if module.__class__.__name__ in target_replace_modules:
|
250 |
+
for child_name, child_module in module.named_modules():
|
251 |
+
if child_module.__class__.__name__ in filt_layers:
|
252 |
+
|
253 |
+
|
254 |
+
if train_method == 'xattn-strict':
|
255 |
+
if 'out' in child_name:
|
256 |
+
continue
|
257 |
+
if 'to_q' in child_name:
|
258 |
+
continue
|
259 |
+
if train_method == 'noxattn-hspace':
|
260 |
+
if 'mid_block' not in name:
|
261 |
+
continue
|
262 |
+
if train_method == 'noxattn-hspace-last':
|
263 |
+
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
|
264 |
+
continue
|
265 |
+
lora_name = prefix + "." + name + "." + child_name
|
266 |
+
lora_name = lora_name.replace(".", "_")
|
267 |
+
# print(f"{lora_name}")
|
268 |
+
lora = self.module(
|
269 |
+
lora_name, child_module, multiplier, rank, self.alpha, train_method, fast_init
|
270 |
+
)
|
271 |
+
# print(name, child_name)
|
272 |
+
# print(child_module.weight.shape)
|
273 |
+
if lora_name not in names:
|
274 |
+
loras.append(lora)
|
275 |
+
names.append(lora_name)
|
276 |
+
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
|
277 |
+
return loras
|
278 |
+
|
279 |
+
def prepare_optimizer_params(self):
|
280 |
+
all_params = []
|
281 |
+
|
282 |
+
if self.unet_loras: # 実質これしかない
|
283 |
+
params = []
|
284 |
+
if self.train_method == 'full':
|
285 |
+
[params.extend(lora.parameters()) for lora in self.unet_loras]
|
286 |
+
else:
|
287 |
+
[params.extend(lora.lora_down.parameters()) for lora in self.unet_loras]
|
288 |
+
param_data = {"params": params}
|
289 |
+
all_params.append(param_data)
|
290 |
+
|
291 |
+
return all_params
|
292 |
+
|
293 |
+
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
294 |
+
state_dict = self.state_dict()
|
295 |
+
|
296 |
+
if dtype is not None:
|
297 |
+
for key in list(state_dict.keys()):
|
298 |
+
v = state_dict[key]
|
299 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
300 |
+
state_dict[key] = v
|
301 |
+
|
302 |
+
# for key in list(state_dict.keys()):
|
303 |
+
# if not key.startswith("lora"):
|
304 |
+
# # lora以外除外
|
305 |
+
# del state_dict[key]
|
306 |
+
|
307 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
308 |
+
save_file(state_dict, file, metadata)
|
309 |
+
else:
|
310 |
+
torch.save(state_dict, file)
|
311 |
+
def set_lora_slider(self, scale):
|
312 |
+
self.lora_scale = scale
|
313 |
+
|
314 |
+
def __enter__(self):
|
315 |
+
for lora in self.unet_loras:
|
316 |
+
lora.multiplier = 1.0 * self.lora_scale
|
317 |
+
|
318 |
+
def __exit__(self, exc_type, exc_value, tb):
|
319 |
+
for lora in self.unet_loras:
|
320 |
+
lora.multiplier = 0
|
utils/model_util.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Optional
|
2 |
+
|
3 |
+
import torch, gc, os
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast
|
5 |
+
from transformers import (
|
6 |
+
AutoModel,
|
7 |
+
CLIPModel,
|
8 |
+
CLIPProcessor,
|
9 |
+
)
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from diffusers import (
|
12 |
+
UNet2DConditionModel,
|
13 |
+
SchedulerMixin,
|
14 |
+
StableDiffusionPipeline,
|
15 |
+
StableDiffusionXLPipeline,
|
16 |
+
FluxPipeline,
|
17 |
+
AutoencoderKL,
|
18 |
+
FluxTransformer2DModel,
|
19 |
+
)
|
20 |
+
import copy
|
21 |
+
from diffusers.schedulers import (
|
22 |
+
DDIMScheduler,
|
23 |
+
DDPMScheduler,
|
24 |
+
LMSDiscreteScheduler,
|
25 |
+
EulerAncestralDiscreteScheduler,
|
26 |
+
FlowMatchEulerDiscreteScheduler,
|
27 |
+
)
|
28 |
+
from diffusers import LCMScheduler, AutoencoderTiny
|
29 |
+
import sys
|
30 |
+
sys.path.append('.')
|
31 |
+
from .flux_utils import *
|
32 |
+
|
33 |
+
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
|
34 |
+
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
|
35 |
+
|
36 |
+
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
|
37 |
+
|
38 |
+
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
|
39 |
+
|
40 |
+
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
|
41 |
+
|
42 |
+
|
43 |
+
def load_diffusers_model(
|
44 |
+
pretrained_model_name_or_path: str,
|
45 |
+
v2: bool = False,
|
46 |
+
clip_skip: Optional[int] = None,
|
47 |
+
weight_dtype: torch.dtype = torch.float32,
|
48 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
49 |
+
# VAE はいらない
|
50 |
+
|
51 |
+
if v2:
|
52 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
53 |
+
TOKENIZER_V2_MODEL_NAME,
|
54 |
+
subfolder="tokenizer",
|
55 |
+
torch_dtype=weight_dtype,
|
56 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
57 |
+
)
|
58 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
59 |
+
pretrained_model_name_or_path,
|
60 |
+
subfolder="text_encoder",
|
61 |
+
# default is clip skip 2
|
62 |
+
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
|
63 |
+
torch_dtype=weight_dtype,
|
64 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
68 |
+
TOKENIZER_V1_MODEL_NAME,
|
69 |
+
subfolder="tokenizer",
|
70 |
+
torch_dtype=weight_dtype,
|
71 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
72 |
+
)
|
73 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
74 |
+
pretrained_model_name_or_path,
|
75 |
+
subfolder="text_encoder",
|
76 |
+
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
|
77 |
+
torch_dtype=weight_dtype,
|
78 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
79 |
+
)
|
80 |
+
|
81 |
+
unet = UNet2DConditionModel.from_pretrained(
|
82 |
+
pretrained_model_name_or_path,
|
83 |
+
subfolder="unet",
|
84 |
+
torch_dtype=weight_dtype,
|
85 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
86 |
+
)
|
87 |
+
|
88 |
+
return tokenizer, text_encoder, unet
|
89 |
+
|
90 |
+
|
91 |
+
def load_checkpoint_model(
|
92 |
+
checkpoint_path: str,
|
93 |
+
v2: bool = False,
|
94 |
+
clip_skip: Optional[int] = None,
|
95 |
+
weight_dtype: torch.dtype = torch.float32,
|
96 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
97 |
+
pipe = StableDiffusionPipeline.from_ckpt(
|
98 |
+
checkpoint_path,
|
99 |
+
upcast_attention=True if v2 else False,
|
100 |
+
torch_dtype=weight_dtype,
|
101 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
102 |
+
)
|
103 |
+
|
104 |
+
unet = pipe.unet
|
105 |
+
tokenizer = pipe.tokenizer
|
106 |
+
text_encoder = pipe.text_encoder
|
107 |
+
if clip_skip is not None:
|
108 |
+
if v2:
|
109 |
+
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
|
110 |
+
else:
|
111 |
+
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
|
112 |
+
|
113 |
+
del pipe
|
114 |
+
|
115 |
+
return tokenizer, text_encoder, unet
|
116 |
+
|
117 |
+
|
118 |
+
def load_models(
|
119 |
+
pretrained_model_name_or_path: str,
|
120 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
121 |
+
v2: bool = False,
|
122 |
+
v_pred: bool = False,
|
123 |
+
weight_dtype: torch.dtype = torch.float32,
|
124 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
|
125 |
+
if pretrained_model_name_or_path.endswith(
|
126 |
+
".ckpt"
|
127 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
128 |
+
tokenizer, text_encoder, unet = load_checkpoint_model(
|
129 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
130 |
+
)
|
131 |
+
else: # diffusers
|
132 |
+
tokenizer, text_encoder, unet = load_diffusers_model(
|
133 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
134 |
+
)
|
135 |
+
|
136 |
+
# VAE はいらない
|
137 |
+
|
138 |
+
scheduler = create_noise_scheduler(
|
139 |
+
scheduler_name,
|
140 |
+
prediction_type="v_prediction" if v_pred else "epsilon",
|
141 |
+
)
|
142 |
+
|
143 |
+
return tokenizer, text_encoder, unet, scheduler
|
144 |
+
|
145 |
+
|
146 |
+
def load_diffusers_model_xl(
|
147 |
+
pretrained_model_name_or_path: str,
|
148 |
+
weight_dtype: torch.dtype = torch.float32,
|
149 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
150 |
+
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
|
151 |
+
|
152 |
+
tokenizers = [
|
153 |
+
CLIPTokenizer.from_pretrained(
|
154 |
+
pretrained_model_name_or_path,
|
155 |
+
subfolder="tokenizer",
|
156 |
+
torch_dtype=weight_dtype,
|
157 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
158 |
+
),
|
159 |
+
CLIPTokenizer.from_pretrained(
|
160 |
+
pretrained_model_name_or_path,
|
161 |
+
subfolder="tokenizer_2",
|
162 |
+
torch_dtype=weight_dtype,
|
163 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
164 |
+
pad_token_id=0, # same as open clip
|
165 |
+
),
|
166 |
+
]
|
167 |
+
|
168 |
+
text_encoders = [
|
169 |
+
CLIPTextModel.from_pretrained(
|
170 |
+
pretrained_model_name_or_path,
|
171 |
+
subfolder="text_encoder",
|
172 |
+
torch_dtype=weight_dtype,
|
173 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
174 |
+
),
|
175 |
+
CLIPTextModelWithProjection.from_pretrained(
|
176 |
+
pretrained_model_name_or_path,
|
177 |
+
subfolder="text_encoder_2",
|
178 |
+
torch_dtype=weight_dtype,
|
179 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
180 |
+
),
|
181 |
+
]
|
182 |
+
|
183 |
+
unet = UNet2DConditionModel.from_pretrained(
|
184 |
+
pretrained_model_name_or_path,
|
185 |
+
subfolder="unet",
|
186 |
+
torch_dtype=weight_dtype,
|
187 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
188 |
+
)
|
189 |
+
|
190 |
+
return tokenizers, text_encoders, unet
|
191 |
+
|
192 |
+
|
193 |
+
def load_checkpoint_model_xl(
|
194 |
+
checkpoint_path: str,
|
195 |
+
weight_dtype: torch.dtype = torch.float32,
|
196 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
197 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
198 |
+
checkpoint_path,
|
199 |
+
torch_dtype=weight_dtype,
|
200 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
201 |
+
)
|
202 |
+
|
203 |
+
unet = pipe.unet
|
204 |
+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
205 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
206 |
+
if len(text_encoders) == 2:
|
207 |
+
text_encoders[1].pad_token_id = 0
|
208 |
+
|
209 |
+
del pipe
|
210 |
+
|
211 |
+
return tokenizers, text_encoders, unet
|
212 |
+
|
213 |
+
|
214 |
+
def load_models_xl_(
|
215 |
+
pretrained_model_name_or_path: str,
|
216 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
217 |
+
weight_dtype: torch.dtype = torch.float32,
|
218 |
+
) -> tuple[
|
219 |
+
list[CLIPTokenizer],
|
220 |
+
list[SDXL_TEXT_ENCODER_TYPE],
|
221 |
+
UNet2DConditionModel,
|
222 |
+
SchedulerMixin,
|
223 |
+
]:
|
224 |
+
if pretrained_model_name_or_path.endswith(
|
225 |
+
".ckpt"
|
226 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
227 |
+
(
|
228 |
+
tokenizers,
|
229 |
+
text_encoders,
|
230 |
+
unet,
|
231 |
+
) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
|
232 |
+
else: # diffusers
|
233 |
+
(
|
234 |
+
tokenizers,
|
235 |
+
text_encoders,
|
236 |
+
unet,
|
237 |
+
) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
|
238 |
+
|
239 |
+
scheduler = create_noise_scheduler(scheduler_name)
|
240 |
+
|
241 |
+
return tokenizers, text_encoders, unet, scheduler
|
242 |
+
|
243 |
+
|
244 |
+
def create_noise_scheduler(
|
245 |
+
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
|
246 |
+
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
247 |
+
) -> SchedulerMixin:
|
248 |
+
# 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
|
249 |
+
|
250 |
+
name = scheduler_name.lower().replace(" ", "_")
|
251 |
+
if name == "ddim":
|
252 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
|
253 |
+
scheduler = DDIMScheduler(
|
254 |
+
beta_start=0.00085,
|
255 |
+
beta_end=0.012,
|
256 |
+
beta_schedule="scaled_linear",
|
257 |
+
num_train_timesteps=1000,
|
258 |
+
clip_sample=False,
|
259 |
+
prediction_type=prediction_type, # これでいいの?
|
260 |
+
)
|
261 |
+
elif name == "ddpm":
|
262 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
|
263 |
+
scheduler = DDPMScheduler(
|
264 |
+
beta_start=0.00085,
|
265 |
+
beta_end=0.012,
|
266 |
+
beta_schedule="scaled_linear",
|
267 |
+
num_train_timesteps=1000,
|
268 |
+
clip_sample=False,
|
269 |
+
prediction_type=prediction_type,
|
270 |
+
)
|
271 |
+
elif name == "lms":
|
272 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
|
273 |
+
scheduler = LMSDiscreteScheduler(
|
274 |
+
beta_start=0.00085,
|
275 |
+
beta_end=0.012,
|
276 |
+
beta_schedule="scaled_linear",
|
277 |
+
num_train_timesteps=1000,
|
278 |
+
prediction_type=prediction_type,
|
279 |
+
)
|
280 |
+
elif name == "euler_a":
|
281 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
|
282 |
+
scheduler = EulerAncestralDiscreteScheduler(
|
283 |
+
beta_start=0.00085,
|
284 |
+
beta_end=0.012,
|
285 |
+
beta_schedule="scaled_linear",
|
286 |
+
num_train_timesteps=1000,
|
287 |
+
# clip_sample=False,
|
288 |
+
prediction_type=prediction_type,
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
raise ValueError(f"Unknown scheduler name: {name}")
|
292 |
+
|
293 |
+
return scheduler
|
294 |
+
|
295 |
+
|
296 |
+
def load_models_xl(params):
|
297 |
+
"""
|
298 |
+
Load all required models for training
|
299 |
+
|
300 |
+
Args:
|
301 |
+
params: Dictionary containing model parameters and configurations
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
dict: Dictionary containing all loaded models and tokenizers
|
305 |
+
"""
|
306 |
+
device = params['device']
|
307 |
+
weight_dtype = params['weight_dtype']
|
308 |
+
|
309 |
+
# Load SDXL components (UNet, text encoders, tokenizers)
|
310 |
+
scheduler_name = 'ddim'
|
311 |
+
tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_(
|
312 |
+
params['pretrained_model_name_or_path'],
|
313 |
+
scheduler_name=scheduler_name,
|
314 |
+
)
|
315 |
+
|
316 |
+
# Move text encoders to device and set to eval mode
|
317 |
+
for text_encoder in text_encoders:
|
318 |
+
text_encoder.to(device, dtype=weight_dtype)
|
319 |
+
text_encoder.requires_grad_(False)
|
320 |
+
text_encoder.eval()
|
321 |
+
|
322 |
+
# Set up UNet
|
323 |
+
unet.to(device, dtype=weight_dtype)
|
324 |
+
unet.requires_grad_(False)
|
325 |
+
unet.eval()
|
326 |
+
|
327 |
+
# Load tiny VAE for efficiency
|
328 |
+
vae = AutoencoderTiny.from_pretrained(
|
329 |
+
"madebyollin/taesdxl",
|
330 |
+
torch_dtype=weight_dtype
|
331 |
+
)
|
332 |
+
vae = vae.to(device, dtype=weight_dtype)
|
333 |
+
vae.requires_grad_(False)
|
334 |
+
|
335 |
+
# Load appropriate encoder (CLIP or DinoV2)
|
336 |
+
if params['encoder'] == 'dinov2-small':
|
337 |
+
clip_model = AutoModel.from_pretrained(
|
338 |
+
'facebook/dinov2-small',
|
339 |
+
torch_dtype=weight_dtype
|
340 |
+
)
|
341 |
+
clip_processor= None
|
342 |
+
else:
|
343 |
+
clip_model = CLIPModel.from_pretrained(
|
344 |
+
"wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
|
345 |
+
torch_dtype=weight_dtype
|
346 |
+
)
|
347 |
+
clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
|
348 |
+
clip_model = clip_model.to(device, dtype=weight_dtype)
|
349 |
+
clip_model.requires_grad_(False)
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
# If using DMD checkpoint, load it
|
354 |
+
if params['distilled'] != 'None':
|
355 |
+
if '.safetensors' in params['distilled']:
|
356 |
+
unet.load_state_dict(load_file(params['distilled'], device=device))
|
357 |
+
elif 'dmd2' in params['distilled']:
|
358 |
+
repo_name = "tianweiy/DMD2"
|
359 |
+
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
|
360 |
+
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
|
361 |
+
else:
|
362 |
+
unet.load_state_dict(torch.load(params['distilled']))
|
363 |
+
|
364 |
+
|
365 |
+
# Set up LCM scheduler for DMD
|
366 |
+
noise_scheduler = LCMScheduler(
|
367 |
+
beta_start=0.00085,
|
368 |
+
beta_end=0.012,
|
369 |
+
beta_schedule="scaled_linear",
|
370 |
+
num_train_timesteps=1000,
|
371 |
+
prediction_type="epsilon",
|
372 |
+
original_inference_steps=1000
|
373 |
+
)
|
374 |
+
|
375 |
+
noise_scheduler.set_timesteps(params['max_denoising_steps'])
|
376 |
+
pipe = StableDiffusionXLPipeline(vae = vae,
|
377 |
+
text_encoder = text_encoders[0],
|
378 |
+
text_encoder_2 = text_encoders[1],
|
379 |
+
tokenizer = tokenizers[0],
|
380 |
+
tokenizer_2 = tokenizers[1],
|
381 |
+
unet = unet,
|
382 |
+
scheduler = noise_scheduler)
|
383 |
+
pipe.set_progress_bar_config(disable=True)
|
384 |
+
return {
|
385 |
+
'unet': unet,
|
386 |
+
'vae': vae,
|
387 |
+
'clip_model': clip_model,
|
388 |
+
'clip_processor': clip_processor,
|
389 |
+
'tokenizers': tokenizers,
|
390 |
+
'text_encoders': text_encoders,
|
391 |
+
'noise_scheduler': noise_scheduler
|
392 |
+
}, pipe
|
393 |
+
|
394 |
+
|
395 |
+
def load_models_flux(params):
|
396 |
+
# Load the tokenizers
|
397 |
+
tokenizer_one = CLIPTokenizer.from_pretrained(
|
398 |
+
params['pretrained_model_name_or_path'],
|
399 |
+
subfolder="tokenizer",
|
400 |
+
torch_dtype=params['weight_dtype'], device_map=params['device']
|
401 |
+
)
|
402 |
+
tokenizer_two = T5TokenizerFast.from_pretrained(
|
403 |
+
params['pretrained_model_name_or_path'],
|
404 |
+
subfolder="tokenizer_2",
|
405 |
+
torch_dtype=params['weight_dtype'], device_map=params['device']
|
406 |
+
)
|
407 |
+
|
408 |
+
# Load scheduler and models
|
409 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
410 |
+
params['pretrained_model_name_or_path'],
|
411 |
+
subfolder="scheduler",
|
412 |
+
torch_dtype=params['weight_dtype'], device=params['device']
|
413 |
+
)
|
414 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
# import correct text encoder classes
|
419 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
420 |
+
params['pretrained_model_name_or_path'],
|
421 |
+
)
|
422 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
423 |
+
params['pretrained_model_name_or_path'], subfolder="text_encoder_2"
|
424 |
+
)
|
425 |
+
# Load the text encoders
|
426 |
+
text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype'])
|
427 |
+
|
428 |
+
# Load VAE
|
429 |
+
vae = AutoencoderKL.from_pretrained(
|
430 |
+
params['pretrained_model_name_or_path'],
|
431 |
+
subfolder="vae",
|
432 |
+
torch_dtype=params['weight_dtype'], device_map='auto'
|
433 |
+
)
|
434 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
435 |
+
params['pretrained_model_name_or_path'],
|
436 |
+
subfolder="transformer",
|
437 |
+
torch_dtype=params['weight_dtype']
|
438 |
+
)
|
439 |
+
|
440 |
+
# We only train the additional adapter LoRA layers
|
441 |
+
transformer.requires_grad_(False)
|
442 |
+
vae.requires_grad_(False)
|
443 |
+
text_encoder_one.requires_grad_(False)
|
444 |
+
text_encoder_two.requires_grad_(False)
|
445 |
+
|
446 |
+
vae.to(params['device'])
|
447 |
+
transformer.to(params['device'])
|
448 |
+
text_encoder_one.to(params['device'])
|
449 |
+
text_encoder_two.to(params['device'])
|
450 |
+
|
451 |
+
# Load appropriate encoder (CLIP or DinoV2)
|
452 |
+
if params['encoder'] == 'dinov2-small':
|
453 |
+
clip_model = AutoModel.from_pretrained(
|
454 |
+
'facebook/dinov2-small',
|
455 |
+
torch_dtype=params['weight_dtype']
|
456 |
+
)
|
457 |
+
clip_processor= None
|
458 |
+
else:
|
459 |
+
clip_model = CLIPModel.from_pretrained(
|
460 |
+
"wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
|
461 |
+
torch_dtype=params['weight_dtype']
|
462 |
+
)
|
463 |
+
clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
|
464 |
+
clip_model = clip_model.to(params['device'], dtype=params['weight_dtype'])
|
465 |
+
clip_model.requires_grad_(False)
|
466 |
+
|
467 |
+
|
468 |
+
pipe = FluxPipeline(noise_scheduler,
|
469 |
+
vae,
|
470 |
+
text_encoder_one,
|
471 |
+
tokenizer_one,
|
472 |
+
text_encoder_two,
|
473 |
+
tokenizer_two,
|
474 |
+
transformer,
|
475 |
+
)
|
476 |
+
pipe.set_progress_bar_config(disable=True)
|
477 |
+
|
478 |
+
return {
|
479 |
+
'transformer': transformer,
|
480 |
+
'vae': vae,
|
481 |
+
'clip_model': clip_model,
|
482 |
+
'clip_processor': clip_processor,
|
483 |
+
'tokenizers': [tokenizer_one, tokenizer_two],
|
484 |
+
'text_encoders': [text_encoder_one,text_encoder_two],
|
485 |
+
'noise_scheduler': noise_scheduler
|
486 |
+
}, pipe
|
487 |
+
|
488 |
+
def save_checkpoint(networks, save_path, weight_dtype):
|
489 |
+
"""
|
490 |
+
Save network weights and perform cleanup
|
491 |
+
|
492 |
+
Args:
|
493 |
+
networks: Dictionary of LoRA networks to save
|
494 |
+
save_path: Path to save the checkpoints
|
495 |
+
weight_dtype: Data type for the weights
|
496 |
+
"""
|
497 |
+
print("Saving checkpoint...")
|
498 |
+
|
499 |
+
try:
|
500 |
+
# Create save directory if it doesn't exist
|
501 |
+
os.makedirs(save_path, exist_ok=True)
|
502 |
+
|
503 |
+
# Save each network's weights
|
504 |
+
for net_idx, network in networks.items():
|
505 |
+
save_name = f"{save_path}/slider_{net_idx}.pt"
|
506 |
+
try:
|
507 |
+
network.save_weights(
|
508 |
+
save_name,
|
509 |
+
dtype=weight_dtype,
|
510 |
+
)
|
511 |
+
except Exception as e:
|
512 |
+
print(f"Error saving network {net_idx}: {str(e)}")
|
513 |
+
continue
|
514 |
+
|
515 |
+
# Cleanup
|
516 |
+
torch.cuda.empty_cache()
|
517 |
+
gc.collect()
|
518 |
+
|
519 |
+
print("Checkpoint saved successfully.")
|
520 |
+
|
521 |
+
except Exception as e:
|
522 |
+
print(f"Error during checkpoint saving: {str(e)}")
|
523 |
+
|
524 |
+
finally:
|
525 |
+
# Ensure memory is cleaned up even if save fails
|
526 |
+
torch.cuda.empty_cache()
|
527 |
+
gc.collect()
|
utils/prompt_util.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anthropic
|
2 |
+
client = anthropic.Anthropic()
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
def claude_generate_prompts_sliders(prompt,
|
6 |
+
num_prompts=20,
|
7 |
+
temperature=0.2,
|
8 |
+
max_tokens=2000,
|
9 |
+
frequency_penalty=0.0,
|
10 |
+
model="claude-3-5-sonnet-20240620",
|
11 |
+
verbose=False):
|
12 |
+
assistant_prompt = f''' You are an expert in writing diverse image captions. When i provide a prompt, I want you to give me {num_prompts} alternative prompts that is similar to the provided prompt but produces diverse images. Be creative and make sure the original subjects in the original prompt are present in your prompts. Make sure that you end the prompts with keywords that will produce high quality images like ",detailed, 8k" or ",hyper-realistic, 4k".
|
13 |
+
|
14 |
+
Give me the expanded prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
|
15 |
+
I need you to give me only the python list and nothing else. Do not explain yourself
|
16 |
+
|
17 |
+
example output format:
|
18 |
+
["prompt1", "prompt2", ...]
|
19 |
+
'''
|
20 |
+
|
21 |
+
user_prompt = prompt
|
22 |
+
|
23 |
+
message=[
|
24 |
+
{
|
25 |
+
"role": "user",
|
26 |
+
"content": [
|
27 |
+
{
|
28 |
+
"type": "text",
|
29 |
+
"text": user_prompt
|
30 |
+
}
|
31 |
+
]
|
32 |
+
}
|
33 |
+
]
|
34 |
+
|
35 |
+
output = client.messages.create(
|
36 |
+
model=model,
|
37 |
+
max_tokens=max_tokens,
|
38 |
+
temperature=temperature,
|
39 |
+
system=assistant_prompt,
|
40 |
+
messages=message
|
41 |
+
)
|
42 |
+
content = output.content[0].text
|
43 |
+
return content
|
44 |
+
|
45 |
+
|
46 |
+
def expand_prompts(concept_prompts: List[str], diverse_prompt_num: int, args) -> List[str]:
|
47 |
+
"""
|
48 |
+
Expand the input prompts using Claude if requested.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
concept_prompts: Initial list of prompts
|
52 |
+
diverse_prompt_num: Number of variations to generate per prompt
|
53 |
+
args: Training arguments
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
List of expanded prompts
|
57 |
+
"""
|
58 |
+
diverse_prompts = []
|
59 |
+
|
60 |
+
if diverse_prompt_num != 0:
|
61 |
+
for prompt in concept_prompts:
|
62 |
+
try:
|
63 |
+
claude_generated_prompts = claude_generate_prompts_sliders(
|
64 |
+
prompt=prompt,
|
65 |
+
num_prompts=diverse_prompt_num,
|
66 |
+
temperature=0.2,
|
67 |
+
max_tokens=8000,
|
68 |
+
frequency_penalty=0.0,
|
69 |
+
model="claude-3-5-sonnet-20240620",
|
70 |
+
verbose=False
|
71 |
+
)
|
72 |
+
diverse_prompts.extend(eval(claude_generated_prompts))
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error with Claude response: {e}")
|
75 |
+
diverse_prompts.append(prompt)
|
76 |
+
else:
|
77 |
+
diverse_prompts = concept_prompts
|
78 |
+
|
79 |
+
print(f"Using prompts: {diverse_prompts}")
|
80 |
+
return diverse_prompts
|
utils/train_util.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
+
from diffusers import UNet2DConditionModel, SchedulerMixin, FluxImg2ImgPipeline
|
7 |
+
from diffusers.image_processor import VaeImageProcessor
|
8 |
+
# from model_util import SDXL_TEXT_ENCODER_TYPE
|
9 |
+
from diffusers.utils.torch_utils import randn_tensor
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
14 |
+
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
15 |
+
|
16 |
+
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
|
17 |
+
TEXT_ENCODER_2_PROJECTION_DIM = 1280
|
18 |
+
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
|
19 |
+
|
20 |
+
|
21 |
+
def get_random_noise(
|
22 |
+
batch_size: int, height: int, width: int, generator: torch.Generator = None
|
23 |
+
) -> torch.Tensor:
|
24 |
+
return torch.randn(
|
25 |
+
(
|
26 |
+
batch_size,
|
27 |
+
UNET_IN_CHANNELS,
|
28 |
+
height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
|
29 |
+
width // VAE_SCALE_FACTOR,
|
30 |
+
),
|
31 |
+
generator=generator,
|
32 |
+
device="cpu",
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
37 |
+
def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
|
38 |
+
latents = latents + noise_offset * torch.randn(
|
39 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
40 |
+
)
|
41 |
+
return latents
|
42 |
+
|
43 |
+
|
44 |
+
def get_initial_latents(
|
45 |
+
scheduler: SchedulerMixin,
|
46 |
+
n_imgs: int,
|
47 |
+
height: int,
|
48 |
+
width: int,
|
49 |
+
n_prompts: int,
|
50 |
+
generator=None,
|
51 |
+
) -> torch.Tensor:
|
52 |
+
noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
|
53 |
+
n_prompts, 1, 1, 1
|
54 |
+
)
|
55 |
+
|
56 |
+
latents = noise * scheduler.init_noise_sigma
|
57 |
+
|
58 |
+
return latents
|
59 |
+
|
60 |
+
|
61 |
+
def text_tokenize(
|
62 |
+
tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
|
63 |
+
prompts: list[str],
|
64 |
+
):
|
65 |
+
return tokenizer(
|
66 |
+
prompts,
|
67 |
+
padding="max_length",
|
68 |
+
max_length=tokenizer.model_max_length,
|
69 |
+
truncation=True,
|
70 |
+
return_tensors="pt",
|
71 |
+
).input_ids
|
72 |
+
|
73 |
+
|
74 |
+
def text_encode(text_encoder: CLIPTextModel, tokens):
|
75 |
+
return text_encoder(tokens.to(text_encoder.device))[0]
|
76 |
+
|
77 |
+
|
78 |
+
def encode_prompts(
|
79 |
+
tokenizer: CLIPTokenizer,
|
80 |
+
text_encoder: CLIPTokenizer,
|
81 |
+
prompts: list[str],
|
82 |
+
):
|
83 |
+
|
84 |
+
text_tokens = text_tokenize(tokenizer, prompts)
|
85 |
+
text_embeddings = text_encode(text_encoder, text_tokens)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
return text_embeddings
|
90 |
+
|
91 |
+
|
92 |
+
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
93 |
+
def text_encode_xl(
|
94 |
+
text_encoder,
|
95 |
+
tokens: torch.FloatTensor,
|
96 |
+
num_images_per_prompt: int = 1,
|
97 |
+
):
|
98 |
+
prompt_embeds = text_encoder(
|
99 |
+
tokens.to(text_encoder.device), output_hidden_states=True
|
100 |
+
)
|
101 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
102 |
+
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
103 |
+
|
104 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
105 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
106 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
107 |
+
|
108 |
+
return prompt_embeds, pooled_prompt_embeds
|
109 |
+
|
110 |
+
|
111 |
+
def encode_prompts_xl(
|
112 |
+
tokenizers,
|
113 |
+
text_encoders,
|
114 |
+
prompts: list[str],
|
115 |
+
num_images_per_prompt: int = 1,
|
116 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
117 |
+
# text_encoder and text_encoder_2's penuultimate layer's output
|
118 |
+
text_embeds_list = []
|
119 |
+
pooled_text_embeds = None # always text_encoder_2's pool
|
120 |
+
|
121 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
122 |
+
text_tokens_input_ids = text_tokenize(tokenizer, prompts)
|
123 |
+
text_embeds, pooled_text_embeds = text_encode_xl(
|
124 |
+
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
125 |
+
)
|
126 |
+
|
127 |
+
text_embeds_list.append(text_embeds)
|
128 |
+
|
129 |
+
bs_embed = pooled_text_embeds.shape[0]
|
130 |
+
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
|
131 |
+
bs_embed * num_images_per_prompt, -1
|
132 |
+
)
|
133 |
+
|
134 |
+
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
135 |
+
|
136 |
+
|
137 |
+
def concat_embeddings(
|
138 |
+
unconditional: torch.FloatTensor,
|
139 |
+
conditional: torch.FloatTensor,
|
140 |
+
n_imgs: int,
|
141 |
+
):
|
142 |
+
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
|
143 |
+
|
144 |
+
|
145 |
+
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
|
146 |
+
def predict_noise(
|
147 |
+
unet: UNet2DConditionModel,
|
148 |
+
scheduler: SchedulerMixin,
|
149 |
+
timestep: int, # 現在のタイムステップ
|
150 |
+
latents: torch.FloatTensor,
|
151 |
+
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
|
152 |
+
guidance_scale=7.5,
|
153 |
+
) -> torch.FloatTensor:
|
154 |
+
latent_model_input = latents
|
155 |
+
if guidance_scale!=0:
|
156 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
157 |
+
latent_model_input = torch.cat([latents] * 2)
|
158 |
+
|
159 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
160 |
+
|
161 |
+
# predict the noise residual
|
162 |
+
noise_pred = unet(
|
163 |
+
latent_model_input,
|
164 |
+
timestep,
|
165 |
+
encoder_hidden_states=text_embeddings,
|
166 |
+
).sample
|
167 |
+
|
168 |
+
# perform guidance
|
169 |
+
if guidance_scale != 1 and guidance_scale!=0:
|
170 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
171 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
172 |
+
noise_pred_text - noise_pred_uncond
|
173 |
+
)
|
174 |
+
|
175 |
+
return noise_pred
|
176 |
+
|
177 |
+
|
178 |
+
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
179 |
+
@torch.no_grad()
|
180 |
+
def diffusion(
|
181 |
+
unet: UNet2DConditionModel,
|
182 |
+
scheduler: SchedulerMixin,
|
183 |
+
latents: torch.FloatTensor, # ただのノイズだけのlatents
|
184 |
+
text_embeddings: torch.FloatTensor,
|
185 |
+
total_timesteps: int = 1000,
|
186 |
+
start_timesteps=0,
|
187 |
+
guidance_scale=1,
|
188 |
+
composition=False,
|
189 |
+
**kwargs,
|
190 |
+
):
|
191 |
+
# latents_steps = []
|
192 |
+
|
193 |
+
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
194 |
+
if not composition:
|
195 |
+
noise_pred = predict_noise(
|
196 |
+
unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale
|
197 |
+
)
|
198 |
+
if guidance_scale==1:
|
199 |
+
_, noise_pred = noise_pred.chunk(2)
|
200 |
+
else:
|
201 |
+
for idx in range(text_embeddings.shape[0]):
|
202 |
+
pred = predict_noise(
|
203 |
+
unet, scheduler, timestep, latents, text_embeddings[idx:idx+1], guidance_scale=1
|
204 |
+
)
|
205 |
+
uncond, pred = noise_pred.chunk(2)
|
206 |
+
if idx == 0:
|
207 |
+
noise_pred = guidance_scale * pred
|
208 |
+
else:
|
209 |
+
noise_pred += guidance_scale * pred
|
210 |
+
noise_pred += uncond
|
211 |
+
|
212 |
+
|
213 |
+
# compute the previous noisy sample x_t -> x_t-1
|
214 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
215 |
+
|
216 |
+
# return latents_steps
|
217 |
+
return latents
|
218 |
+
|
219 |
+
|
220 |
+
def rescale_noise_cfg(
|
221 |
+
noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
|
222 |
+
):
|
223 |
+
"""
|
224 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
225 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
226 |
+
"""
|
227 |
+
std_text = noise_pred_text.std(
|
228 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
229 |
+
)
|
230 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
231 |
+
# rescale the results from guidance (fixes overexposure)
|
232 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
233 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
234 |
+
noise_cfg = (
|
235 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
236 |
+
)
|
237 |
+
|
238 |
+
return noise_cfg
|
239 |
+
|
240 |
+
|
241 |
+
def predict_noise_xl(
|
242 |
+
unet: UNet2DConditionModel,
|
243 |
+
scheduler: SchedulerMixin,
|
244 |
+
timestep: int, # 現在のタイムステップ
|
245 |
+
latents: torch.FloatTensor,
|
246 |
+
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
|
247 |
+
add_text_embeddings: torch.FloatTensor, # pooled なやつ
|
248 |
+
add_time_ids: torch.FloatTensor,
|
249 |
+
guidance_scale=7.5,
|
250 |
+
guidance_rescale=0.7,
|
251 |
+
) -> torch.FloatTensor:
|
252 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
253 |
+
latent_model_input = latents
|
254 |
+
if guidance_scale !=0:
|
255 |
+
latent_model_input = torch.cat([latents] * 2)
|
256 |
+
|
257 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
258 |
+
|
259 |
+
added_cond_kwargs = {
|
260 |
+
"text_embeds": add_text_embeddings,
|
261 |
+
"time_ids": add_time_ids,
|
262 |
+
}
|
263 |
+
|
264 |
+
# predict the noise residual
|
265 |
+
noise_pred = unet(
|
266 |
+
latent_model_input,
|
267 |
+
timestep,
|
268 |
+
encoder_hidden_states=text_embeddings,
|
269 |
+
added_cond_kwargs=added_cond_kwargs,
|
270 |
+
).sample
|
271 |
+
# perform guidance
|
272 |
+
if guidance_scale != 1 and guidance_scale!=0:
|
273 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
274 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
275 |
+
noise_pred_text - noise_pred_uncond
|
276 |
+
)
|
277 |
+
|
278 |
+
return noise_pred
|
279 |
+
# # perform guidance
|
280 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
281 |
+
# guided_target = noise_pred_uncond + guidance_scale * (
|
282 |
+
# noise_pred_text - noise_pred_uncond
|
283 |
+
# )
|
284 |
+
|
285 |
+
# # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
286 |
+
# noise_pred = rescale_noise_cfg(
|
287 |
+
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
288 |
+
# )
|
289 |
+
|
290 |
+
# return guided_target
|
291 |
+
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def diffusion_xl(
|
295 |
+
unet: UNet2DConditionModel,
|
296 |
+
scheduler: SchedulerMixin,
|
297 |
+
latents: torch.FloatTensor, # ただのノイズだけのlatents
|
298 |
+
text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
|
299 |
+
add_text_embeddings: torch.FloatTensor, # pooled なやつ
|
300 |
+
add_time_ids: torch.FloatTensor,
|
301 |
+
guidance_scale: float = 1.0,
|
302 |
+
total_timesteps: int = 1000,
|
303 |
+
start_timesteps=0,
|
304 |
+
composition=False,
|
305 |
+
):
|
306 |
+
# latents_steps = []
|
307 |
+
|
308 |
+
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
309 |
+
if not composition:
|
310 |
+
noise_pred = predict_noise_xl(
|
311 |
+
unet,
|
312 |
+
scheduler,
|
313 |
+
timestep,
|
314 |
+
latents,
|
315 |
+
text_embeddings,
|
316 |
+
add_text_embeddings,
|
317 |
+
add_time_ids,
|
318 |
+
guidance_scale=guidance_scale,
|
319 |
+
guidance_rescale=0.7,
|
320 |
+
)
|
321 |
+
if guidance_scale==1:
|
322 |
+
_, noise_pred = noise_pred.chunk(2)
|
323 |
+
# compute the previous noisy sample x_t -> x_t-1
|
324 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
325 |
+
|
326 |
+
# return latents_steps
|
327 |
+
return latents
|
328 |
+
|
329 |
+
|
330 |
+
# for XL
|
331 |
+
def get_add_time_ids(
|
332 |
+
height: int,
|
333 |
+
width: int,
|
334 |
+
dynamic_crops: bool = False,
|
335 |
+
dtype: torch.dtype = torch.float32,
|
336 |
+
):
|
337 |
+
if dynamic_crops:
|
338 |
+
# random float scale between 1 and 3
|
339 |
+
random_scale = torch.rand(1).item() * 2 + 1
|
340 |
+
original_size = (int(height * random_scale), int(width * random_scale))
|
341 |
+
# random position
|
342 |
+
crops_coords_top_left = (
|
343 |
+
torch.randint(0, original_size[0] - height, (1,)).item(),
|
344 |
+
torch.randint(0, original_size[1] - width, (1,)).item(),
|
345 |
+
)
|
346 |
+
target_size = (height, width)
|
347 |
+
else:
|
348 |
+
original_size = (height, width)
|
349 |
+
crops_coords_top_left = (0, 0)
|
350 |
+
target_size = (height, width)
|
351 |
+
|
352 |
+
# this is expected as 6
|
353 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
354 |
+
|
355 |
+
# this is expected as 2816
|
356 |
+
passed_add_embed_dim = (
|
357 |
+
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
|
358 |
+
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
|
359 |
+
)
|
360 |
+
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
|
361 |
+
raise ValueError(
|
362 |
+
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
363 |
+
)
|
364 |
+
|
365 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
366 |
+
return add_time_ids
|
367 |
+
|
368 |
+
|
369 |
+
def get_optimizer(name: str):
|
370 |
+
name = name.lower()
|
371 |
+
|
372 |
+
if name.startswith("dadapt"):
|
373 |
+
import dadaptation
|
374 |
+
|
375 |
+
if name == "dadaptadam":
|
376 |
+
return dadaptation.DAdaptAdam
|
377 |
+
elif name == "dadaptlion":
|
378 |
+
return dadaptation.DAdaptLion
|
379 |
+
else:
|
380 |
+
raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
|
381 |
+
|
382 |
+
elif name.endswith("8bit"): # 検証してない
|
383 |
+
import bitsandbytes as bnb
|
384 |
+
|
385 |
+
if name == "adam8bit":
|
386 |
+
return bnb.optim.Adam8bit
|
387 |
+
elif name == "lion8bit":
|
388 |
+
return bnb.optim.Lion8bit
|
389 |
+
else:
|
390 |
+
raise ValueError("8bit optimizer must be adam8bit or lion8bit")
|
391 |
+
|
392 |
+
else:
|
393 |
+
if name == "adam":
|
394 |
+
return torch.optim.Adam
|
395 |
+
elif name == "adamw":
|
396 |
+
return torch.optim.AdamW
|
397 |
+
elif name == "lion":
|
398 |
+
from lion_pytorch import Lion
|
399 |
+
|
400 |
+
return Lion
|
401 |
+
elif name == "prodigy":
|
402 |
+
import prodigyopt
|
403 |
+
|
404 |
+
return prodigyopt.Prodigy
|
405 |
+
else:
|
406 |
+
raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
|
407 |
+
|
408 |
+
@torch.no_grad()
|
409 |
+
def get_noisy_image(
|
410 |
+
image,
|
411 |
+
vae,
|
412 |
+
unet,
|
413 |
+
scheduler,
|
414 |
+
timesteps_to = 1000,
|
415 |
+
generator=None,
|
416 |
+
**kwargs,
|
417 |
+
):
|
418 |
+
# latents_steps = []
|
419 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
420 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
421 |
+
|
422 |
+
device = vae.device
|
423 |
+
image = image_processor.preprocess(image).to(device).to(vae.dtype)
|
424 |
+
|
425 |
+
init_latents = vae.encode(image).latents
|
426 |
+
|
427 |
+
init_latents = vae.config.scaling_factor * init_latents
|
428 |
+
|
429 |
+
init_latents = torch.cat([init_latents], dim=0)
|
430 |
+
|
431 |
+
shape = init_latents.shape
|
432 |
+
|
433 |
+
noise = randn_tensor(shape, generator=generator, device=device)
|
434 |
+
|
435 |
+
timestep = scheduler.timesteps[timesteps_to:timesteps_to+1]
|
436 |
+
# get latents
|
437 |
+
init_latents = scheduler.add_noise(init_latents, noise, timestep)
|
438 |
+
|
439 |
+
return init_latents, noise
|
440 |
+
|
441 |
+
|
442 |
+
def get_lr_scheduler(
|
443 |
+
name: Optional[str],
|
444 |
+
optimizer: torch.optim.Optimizer,
|
445 |
+
max_iterations: Optional[int],
|
446 |
+
lr_min: Optional[float],
|
447 |
+
**kwargs,
|
448 |
+
):
|
449 |
+
if name == "cosine":
|
450 |
+
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
451 |
+
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
452 |
+
)
|
453 |
+
elif name == "cosine_with_restarts":
|
454 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
455 |
+
optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
|
456 |
+
)
|
457 |
+
elif name == "step":
|
458 |
+
return torch.optim.lr_scheduler.StepLR(
|
459 |
+
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
460 |
+
)
|
461 |
+
elif name == "constant":
|
462 |
+
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
463 |
+
elif name == "linear":
|
464 |
+
return torch.optim.lr_scheduler.LinearLR(
|
465 |
+
optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
raise ValueError(
|
469 |
+
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
470 |
+
)
|
471 |
+
|
472 |
+
|
473 |
+
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
|
474 |
+
max_resolution = bucket_resolution
|
475 |
+
min_resolution = bucket_resolution // 2
|
476 |
+
|
477 |
+
step = 64
|
478 |
+
|
479 |
+
min_step = min_resolution // step
|
480 |
+
max_step = max_resolution // step
|
481 |
+
|
482 |
+
height = torch.randint(min_step, max_step, (1,)).item() * step
|
483 |
+
width = torch.randint(min_step, max_step, (1,)).item() * step
|
484 |
+
|
485 |
+
return height, width
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
def _get_t5_prompt_embeds(
|
490 |
+
text_encoder,
|
491 |
+
tokenizer,
|
492 |
+
prompt,
|
493 |
+
max_sequence_length=512,
|
494 |
+
device=None,
|
495 |
+
dtype=None
|
496 |
+
):
|
497 |
+
"""Helper function to get T5 embeddings in Flux format"""
|
498 |
+
device = device or text_encoder.device
|
499 |
+
dtype = dtype or text_encoder.dtype
|
500 |
+
|
501 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
502 |
+
batch_size = len(prompt)
|
503 |
+
|
504 |
+
text_inputs = tokenizer(
|
505 |
+
prompt,
|
506 |
+
padding="max_length",
|
507 |
+
max_length=max_sequence_length,
|
508 |
+
truncation=True,
|
509 |
+
return_length=False,
|
510 |
+
return_overflowing_tokens=False,
|
511 |
+
return_tensors="pt",
|
512 |
+
)
|
513 |
+
text_input_ids = text_inputs.input_ids
|
514 |
+
|
515 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
|
516 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
517 |
+
|
518 |
+
return prompt_embeds
|
519 |
+
|
520 |
+
def _get_clip_prompt_embeds(
|
521 |
+
text_encoder,
|
522 |
+
tokenizer,
|
523 |
+
prompt,
|
524 |
+
device=None,
|
525 |
+
):
|
526 |
+
"""Helper function to get CLIP embeddings in Flux format"""
|
527 |
+
device = device or text_encoder.device
|
528 |
+
|
529 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
530 |
+
batch_size = len(prompt)
|
531 |
+
|
532 |
+
text_inputs = tokenizer(
|
533 |
+
prompt,
|
534 |
+
padding="max_length",
|
535 |
+
max_length=tokenizer.model_max_length,
|
536 |
+
truncation=True,
|
537 |
+
return_overflowing_tokens=False,
|
538 |
+
return_length=False,
|
539 |
+
return_tensors="pt",
|
540 |
+
)
|
541 |
+
|
542 |
+
text_input_ids = text_inputs.input_ids
|
543 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
544 |
+
|
545 |
+
# Use pooled output for Flux
|
546 |
+
prompt_embeds = prompt_embeds.pooler_output
|
547 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
548 |
+
|
549 |
+
return prompt_embeds
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
@torch.no_grad()
|
556 |
+
def get_noisy_image_flux(
|
557 |
+
image,
|
558 |
+
vae,
|
559 |
+
transformer,
|
560 |
+
scheduler,
|
561 |
+
timesteps_to=1000,
|
562 |
+
generator=None,
|
563 |
+
params = None
|
564 |
+
):
|
565 |
+
"""
|
566 |
+
Gets noisy latents for a given image using Flux pipeline approach.
|
567 |
+
|
568 |
+
Args:
|
569 |
+
image (Union[PIL.Image.Image, torch.Tensor]): Input image
|
570 |
+
vae (AutoencoderKL): Flux VAE model
|
571 |
+
transformer (FluxTransformer2DModel): Flux transformer model
|
572 |
+
scheduler (FlowMatchEulerDiscreteScheduler): Flux noise scheduler
|
573 |
+
timesteps_to (int, optional): Target timestep. Defaults to 1000.
|
574 |
+
generator (torch.Generator, optional): Random generator for reproducibility.
|
575 |
+
|
576 |
+
Returns:
|
577 |
+
tuple: (noisy_latents, noise) - Both in packed Flux format
|
578 |
+
"""
|
579 |
+
|
580 |
+
vae_scale_factor = params['vae_scale_factor']
|
581 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
|
582 |
+
|
583 |
+
image = image_processor.preprocess(image, height=params['height'], width=params['width'])
|
584 |
+
image = image.to(dtype=torch.float32)
|
585 |
+
|
586 |
+
# 5. Prepare latent variables
|
587 |
+
num_channels_latents = transformer.config.in_channels // 4
|
588 |
+
|
589 |
+
latents, latent_image_ids = prepare_latents_flux(
|
590 |
+
image,
|
591 |
+
timesteps_to.repeat(params['batchsize']),
|
592 |
+
params['batchsize'],
|
593 |
+
num_channels_latents,
|
594 |
+
params['height'],
|
595 |
+
params['width'],
|
596 |
+
transformer.dtype,
|
597 |
+
transformer.device,
|
598 |
+
generator,
|
599 |
+
None,
|
600 |
+
vae_scale_factor,
|
601 |
+
vae,
|
602 |
+
scheduler
|
603 |
+
)
|
604 |
+
|
605 |
+
return latents, latent_image_ids
|
606 |
+
|
607 |
+
|
608 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
609 |
+
"""
|
610 |
+
Pack latents into Flux's 2x2 patch format
|
611 |
+
"""
|
612 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
613 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
614 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
615 |
+
return latents
|
616 |
+
|
617 |
+
|
618 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
619 |
+
"""
|
620 |
+
Unpack latents from Flux's 2x2 patch format back to image space
|
621 |
+
"""
|
622 |
+
batch_size, num_patches, channels = latents.shape
|
623 |
+
|
624 |
+
# Account for VAE compression and packing
|
625 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
626 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
627 |
+
|
628 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
629 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
630 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
631 |
+
|
632 |
+
return latents
|
633 |
+
|
634 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
635 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
636 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
637 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
638 |
+
|
639 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
640 |
+
|
641 |
+
latent_image_ids = latent_image_ids.reshape(
|
642 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
643 |
+
)
|
644 |
+
|
645 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
646 |
+
|
647 |
+
|
648 |
+
def prepare_latents_flux(
|
649 |
+
image,
|
650 |
+
timestep,
|
651 |
+
batch_size,
|
652 |
+
num_channels_latents,
|
653 |
+
height,
|
654 |
+
width,
|
655 |
+
dtype,
|
656 |
+
device,
|
657 |
+
generator,
|
658 |
+
latents=None,
|
659 |
+
vae_scale_factor=None,
|
660 |
+
vae=None,
|
661 |
+
scheduler=None
|
662 |
+
):
|
663 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
664 |
+
raise ValueError(
|
665 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
666 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
667 |
+
)
|
668 |
+
|
669 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
670 |
+
# latent height and width to be divisible by 2.
|
671 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
672 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
673 |
+
shape = (batch_size, num_channels_latents, height, width)
|
674 |
+
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
675 |
+
|
676 |
+
if latents is not None:
|
677 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
678 |
+
|
679 |
+
image = image.to(device=device, dtype=dtype)
|
680 |
+
image_latents = _encode_vae_image(vae=vae, image=image, generator=generator)
|
681 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
682 |
+
# expand init_latents for batch_size
|
683 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
684 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
685 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
686 |
+
raise ValueError(
|
687 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
688 |
+
)
|
689 |
+
else:
|
690 |
+
image_latents = torch.cat([image_latents], dim=0)
|
691 |
+
|
692 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
693 |
+
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
694 |
+
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
695 |
+
return latents, latent_image_ids
|
696 |
+
|
697 |
+
|
698 |
+
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
699 |
+
if isinstance(generator, list):
|
700 |
+
image_latents = [
|
701 |
+
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
|
702 |
+
for i in range(image.shape[0])
|
703 |
+
]
|
704 |
+
image_latents = torch.cat(image_latents, dim=0)
|
705 |
+
else:
|
706 |
+
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
707 |
+
|
708 |
+
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
709 |
+
return image_latents
|
710 |
+
|
711 |
+
|
712 |
+
def retrieve_latents(
|
713 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
714 |
+
):
|
715 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
716 |
+
return encoder_output.latent_dist.sample(generator)
|
717 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
718 |
+
return encoder_output.latent_dist.mode()
|
719 |
+
elif hasattr(encoder_output, "latents"):
|
720 |
+
return encoder_output.latents
|
721 |
+
else:
|
722 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
utils/utils.py
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anthropic
|
2 |
+
client = anthropic.Anthropic()
|
3 |
+
from diffusers.image_processor import VaeImageProcessor
|
4 |
+
from typing import List, Optional
|
5 |
+
import argparse
|
6 |
+
import ast
|
7 |
+
import pandas as pd
|
8 |
+
from pathlib import Path
|
9 |
+
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderTiny
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import gc
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
import time, datetime
|
17 |
+
import numpy as np
|
18 |
+
from torch.optim import AdamW
|
19 |
+
from contextlib import ExitStack
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
import torch.nn as nn
|
22 |
+
import random
|
23 |
+
from transformers import CLIPModel
|
24 |
+
|
25 |
+
import sys
|
26 |
+
import argparse
|
27 |
+
import wandb
|
28 |
+
from diffusers import AutoencoderKL
|
29 |
+
from diffusers.image_processor import VaeImageProcessor
|
30 |
+
|
31 |
+
sys.path.append('../')
|
32 |
+
from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
|
33 |
+
|
34 |
+
from transformers import logging
|
35 |
+
logging.set_verbosity_warning()
|
36 |
+
import matplotlib.pyplot as plt
|
37 |
+
from diffusers import logging
|
38 |
+
logging.set_verbosity_error()
|
39 |
+
modules = DEFAULT_TARGET_REPLACE
|
40 |
+
modules += UNET_TARGET_REPLACE_MODULE_CONV
|
41 |
+
import torch
|
42 |
+
import torch.nn.functional as F
|
43 |
+
from sklearn.decomposition import PCA
|
44 |
+
import random
|
45 |
+
import gc
|
46 |
+
import diffusers
|
47 |
+
from diffusers import DiffusionPipeline, FluxPipeline
|
48 |
+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, SchedulerMixin
|
49 |
+
from diffusers.loaders import AttnProcsLayers
|
50 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
|
51 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
52 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
53 |
+
from diffusers.utils.torch_utils import randn_tensor
|
54 |
+
|
55 |
+
import inspect
|
56 |
+
import os
|
57 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
58 |
+
from diffusers.pipelines import StableDiffusionXLPipeline
|
59 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
60 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
61 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
|
62 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import XLA_AVAILABLE
|
63 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
64 |
+
|
65 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
66 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
67 |
+
|
68 |
+
import sys
|
69 |
+
sys.path.append('../.')
|
70 |
+
from utils.flux_utils import *
|
71 |
+
import random
|
72 |
+
|
73 |
+
import torch
|
74 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
75 |
+
|
76 |
+
|
77 |
+
def flush():
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
gc.collect()
|
80 |
+
|
81 |
+
def calculate_shift(
|
82 |
+
image_seq_len,
|
83 |
+
base_seq_len: int = 256,
|
84 |
+
max_seq_len: int = 4096,
|
85 |
+
base_shift: float = 0.5,
|
86 |
+
max_shift: float = 1.16,
|
87 |
+
):
|
88 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
89 |
+
b = base_shift - m * base_seq_len
|
90 |
+
mu = image_seq_len * m + b
|
91 |
+
return mu
|
92 |
+
|
93 |
+
|
94 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
95 |
+
def retrieve_timesteps(
|
96 |
+
scheduler,
|
97 |
+
num_inference_steps: Optional[int] = None,
|
98 |
+
device: Optional[Union[str, torch.device]] = None,
|
99 |
+
timesteps: Optional[List[int]] = None,
|
100 |
+
sigmas: Optional[List[float]] = None,
|
101 |
+
**kwargs,
|
102 |
+
):
|
103 |
+
"""
|
104 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
105 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
scheduler (`SchedulerMixin`):
|
109 |
+
The scheduler to get timesteps from.
|
110 |
+
num_inference_steps (`int`):
|
111 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
112 |
+
must be `None`.
|
113 |
+
device (`str` or `torch.device`, *optional*):
|
114 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
115 |
+
timesteps (`List[int]`, *optional*):
|
116 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
117 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
118 |
+
sigmas (`List[float]`, *optional*):
|
119 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
120 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
124 |
+
second element is the number of inference steps.
|
125 |
+
"""
|
126 |
+
if timesteps is not None and sigmas is not None:
|
127 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
128 |
+
if timesteps is not None:
|
129 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
130 |
+
if not accepts_timesteps:
|
131 |
+
raise ValueError(
|
132 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
133 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
134 |
+
)
|
135 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
136 |
+
timesteps = scheduler.timesteps
|
137 |
+
num_inference_steps = len(timesteps)
|
138 |
+
elif sigmas is not None:
|
139 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
140 |
+
if not accept_sigmas:
|
141 |
+
raise ValueError(
|
142 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
143 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
144 |
+
)
|
145 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
146 |
+
timesteps = scheduler.timesteps
|
147 |
+
num_inference_steps = len(timesteps)
|
148 |
+
else:
|
149 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
150 |
+
timesteps = scheduler.timesteps
|
151 |
+
return timesteps, num_inference_steps
|
152 |
+
|
153 |
+
def claude_generate_prompts_sliders(prompt,
|
154 |
+
num_prompts=20,
|
155 |
+
temperature=0.2,
|
156 |
+
max_tokens=2000,
|
157 |
+
frequency_penalty=0.0,
|
158 |
+
model="claude-3-5-sonnet-20240620",
|
159 |
+
verbose=False,
|
160 |
+
train_type='concept'):
|
161 |
+
gpt_assistant_prompt = f''' You are an expert in writing diverse image captions. When i provide a prompt, I want you to give me {num_prompts} alternative prompts that is similar to the provided prompt but produces diverse images. Be creative and make sure the original subjects in the original prompt are present in your prompts. Make sure that you end the prompts with keywords that will produce high quality images like ",detailed, 8k" or ",hyper-realistic, 4k".
|
162 |
+
|
163 |
+
Give me the expanded prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
|
164 |
+
I need you to give me only the python list and nothing else. Do not explain yourself
|
165 |
+
|
166 |
+
example output format:
|
167 |
+
["prompt1", "prompt2", ...]
|
168 |
+
'''
|
169 |
+
|
170 |
+
if train_type == 'art':
|
171 |
+
gpt_assistant_prompt = f'''You are an expert in writing art image captions. I want you to generate prompts that would create diverse artwork images.
|
172 |
+
Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output creative and interesting artwork images with unique and diverse artistic styles. A prompt could like "an <object/landscape> in the style of <an artist>" or "an <object/landscape> in the style of <an artistic style (e.g. cubism)>". make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
|
173 |
+
|
174 |
+
Give me the prompts in the style of a list. start with a [ and end with ] do not add any special characters like \n
|
175 |
+
I need you to give me only the python list and nothing else. Do not explain yourself
|
176 |
+
|
177 |
+
example output format:
|
178 |
+
["prompt1", "prompt2", ...]
|
179 |
+
'''
|
180 |
+
# if 'dog' in prompt:
|
181 |
+
# gpt_assistant_prompt = f'''You are an expert in prompting text-image generation models. I want you to generate simple prompts that would trigger the image generation model to generate a unique dog breeds.
|
182 |
+
# Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output diverse and interesting dog breeds with unique and diverse looks. make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
|
183 |
+
|
184 |
+
# Be creative and make sure to remember diversity is the key. Give me the prompts in the form of a list. start with a [ and end with ] do not add any special characters like \n
|
185 |
+
# I need you to give me only the python list and nothing else. Do not explain yourself
|
186 |
+
|
187 |
+
# example output format:
|
188 |
+
# ["prompt1", "prompt2", ...]
|
189 |
+
# '''
|
190 |
+
|
191 |
+
if train_type == 'artclaudesemantics':
|
192 |
+
gpt_assistant_prompt = f'''You are an expert in prompting text-image generation models. I want you to generate simple prompts that would trigger the image generation model to generate a unique artistic images but DO NOT SPECIFY THE ART STYLE.
|
193 |
+
Your role is to give me {num_prompts} diverse prompts that will make the image-generation model to output diverse and interesting art images. Usually like "<some object or scene> in the style of " or "<some object or scene> in style of". Always end your prompts with "in the style of" so that i can manually add the style i want. make sure that you end the prompts with enhancing keywords like ",detailed, 8k" or ",hyper-realistic, 4k".
|
194 |
+
|
195 |
+
Be creative and make sure to remember diversity is the key. Give me the prompts in the form of a list. start with a [ and end with ] do not add any special characters like \n
|
196 |
+
I need you to give me only the python list and nothing else. Do not explain yourself
|
197 |
+
|
198 |
+
example output format:
|
199 |
+
["prompt1", "prompt2", ...]
|
200 |
+
'''
|
201 |
+
gpt_user_prompt = prompt
|
202 |
+
gpt_prompt = gpt_assistant_prompt, gpt_user_prompt
|
203 |
+
message=[
|
204 |
+
{
|
205 |
+
"role": "user",
|
206 |
+
"content": [
|
207 |
+
{
|
208 |
+
"type": "text",
|
209 |
+
"text": gpt_user_prompt
|
210 |
+
}
|
211 |
+
]
|
212 |
+
}
|
213 |
+
]
|
214 |
+
|
215 |
+
output = client.messages.create(
|
216 |
+
model=model,
|
217 |
+
max_tokens=max_tokens,
|
218 |
+
temperature=temperature,
|
219 |
+
system=gpt_assistant_prompt,
|
220 |
+
messages=message
|
221 |
+
)
|
222 |
+
content = output.content[0].text
|
223 |
+
return content
|
224 |
+
|
225 |
+
def normalize_image(image):
|
226 |
+
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(image.device)
|
227 |
+
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(image.device)
|
228 |
+
return (image - mean) / std
|
229 |
+
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def call_sdxl(
|
233 |
+
self,
|
234 |
+
prompt: Union[str, List[str]] = None,
|
235 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
236 |
+
height: Optional[int] = None,
|
237 |
+
width: Optional[int] = None,
|
238 |
+
num_inference_steps: int = 50,
|
239 |
+
timesteps: List[int] = None,
|
240 |
+
sigmas: List[float] = None,
|
241 |
+
denoising_end: Optional[float] = None,
|
242 |
+
guidance_scale: float = 5.0,
|
243 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
244 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
245 |
+
num_images_per_prompt: Optional[int] = 1,
|
246 |
+
eta: float = 0.0,
|
247 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
248 |
+
latents: Optional[torch.Tensor] = None,
|
249 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
250 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
251 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
252 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
253 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
254 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
255 |
+
output_type: Optional[str] = "pil",
|
256 |
+
return_dict: bool = True,
|
257 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
258 |
+
guidance_rescale: float = 0.0,
|
259 |
+
original_size: Optional[Tuple[int, int]] = None,
|
260 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
261 |
+
target_size: Optional[Tuple[int, int]] = None,
|
262 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
263 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
264 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
265 |
+
clip_skip: Optional[int] = None,
|
266 |
+
callback_on_step_end: Optional[
|
267 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
268 |
+
] = None,
|
269 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
270 |
+
save_timesteps = None,
|
271 |
+
clip=None,
|
272 |
+
use_clip=True,
|
273 |
+
encoder='clip',
|
274 |
+
):
|
275 |
+
|
276 |
+
callback = None
|
277 |
+
callback_steps = None
|
278 |
+
|
279 |
+
if callback is not None:
|
280 |
+
deprecate(
|
281 |
+
"callback",
|
282 |
+
"1.0.0",
|
283 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
284 |
+
)
|
285 |
+
if callback_steps is not None:
|
286 |
+
deprecate(
|
287 |
+
"callback_steps",
|
288 |
+
"1.0.0",
|
289 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
290 |
+
)
|
291 |
+
|
292 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
293 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
294 |
+
|
295 |
+
# 0. Default height and width to unet
|
296 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
297 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
298 |
+
|
299 |
+
original_size = original_size or (height, width)
|
300 |
+
target_size = target_size or (height, width)
|
301 |
+
|
302 |
+
# 1. Check inputs. Raise error if not correct
|
303 |
+
self.check_inputs(
|
304 |
+
prompt,
|
305 |
+
prompt_2,
|
306 |
+
height,
|
307 |
+
width,
|
308 |
+
callback_steps,
|
309 |
+
negative_prompt,
|
310 |
+
negative_prompt_2,
|
311 |
+
prompt_embeds,
|
312 |
+
negative_prompt_embeds,
|
313 |
+
pooled_prompt_embeds,
|
314 |
+
negative_pooled_prompt_embeds,
|
315 |
+
ip_adapter_image,
|
316 |
+
ip_adapter_image_embeds,
|
317 |
+
callback_on_step_end_tensor_inputs,
|
318 |
+
)
|
319 |
+
|
320 |
+
self._guidance_scale = guidance_scale
|
321 |
+
self._guidance_rescale = guidance_rescale
|
322 |
+
self._clip_skip = clip_skip
|
323 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
324 |
+
self._denoising_end = denoising_end
|
325 |
+
self._interrupt = False
|
326 |
+
|
327 |
+
# 2. Define call parameters
|
328 |
+
if prompt is not None and isinstance(prompt, str):
|
329 |
+
batch_size = 1
|
330 |
+
elif prompt is not None and isinstance(prompt, list):
|
331 |
+
batch_size = len(prompt)
|
332 |
+
else:
|
333 |
+
batch_size = prompt_embeds.shape[0]
|
334 |
+
|
335 |
+
device = self._execution_device
|
336 |
+
|
337 |
+
# 3. Encode input prompt
|
338 |
+
lora_scale = (
|
339 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
340 |
+
)
|
341 |
+
|
342 |
+
(
|
343 |
+
prompt_embeds,
|
344 |
+
negative_prompt_embeds,
|
345 |
+
pooled_prompt_embeds,
|
346 |
+
negative_pooled_prompt_embeds,
|
347 |
+
) = self.encode_prompt(
|
348 |
+
prompt=prompt,
|
349 |
+
prompt_2=prompt_2,
|
350 |
+
device=device,
|
351 |
+
num_images_per_prompt=num_images_per_prompt,
|
352 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
353 |
+
negative_prompt=negative_prompt,
|
354 |
+
negative_prompt_2=negative_prompt_2,
|
355 |
+
prompt_embeds=prompt_embeds,
|
356 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
357 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
358 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
359 |
+
lora_scale=lora_scale,
|
360 |
+
clip_skip=self.clip_skip,
|
361 |
+
)
|
362 |
+
|
363 |
+
# 4. Prepare timesteps
|
364 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
365 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
366 |
+
)
|
367 |
+
|
368 |
+
# 5. Prepare latent variables
|
369 |
+
num_channels_latents = self.unet.config.in_channels
|
370 |
+
latents = self.prepare_latents(
|
371 |
+
batch_size * num_images_per_prompt,
|
372 |
+
num_channels_latents,
|
373 |
+
height,
|
374 |
+
width,
|
375 |
+
prompt_embeds.dtype,
|
376 |
+
device,
|
377 |
+
generator,
|
378 |
+
latents,
|
379 |
+
)
|
380 |
+
|
381 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
382 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
383 |
+
|
384 |
+
# 7. Prepare added time ids & embeddings
|
385 |
+
add_text_embeds = pooled_prompt_embeds
|
386 |
+
if self.text_encoder_2 is None:
|
387 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
388 |
+
else:
|
389 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
390 |
+
|
391 |
+
add_time_ids = self._get_add_time_ids(
|
392 |
+
original_size,
|
393 |
+
crops_coords_top_left,
|
394 |
+
target_size,
|
395 |
+
dtype=prompt_embeds.dtype,
|
396 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
397 |
+
)
|
398 |
+
if negative_original_size is not None and negative_target_size is not None:
|
399 |
+
negative_add_time_ids = self._get_add_time_ids(
|
400 |
+
negative_original_size,
|
401 |
+
negative_crops_coords_top_left,
|
402 |
+
negative_target_size,
|
403 |
+
dtype=prompt_embeds.dtype,
|
404 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
405 |
+
)
|
406 |
+
else:
|
407 |
+
negative_add_time_ids = add_time_ids
|
408 |
+
|
409 |
+
if self.do_classifier_free_guidance:
|
410 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
411 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
412 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
413 |
+
|
414 |
+
prompt_embeds = prompt_embeds.to(device)
|
415 |
+
add_text_embeds = add_text_embeds.to(device)
|
416 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
417 |
+
|
418 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
419 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
420 |
+
ip_adapter_image,
|
421 |
+
ip_adapter_image_embeds,
|
422 |
+
device,
|
423 |
+
batch_size * num_images_per_prompt,
|
424 |
+
self.do_classifier_free_guidance,
|
425 |
+
)
|
426 |
+
|
427 |
+
# 8. Denoising loop
|
428 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
429 |
+
|
430 |
+
# 8.1 Apply denoising_end
|
431 |
+
if (
|
432 |
+
self.denoising_end is not None
|
433 |
+
and isinstance(self.denoising_end, float)
|
434 |
+
and self.denoising_end > 0
|
435 |
+
and self.denoising_end < 1
|
436 |
+
):
|
437 |
+
discrete_timestep_cutoff = int(
|
438 |
+
round(
|
439 |
+
self.scheduler.config.num_train_timesteps
|
440 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
441 |
+
)
|
442 |
+
)
|
443 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
444 |
+
timesteps = timesteps[:num_inference_steps]
|
445 |
+
|
446 |
+
# 9. Optionally get Guidance Scale Embedding
|
447 |
+
timestep_cond = None
|
448 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
449 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
450 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
451 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
452 |
+
).to(device=device, dtype=latents.dtype)
|
453 |
+
|
454 |
+
self._num_timesteps = len(timesteps)
|
455 |
+
clip_features = []
|
456 |
+
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
457 |
+
for i, t in enumerate(timesteps):
|
458 |
+
if self.interrupt:
|
459 |
+
continue
|
460 |
+
|
461 |
+
# expand the latents if we are doing classifier free guidance
|
462 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
463 |
+
|
464 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
465 |
+
|
466 |
+
# predict the noise residual
|
467 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
468 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
469 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
470 |
+
noise_pred = self.unet(
|
471 |
+
latent_model_input,
|
472 |
+
t,
|
473 |
+
encoder_hidden_states=prompt_embeds,
|
474 |
+
timestep_cond=timestep_cond,
|
475 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
476 |
+
added_cond_kwargs=added_cond_kwargs,
|
477 |
+
return_dict=False,
|
478 |
+
)[0]
|
479 |
+
|
480 |
+
# perform guidance
|
481 |
+
if self.do_classifier_free_guidance:
|
482 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
483 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
484 |
+
|
485 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
486 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
487 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
488 |
+
|
489 |
+
# compute the previous noisy sample x_t -> x_t-1
|
490 |
+
latents_dtype = latents.dtype
|
491 |
+
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
492 |
+
|
493 |
+
# compute the previous noisy sample x_t -> x_t-1
|
494 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
|
495 |
+
try:
|
496 |
+
denoised = latents['pred_original_sample'] / self.vae.config.scaling_factor
|
497 |
+
except:
|
498 |
+
denoised = latents['denoised'] / self.vae.config.scaling_factor
|
499 |
+
latents = latents['prev_sample']
|
500 |
+
|
501 |
+
|
502 |
+
# if latents.dtype != latents_dtype:
|
503 |
+
# if torch.backends.mps.is_available():
|
504 |
+
# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
505 |
+
latents = latents.to(self.vae.dtype)
|
506 |
+
denoised = denoised.to(self.vae.dtype)
|
507 |
+
|
508 |
+
if i in save_timesteps:
|
509 |
+
if use_clip:
|
510 |
+
denoised = self.vae.decode(denoised.to(self.vae.dtype), return_dict=False)[0]
|
511 |
+
denoised = F.adaptive_avg_pool2d(denoised, (224, 224))
|
512 |
+
denoised = normalize_image(denoised)
|
513 |
+
if 'dino' in encoder:
|
514 |
+
denoised = clip(denoised)
|
515 |
+
denoised = denoised.pooler_output
|
516 |
+
denoised = denoised.cpu().view(denoised.shape[0], -1)
|
517 |
+
else:
|
518 |
+
denoised = clip.get_image_features(denoised)
|
519 |
+
denoised = denoised.cpu().view(denoised.shape[0], -1)
|
520 |
+
|
521 |
+
# denoised = clip.get_image_features(denoised)
|
522 |
+
clip_features.append(denoised)
|
523 |
+
|
524 |
+
|
525 |
+
|
526 |
+
|
527 |
+
if callback_on_step_end is not None:
|
528 |
+
callback_kwargs = {}
|
529 |
+
for k in callback_on_step_end_tensor_inputs:
|
530 |
+
callback_kwargs[k] = locals()[k]
|
531 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
532 |
+
|
533 |
+
latents = callback_outputs.pop("latents", latents)
|
534 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
535 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
536 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
537 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
538 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
539 |
+
)
|
540 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
541 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
542 |
+
|
543 |
+
# call the callback, if provided
|
544 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
545 |
+
# progress_bar.update()
|
546 |
+
if callback is not None and i % callback_steps == 0:
|
547 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
548 |
+
callback(step_idx, t, latents)
|
549 |
+
|
550 |
+
if XLA_AVAILABLE:
|
551 |
+
xm.mark_step()
|
552 |
+
|
553 |
+
if not output_type == "latent":
|
554 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
555 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
556 |
+
|
557 |
+
if needs_upcasting:
|
558 |
+
self.upcast_vae()
|
559 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
560 |
+
elif latents.dtype != self.vae.dtype:
|
561 |
+
if torch.backends.mps.is_available():
|
562 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
563 |
+
self.vae = self.vae.to(latents.dtype)
|
564 |
+
|
565 |
+
# unscale/denormalize the latents
|
566 |
+
# denormalize with the mean and std if available and not None
|
567 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
568 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
569 |
+
if has_latents_mean and has_latents_std:
|
570 |
+
latents_mean = (
|
571 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
572 |
+
)
|
573 |
+
latents_std = (
|
574 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
575 |
+
)
|
576 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
577 |
+
else:
|
578 |
+
latents = latents / self.vae.config.scaling_factor
|
579 |
+
|
580 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
581 |
+
|
582 |
+
# cast back to fp16 if needed
|
583 |
+
if needs_upcasting:
|
584 |
+
self.vae.to(dtype=torch.float16)
|
585 |
+
else:
|
586 |
+
image = latents
|
587 |
+
|
588 |
+
if not output_type == "latent":
|
589 |
+
|
590 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
591 |
+
|
592 |
+
# Offload all models
|
593 |
+
self.maybe_free_model_hooks()
|
594 |
+
|
595 |
+
return image, clip_features
|
596 |
+
|
597 |
+
@torch.no_grad()
|
598 |
+
|
599 |
+
def call_flux(
|
600 |
+
self,
|
601 |
+
prompt: Union[str, List[str]] = None,
|
602 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
603 |
+
height: Optional[int] = None,
|
604 |
+
width: Optional[int] = None,
|
605 |
+
num_inference_steps: int = 28,
|
606 |
+
timesteps: List[int] = None,
|
607 |
+
guidance_scale: float = 7.0,
|
608 |
+
num_images_per_prompt: Optional[int] = 1,
|
609 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
610 |
+
latents: Optional[torch.FloatTensor] = None,
|
611 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
612 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
613 |
+
output_type: Optional[str] = "pil",
|
614 |
+
return_dict: bool = True,
|
615 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
616 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
617 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
618 |
+
max_sequence_length: int = 512,
|
619 |
+
verbose=False,
|
620 |
+
save_timesteps = None,
|
621 |
+
clip=None,
|
622 |
+
use_clip=True,
|
623 |
+
encoder='clip'
|
624 |
+
):
|
625 |
+
|
626 |
+
|
627 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
628 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
629 |
+
|
630 |
+
# 1. Check inputs. Raise error if not correct
|
631 |
+
self.check_inputs(
|
632 |
+
prompt,
|
633 |
+
prompt_2,
|
634 |
+
height,
|
635 |
+
width,
|
636 |
+
prompt_embeds=prompt_embeds,
|
637 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
638 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
639 |
+
max_sequence_length=max_sequence_length,
|
640 |
+
)
|
641 |
+
|
642 |
+
self._guidance_scale = guidance_scale
|
643 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
644 |
+
self._interrupt = False
|
645 |
+
|
646 |
+
# 2. Define call parameters
|
647 |
+
if prompt is not None and isinstance(prompt, str):
|
648 |
+
batch_size = 1
|
649 |
+
elif prompt is not None and isinstance(prompt, list):
|
650 |
+
batch_size = len(prompt)
|
651 |
+
else:
|
652 |
+
batch_size = prompt_embeds.shape[0]
|
653 |
+
|
654 |
+
device = self._execution_device
|
655 |
+
|
656 |
+
lora_scale = (
|
657 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
658 |
+
)
|
659 |
+
(
|
660 |
+
prompt_embeds,
|
661 |
+
pooled_prompt_embeds,
|
662 |
+
text_ids,
|
663 |
+
) = self.encode_prompt(
|
664 |
+
prompt=prompt,
|
665 |
+
prompt_2=prompt_2,
|
666 |
+
prompt_embeds=prompt_embeds,
|
667 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
668 |
+
device=device,
|
669 |
+
num_images_per_prompt=num_images_per_prompt,
|
670 |
+
max_sequence_length=max_sequence_length,
|
671 |
+
lora_scale=lora_scale,
|
672 |
+
)
|
673 |
+
|
674 |
+
# 4. Prepare latent variables
|
675 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
676 |
+
latents, latent_image_ids = self.prepare_latents(
|
677 |
+
batch_size * num_images_per_prompt,
|
678 |
+
num_channels_latents,
|
679 |
+
height,
|
680 |
+
width,
|
681 |
+
prompt_embeds.dtype,
|
682 |
+
device,
|
683 |
+
generator,
|
684 |
+
latents,
|
685 |
+
)
|
686 |
+
|
687 |
+
# 5. Prepare timesteps
|
688 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
689 |
+
image_seq_len = latents.shape[1]
|
690 |
+
mu = calculate_shift(
|
691 |
+
image_seq_len,
|
692 |
+
self.scheduler.config.base_image_seq_len,
|
693 |
+
self.scheduler.config.max_image_seq_len,
|
694 |
+
self.scheduler.config.base_shift,
|
695 |
+
self.scheduler.config.max_shift,
|
696 |
+
)
|
697 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
698 |
+
self.scheduler,
|
699 |
+
num_inference_steps,
|
700 |
+
device,
|
701 |
+
timesteps,
|
702 |
+
sigmas,
|
703 |
+
mu=mu,
|
704 |
+
)
|
705 |
+
|
706 |
+
timesteps = timesteps
|
707 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
708 |
+
self._num_timesteps = len(timesteps)
|
709 |
+
|
710 |
+
# handle guidance
|
711 |
+
if self.transformer.config.guidance_embeds:
|
712 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
713 |
+
guidance = guidance.expand(latents.shape[0])
|
714 |
+
else:
|
715 |
+
guidance = None
|
716 |
+
clip_features = []
|
717 |
+
# 6. Denoising loop
|
718 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
719 |
+
for i, t in enumerate(timesteps):
|
720 |
+
if self.interrupt:
|
721 |
+
continue
|
722 |
+
|
723 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
724 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
725 |
+
|
726 |
+
noise_pred = self.transformer(
|
727 |
+
hidden_states=latents,
|
728 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
729 |
+
timestep=timestep / 1000,
|
730 |
+
guidance=guidance,
|
731 |
+
pooled_projections=pooled_prompt_embeds,
|
732 |
+
encoder_hidden_states=prompt_embeds,
|
733 |
+
txt_ids=text_ids,
|
734 |
+
img_ids=latent_image_ids,
|
735 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
736 |
+
return_dict=False,
|
737 |
+
)[0]
|
738 |
+
|
739 |
+
# compute the previous noisy sample x_t -> x_t-1
|
740 |
+
latents_dtype = latents.dtype
|
741 |
+
# compute the previous noisy sample x_t -> x_t-1
|
742 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=True)
|
743 |
+
|
744 |
+
|
745 |
+
denoised = latents['prev_sample']
|
746 |
+
latents = latents['prev_sample']
|
747 |
+
|
748 |
+
denoised = self._unpack_latents(denoised, height, width, self.vae_scale_factor)
|
749 |
+
denoised = (denoised / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
750 |
+
denoised = self.vae.decode(denoised, return_dict=False)[0]
|
751 |
+
denoised = F.adaptive_avg_pool2d(denoised, (224, 224))
|
752 |
+
if 'dino' in encoder:
|
753 |
+
outputs = clip(**inputs)
|
754 |
+
denoised = outputs.pooler_output
|
755 |
+
denoised = denoised.cpu().view(denoised.shape[0], -1)
|
756 |
+
else:
|
757 |
+
denoised = clip.get_image_features(denoised)
|
758 |
+
denoised = denoised.cpu().view(denoised.shape[0], -1)
|
759 |
+
|
760 |
+
clip_features.append()
|
761 |
+
|
762 |
+
if latents.dtype != latents_dtype:
|
763 |
+
if torch.backends.mps.is_available():
|
764 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
765 |
+
latents = latents.to(latents_dtype)
|
766 |
+
|
767 |
+
if callback_on_step_end is not None:
|
768 |
+
callback_kwargs = {}
|
769 |
+
for k in callback_on_step_end_tensor_inputs:
|
770 |
+
callback_kwargs[k] = locals()[k]
|
771 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
772 |
+
|
773 |
+
latents = callback_outputs.pop("latents", latents)
|
774 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
775 |
+
|
776 |
+
# call the callback, if provided
|
777 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
778 |
+
progress_bar.update()
|
779 |
+
|
780 |
+
if XLA_AVAILABLE:
|
781 |
+
xm.mark_step()
|
782 |
+
|
783 |
+
if output_type == "latent":
|
784 |
+
image = latents
|
785 |
+
return image
|
786 |
+
|
787 |
+
else:
|
788 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
789 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
790 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
791 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
792 |
+
|
793 |
+
# Offload all models
|
794 |
+
self.maybe_free_model_hooks()
|
795 |
+
|
796 |
+
if not return_dict:
|
797 |
+
return (image,)
|
798 |
+
|
799 |
+
return image, clip_features
|
800 |
+
|
801 |
+
|
802 |
+
|
803 |
+
|
804 |
+
def get_diffusion_clip_directions(prompts, unet, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True,encoder='clip'):
|
805 |
+
device = unet.device
|
806 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
807 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
808 |
+
|
809 |
+
os.makedirs(savepath_training_images, exist_ok=True)
|
810 |
+
|
811 |
+
|
812 |
+
if len(noise_scheduler.timesteps) != max_denoising_steps:
|
813 |
+
noise_scheduler_orig = noise_scheduler
|
814 |
+
max_denoising_steps_orig = len(noise_scheduler.timesteps)
|
815 |
+
noise_scheduler.set_timesteps(max_denoising_steps)
|
816 |
+
timesteps_distilled = noise_scheduler.timesteps
|
817 |
+
|
818 |
+
noise_scheduler.set_timesteps(max_denoising_steps_orig)
|
819 |
+
timesteps_full = noise_scheduler.timesteps
|
820 |
+
save_timesteps = []
|
821 |
+
for timesteps_to_distilled in range(max_denoising_steps):
|
822 |
+
# Get the value from timesteps_distilled that we want to find in timesteps_full
|
823 |
+
value_to_find = timesteps_distilled[timesteps_to_distilled]
|
824 |
+
timesteps_to_full = (timesteps_full == value_to_find).nonzero().item()
|
825 |
+
save_timesteps.append(timesteps_to_full)
|
826 |
+
|
827 |
+
guidance_scale = 7
|
828 |
+
else:
|
829 |
+
max_denoising_steps_orig = max_denoising_steps
|
830 |
+
save_timesteps = [i for i in range(max_denoising_steps_orig)]
|
831 |
+
guidance_scale = 7
|
832 |
+
if max_denoising_steps_orig <=4:
|
833 |
+
guidance_scale = 0
|
834 |
+
|
835 |
+
noise_scheduler.set_timesteps(max_denoising_steps_orig)
|
836 |
+
# if max_denoising_steps_orig == 1:
|
837 |
+
# noise_scheduler.set_timesteps(timesteps=[399],
|
838 |
+
# device=device)
|
839 |
+
|
840 |
+
weight_dtype = unet.dtype
|
841 |
+
device = unet.device
|
842 |
+
StableDiffusionXLPipeline.__call__ = call_sdxl
|
843 |
+
pipe = StableDiffusionXLPipeline(vae = vae,
|
844 |
+
text_encoder= text_encoders[0],
|
845 |
+
text_encoder_2=text_encoders[1],
|
846 |
+
tokenizer = tokenizers[0],
|
847 |
+
tokenizer_2= tokenizers[1],
|
848 |
+
unet=unet,
|
849 |
+
scheduler=noise_scheduler)
|
850 |
+
pipe.to(unet.device)
|
851 |
+
# print(guidance_scale, max_denoising_steps_orig, save_timesteps)
|
852 |
+
images, clip_features = pipe(prompts, guidance_scale=guidance_scale, num_inference_steps = max_denoising_steps_orig, clip=clip, save_timesteps =save_timesteps, use_clip=use_clip, encoder=encoder)
|
853 |
+
|
854 |
+
return images, torch.stack(clip_features)
|
855 |
+
|
856 |
+
|
857 |
+
|
858 |
+
def get_flux_clip_directions(prompts, transformer, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True):
|
859 |
+
device = transformer.device
|
860 |
+
FluxPipeline.__call__ = call_flux
|
861 |
+
pipe = FluxPipeline(noise_scheduler,
|
862 |
+
vae,
|
863 |
+
text_encoders[0],
|
864 |
+
tokenizers[0],
|
865 |
+
text_encoders[1],
|
866 |
+
tokenizers[1],
|
867 |
+
transformer,
|
868 |
+
)
|
869 |
+
pipe.set_progress_bar_config(disable=True)
|
870 |
+
|
871 |
+
os.makedirs(savepath_training_images, exist_ok=True)
|
872 |
+
|
873 |
+
images, clip_features = pipe(
|
874 |
+
prompts,
|
875 |
+
height=height,
|
876 |
+
width=width,
|
877 |
+
guidance_scale=0,
|
878 |
+
num_inference_steps=4,
|
879 |
+
max_sequence_length=256,
|
880 |
+
num_images_per_prompt=1,
|
881 |
+
output_type='pil',
|
882 |
+
clip=clip
|
883 |
+
)
|
884 |
+
|
885 |
+
return images, torch.stack(clip_features)
|
886 |
+
|
887 |
+
|
888 |
+
|
889 |
+
|
890 |
+
def get_diffusion_clip_directions(prompts, unet, tokenizers, text_encoders, vae, noise_scheduler, clip, batchsize=1, height=1024, width=1024, max_denoising_steps=4, savepath_training_images=None, use_clip=True,encoder='clip', num_images_per_prompt=1):
|
891 |
+
|
892 |
+
|
893 |
+
device = unet.device
|
894 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
895 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
896 |
+
os.makedirs(savepath_training_images, exist_ok=True)
|
897 |
+
|
898 |
+
|
899 |
+
if len(noise_scheduler.timesteps) != max_denoising_steps:
|
900 |
+
noise_scheduler_orig = noise_scheduler
|
901 |
+
max_denoising_steps_orig = len(noise_scheduler.timesteps)
|
902 |
+
noise_scheduler.set_timesteps(max_denoising_steps)
|
903 |
+
timesteps_distilled = noise_scheduler.timesteps
|
904 |
+
|
905 |
+
noise_scheduler.set_timesteps(max_denoising_steps_orig)
|
906 |
+
timesteps_full = noise_scheduler.timesteps
|
907 |
+
save_timesteps = []
|
908 |
+
for timesteps_to_distilled in range(max_denoising_steps):
|
909 |
+
# Get the value from timesteps_distilled that we want to find in timesteps_full
|
910 |
+
value_to_find = timesteps_distilled[timesteps_to_distilled]
|
911 |
+
timesteps_to_full = (timesteps_full == value_to_find).nonzero().item()
|
912 |
+
save_timesteps.append(timesteps_to_full)
|
913 |
+
|
914 |
+
guidance_scale = 7
|
915 |
+
else:
|
916 |
+
max_denoising_steps_orig = max_denoising_steps
|
917 |
+
save_timesteps = [i for i in range(max_denoising_steps_orig)]
|
918 |
+
guidance_scale = 7
|
919 |
+
if max_denoising_steps_orig <=4:
|
920 |
+
guidance_scale = 0
|
921 |
+
|
922 |
+
noise_scheduler.set_timesteps(max_denoising_steps_orig)
|
923 |
+
# if max_denoising_steps_orig == 1:
|
924 |
+
# noise_scheduler.set_timesteps(timesteps=[399],
|
925 |
+
# device=device)
|
926 |
+
|
927 |
+
weight_dtype = unet.dtype
|
928 |
+
device = unet.device
|
929 |
+
StableDiffusionXLPipeline.__call__ = call_sdxl
|
930 |
+
pipe = StableDiffusionXLPipeline(vae = vae,
|
931 |
+
text_encoder= text_encoders[0],
|
932 |
+
text_encoder_2=text_encoders[1],
|
933 |
+
tokenizer = tokenizers[0],
|
934 |
+
tokenizer_2= tokenizers[1],
|
935 |
+
unet=unet,
|
936 |
+
scheduler=noise_scheduler)
|
937 |
+
pipe.to(unet.device)
|
938 |
+
# print(guidance_scale, max_denoising_steps_orig, save_timesteps)
|
939 |
+
images, clip_features = pipe(prompts, guidance_scale=guidance_scale, num_inference_steps = max_denoising_steps_orig, clip=clip, save_timesteps =save_timesteps, use_clip=use_clip, encoder=encoder)
|
940 |
+
|
941 |
+
return images, torch.stack(clip_features)
|
942 |
+
|
943 |
+
|
944 |
+
|
945 |
+
|