Spaces:
Running
on
Zero
Running
on
Zero
import os , torch | |
import argparse | |
import copy | |
import gc | |
import itertools | |
import logging | |
import math | |
import random | |
import shutil | |
import warnings | |
from contextlib import nullcontext | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.utils.checkpoint | |
import transformers | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | |
from huggingface_hub import create_repo, upload_folder | |
from huggingface_hub.utils import insecure_hashlib | |
from PIL import Image | |
from PIL.ImageOps import exif_transpose | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torchvision.transforms.functional import crop | |
from tqdm.auto import tqdm | |
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast | |
import diffusers | |
from diffusers import ( | |
AutoencoderKL, | |
FlowMatchEulerDiscreteScheduler, | |
FluxTransformer2DModel, | |
) | |
from diffusers.optimization import get_scheduler | |
from diffusers.training_utils import ( | |
_set_state_dict_into_text_encoder, | |
cast_training_params, | |
compute_density_for_timestep_sampling, | |
compute_loss_weighting_for_sd3, | |
) | |
from diffusers.utils import ( | |
check_min_version, | |
convert_unet_state_dict_to_peft, | |
is_wandb_available, | |
) | |
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | |
from diffusers.utils.torch_utils import is_compiled_module | |
from collections import defaultdict | |
from typing import List, Optional | |
import argparse | |
import ast | |
from pathlib import Path | |
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler | |
from huggingface_hub import hf_hub_download | |
import gc | |
import torch.nn.functional as F | |
import os | |
import torch | |
from tqdm.auto import tqdm | |
import time, datetime | |
import numpy as np | |
from torch.optim import AdamW | |
from contextlib import ExitStack | |
from safetensors.torch import load_file | |
import torch.nn as nn | |
import random | |
from transformers import CLIPModel | |
from transformers import logging | |
logging.set_verbosity_warning() | |
from diffusers import logging | |
logging.set_verbosity_error() | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
flush() | |
def unwrap_model(model): | |
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) | |
#if is_deepspeed_available(): | |
# options += (DeepSpeedEngine,) | |
while isinstance(model, options): | |
model = model.module | |
return model | |
# Function to log gradients | |
def log_gradients(named_parameters): | |
grad_dict = defaultdict(lambda: defaultdict(float)) | |
for name, param in named_parameters: | |
if param.requires_grad and param.grad is not None: | |
grad_dict[name]['mean'] = param.grad.abs().mean().item() | |
grad_dict[name]['std'] = param.grad.std().item() | |
grad_dict[name]['max'] = param.grad.abs().max().item() | |
grad_dict[name]['min'] = param.grad.abs().min().item() | |
return grad_dict | |
def import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path: str, subfolder: str = "text_encoder", | |
): | |
text_encoder_config = PretrainedConfig.from_pretrained( | |
pretrained_model_name_or_path, subfolder=subfolder | |
, device_map='cuda:0' | |
) | |
model_class = text_encoder_config.architectures[0] | |
if model_class == "CLIPTextModel": | |
from transformers import CLIPTextModel | |
return CLIPTextModel | |
elif model_class == "T5EncoderModel": | |
from transformers import T5EncoderModel | |
return T5EncoderModel | |
else: | |
raise ValueError(f"{model_class} is not supported.") | |
def load_text_encoders(pretrained_model_name_or_path, class_one, class_two, weight_dtype): | |
text_encoder_one = class_one.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=weight_dtype, | |
device_map='cuda:0' | |
) | |
text_encoder_two = class_two.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder_2", | |
torch_dtype=weight_dtype, | |
device_map='cuda:0' | |
) | |
return text_encoder_one, text_encoder_two | |
import matplotlib.pyplot as plt | |
def plot_labeled_images(images, labels): | |
# Determine the number of images | |
n = len(images) | |
# Create a new figure with a single row | |
fig, axes = plt.subplots(1, n, figsize=(5*n, 5)) | |
# If there's only one image, axes will be a single object, not an array | |
if n == 1: | |
axes = [axes] | |
# Plot each image | |
for i, (img, label) in enumerate(zip(images, labels)): | |
# Convert PIL image to numpy array | |
img_array = np.array(img) | |
# Display the image | |
axes[i].imshow(img_array) | |
axes[i].axis('off') # Turn off axis | |
# Set the title (label) for the image | |
axes[i].set_title(label) | |
# Adjust the layout and display the plot | |
plt.tight_layout() | |
plt.show() | |
def tokenize_prompt(tokenizer, prompt, max_sequence_length): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_sequence_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
return text_input_ids | |
def _encode_prompt_with_t5( | |
text_encoder, | |
tokenizer, | |
max_sequence_length=512, | |
prompt=None, | |
num_images_per_prompt=1, | |
device=None, | |
text_input_ids=None, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
if tokenizer is not None: | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_sequence_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
else: | |
if text_input_ids is None: | |
raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
dtype = text_encoder.dtype | |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
_, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
return prompt_embeds | |
def _encode_prompt_with_clip( | |
text_encoder, | |
tokenizer, | |
prompt: str, | |
device=None, | |
text_input_ids=None, | |
num_images_per_prompt: int = 1, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
if tokenizer is not None: | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
else: | |
if text_input_ids is None: | |
raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
# Use pooled output of CLIPTextModel | |
prompt_embeds = prompt_embeds.pooler_output | |
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | |
return prompt_embeds | |
def encode_prompt( | |
text_encoders, | |
tokenizers, | |
prompt: str, | |
max_sequence_length, | |
device=None, | |
num_images_per_prompt: int = 1, | |
text_input_ids_list=None, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
dtype = text_encoders[0].dtype | |
pooled_prompt_embeds = _encode_prompt_with_clip( | |
text_encoder=text_encoders[0], | |
tokenizer=tokenizers[0], | |
prompt=prompt, | |
device=device if device is not None else text_encoders[0].device, | |
num_images_per_prompt=num_images_per_prompt, | |
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, | |
) | |
prompt_embeds = _encode_prompt_with_t5( | |
text_encoder=text_encoders[1], | |
tokenizer=tokenizers[1], | |
max_sequence_length=max_sequence_length, | |
prompt=prompt, | |
num_images_per_prompt=num_images_per_prompt, | |
device=device if device is not None else text_encoders[1].device, | |
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, | |
) | |
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) | |
return prompt_embeds, pooled_prompt_embeds, text_ids | |
def compute_text_embeddings(prompt, text_encoders, tokenizers,max_sequence_length=256): | |
device = text_encoders[0].device | |
with torch.no_grad(): | |
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( | |
text_encoders, tokenizers, prompt, max_sequence_length=max_sequence_length | |
) | |
prompt_embeds = prompt_embeds.to(device) | |
pooled_prompt_embeds = pooled_prompt_embeds.to(device) | |
text_ids = text_ids.to(device) | |
return prompt_embeds, pooled_prompt_embeds, text_ids | |
def get_sigmas(timesteps, n_dim=4, device='cuda:0', dtype=torch.bfloat16): | |
sigmas = noise_scheduler_copy.sigmas.to(device=device, dtype=dtype) | |
schedule_timesteps = noise_scheduler_copy.timesteps.to(device) | |
timesteps = timesteps.to(device) | |
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
sigma = sigmas[step_indices].flatten() | |
while len(sigma.shape) < n_dim: | |
sigma = sigma.unsqueeze(-1) | |
return sigma | |
def plot_history(history): | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5)) | |
ax1.plot(history['concept']) | |
ax1.set_title('Concept Loss') | |
ax2.plot(movingaverage(history['concept'], 10)) | |
ax2.set_title('Moving Average Concept Loss') | |
plt.tight_layout() | |
plt.show() | |
def movingaverage(interval, window_size): | |
window = np.ones(int(window_size))/float(window_size) | |
return np.convolve(interval, window, 'same') | |
def get_noisy_image_flux( | |
image, | |
vae, | |
transformer, | |
scheduler, | |
timesteps_to=1000, | |
generator=None, | |
**kwargs, | |
): | |
""" | |
Gets noisy latents for a given image using Flux pipeline approach. | |
Args: | |
image: PIL image or tensor | |
vae: Flux VAE model | |
transformer: Flux transformer model | |
scheduler: Flux noise scheduler | |
timesteps_to: Target timestep | |
generator: Random generator for reproducibility | |
Returns: | |
tuple: (noisy_latents, noise) | |
""" | |
device = vae.device | |
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) | |
# Preprocess image | |
if not isinstance(image, torch.Tensor): | |
image = image_processor.preprocess(image) | |
image = image.to(device=device, dtype=torch.float32) | |
# Encode through VAE | |
init_latents = vae.encode(image).latents | |
init_latents = vae.config.scaling_factor * init_latents | |
# Get shape for noise | |
shape = init_latents.shape | |
# Generate noise | |
noise = randn_tensor(shape, generator=generator, device=device) | |
# Pack latents using Flux's method | |
init_latents = _pack_latents( | |
init_latents, | |
shape[0], # batch size | |
transformer.config.in_channels // 4, | |
height=shape[2], | |
width=shape[3] | |
) | |
noise = _pack_latents( | |
noise, | |
shape[0], | |
transformer.config.in_channels // 4, | |
height=shape[2], | |
width=shape[3] | |
) | |
# Get timestep | |
timestep = scheduler.timesteps[timesteps_to:timesteps_to+1] | |
# Add noise to latents | |
noisy_latents = scheduler.add_noise(init_latents, noise, timestep) | |
return noisy_latents, noise | |