import os , torch import argparse import copy import gc import itertools import logging import math import random import shutil import warnings from contextlib import nullcontext from pathlib import Path import numpy as np import torch import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast import diffusers from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, ) from diffusers.utils import ( check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module from collections import defaultdict from typing import List, Optional import argparse import ast from pathlib import Path from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler from huggingface_hub import hf_hub_download import gc import torch.nn.functional as F import os import torch from tqdm.auto import tqdm import time, datetime import numpy as np from torch.optim import AdamW from contextlib import ExitStack from safetensors.torch import load_file import torch.nn as nn import random from transformers import CLIPModel from transformers import logging logging.set_verbosity_warning() from diffusers import logging logging.set_verbosity_error() def flush(): torch.cuda.empty_cache() gc.collect() flush() def unwrap_model(model): options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) #if is_deepspeed_available(): # options += (DeepSpeedEngine,) while isinstance(model, options): model = model.module return model # Function to log gradients def log_gradients(named_parameters): grad_dict = defaultdict(lambda: defaultdict(float)) for name, param in named_parameters: if param.requires_grad and param.grad is not None: grad_dict[name]['mean'] = param.grad.abs().mean().item() grad_dict[name]['std'] = param.grad.std().item() grad_dict[name]['max'] = param.grad.abs().max().item() grad_dict[name]['min'] = param.grad.abs().min().item() return grad_dict def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, subfolder: str = "text_encoder", ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder , device_map='cuda:0' ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") def load_text_encoders(pretrained_model_name_or_path, class_one, class_two, weight_dtype): text_encoder_one = class_one.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype, device_map='cuda:0' ) text_encoder_two = class_two.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder_2", torch_dtype=weight_dtype, device_map='cuda:0' ) return text_encoder_one, text_encoder_two import matplotlib.pyplot as plt def plot_labeled_images(images, labels): # Determine the number of images n = len(images) # Create a new figure with a single row fig, axes = plt.subplots(1, n, figsize=(5*n, 5)) # If there's only one image, axes will be a single object, not an array if n == 1: axes = [axes] # Plot each image for i, (img, label) in enumerate(zip(images, labels)): # Convert PIL image to numpy array img_array = np.array(img) # Display the image axes[i].imshow(img_array) axes[i].axis('off') # Turn off axis # Set the title (label) for the image axes[i].set_title(label) # Adjust the layout and display the plot plt.tight_layout() plt.show() def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids def compute_text_embeddings(prompt, text_encoders, tokenizers,max_sequence_length=256): device = text_encoders[0].device with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders, tokenizers, prompt, max_sequence_length=max_sequence_length ) prompt_embeds = prompt_embeds.to(device) pooled_prompt_embeds = pooled_prompt_embeds.to(device) text_ids = text_ids.to(device) return prompt_embeds, pooled_prompt_embeds, text_ids def get_sigmas(timesteps, n_dim=4, device='cuda:0', dtype=torch.bfloat16): sigmas = noise_scheduler_copy.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler_copy.timesteps.to(device) timesteps = timesteps.to(device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def plot_history(history): fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5)) ax1.plot(history['concept']) ax1.set_title('Concept Loss') ax2.plot(movingaverage(history['concept'], 10)) ax2.set_title('Moving Average Concept Loss') plt.tight_layout() plt.show() def movingaverage(interval, window_size): window = np.ones(int(window_size))/float(window_size) return np.convolve(interval, window, 'same') @torch.no_grad() def get_noisy_image_flux( image, vae, transformer, scheduler, timesteps_to=1000, generator=None, **kwargs, ): """ Gets noisy latents for a given image using Flux pipeline approach. Args: image: PIL image or tensor vae: Flux VAE model transformer: Flux transformer model scheduler: Flux noise scheduler timesteps_to: Target timestep generator: Random generator for reproducibility Returns: tuple: (noisy_latents, noise) """ device = vae.device vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) # Preprocess image if not isinstance(image, torch.Tensor): image = image_processor.preprocess(image) image = image.to(device=device, dtype=torch.float32) # Encode through VAE init_latents = vae.encode(image).latents init_latents = vae.config.scaling_factor * init_latents # Get shape for noise shape = init_latents.shape # Generate noise noise = randn_tensor(shape, generator=generator, device=device) # Pack latents using Flux's method init_latents = _pack_latents( init_latents, shape[0], # batch size transformer.config.in_channels // 4, height=shape[2], width=shape[3] ) noise = _pack_latents( noise, shape[0], transformer.config.in_channels // 4, height=shape[2], width=shape[3] ) # Get timestep timestep = scheduler.timesteps[timesteps_to:timesteps_to+1] # Add noise to latents noisy_latents = scheduler.add_noise(init_latents, noise, timestep) return noisy_latents, noise