# This code is modified from the Huggingface repository: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py, and import argparse import hashlib import itertools import json import logging import math import os import warnings from pathlib import Path import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import HfApi, create_repo from model_pipeline import ( CustomDiffusionAttnProcessor, CustomDiffusionPipeline, set_use_memory_efficient_attention_xformers, ) from packaging import version from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig from utils import ( CustomDiffusionDataset, PromptDataset, collate_fn, filter, getanchorprompts, ) import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, ) from diffusers.models.cross_attention import CrossAttention from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.14.0") logger = get_logger(__name__) def create_custom_diffusion(unet, parameter_group): for name, params in unet.named_parameters(): if parameter_group == "cross-attn": if 'attn2.to_k' in name or 'attn2.to_v' in name: params.requires_grad = True else: params.requires_grad = False elif parameter_group == 'full-weight': params.requires_grad = True elif parameter_group == 'embedding': params.requires_grad = False else: raise ValueError( "parameter_group argument only cross-attn, full-weight, embedding" ) # change attn class def change_attn(unet): for layer in unet.children(): if type(layer) == CrossAttention: bound_method = set_use_memory_efficient_attention_xformers.__get__( layer, layer.__class__) setattr( layer, 'set_use_memory_efficient_attention_xformers', bound_method) else: change_attn(layer) change_attn(unet) unet.set_attn_processor(CustomDiffusionAttnProcessor()) return unet def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f"./image_{i}.png\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} instance_prompt: {prompt} tags: - stable-diffusion - stable-diffusion-diffusers - text-to-image - diffusers - custom diffusion inference: true --- """ model_card = f""" # Custom Diffusion - {repo_id} These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n {img_str[0]} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( RobertaSeriesModelWithTransformation, ) return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def freeze_params(params): for param in params: param.requires_grad = False def parse_args(input_args=None): parser = argparse.ArgumentParser( description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--concept_type", type=str, required=True, choices=['style', 'object', 'memorization'], help='the type of removed concepts' ) parser.add_argument( "--caption_target", type=str, required=True, help="target style to remove, used when kldiv loss", ) parser.add_argument( "--instance_data_dir", type=str, default=None, help="A folder containing the training data of instance images.", ) parser.add_argument( "--class_data_dir", type=str, default=None, help="A folder containing the training data of class images.", ) parser.add_argument( "--instance_prompt", type=str, help="The prompt with identifier specifying the instance", ) parser.add_argument( "--class_prompt", type=str, default=None, help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( "--mem_impath", type=str, default="", help='the path to saved memorized image. Required when concept_type is memorization' ) parser.add_argument( "--validation_prompt", type=str, default=None, help="A prompt that is used during validation to verify that the model is learning.", ) parser.add_argument( "--num_validation_images", type=int, default=2, help="Number of images that should be generated during validation with `validation_prompt`.", ) parser.add_argument( "--validation_steps", type=int, default=500, help=( "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) parser.add_argument( "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--train_size", type=int, default=1000, help='the number of generated images used for ablating the concept' ) parser.add_argument( "--output_dir", type=str, default="custom-diffusion-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--num_class_images", type=int, default=1000, help=( "Minimal anchor class images. If there are not enough images already present in" " class_data_dir, additional images will be sampled with class_prompt." ), ) parser.add_argument( "--num_class_prompts", type=int, default=200, help=( "Minimal prompts used to generate anchor class images" ), ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--center_crop", default=False, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" " cropped. The images will be resized to the resolution first before cropping." ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=250, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=( "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" " for more docs" ), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--dataloader_num_workers", type=int, default=2, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--parameter_group", type=str, default='cross-attn', choices=['full-weight', 'cross-attn', 'embedding'], help='parameter groups to finetune. Default: full-weight for memorization and cross-attn for others' ) parser.add_argument( "--loss_type_reverse", type=str, default='model-based', help="loss type for reverse fine-tuning", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--prior_generation_precision", type=str, default=None, choices=["no", "fp32", "fp16", "bf16"], help=( "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) parser.add_argument( "--concepts_list", type=str, default=None, help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", ) parser.add_argument( "--openai_key", type=str, default="", help=( "OPENAI API key. required for ablating objects and memorized images." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") parser.add_argument("--noaug", action="store_true", help="Dont apply augmentation during data augmentation when this flag is enabled.") if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: if args.concepts_list is None: if args.class_data_dir is None: raise ValueError( "You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: # logger is not available yet if args.class_data_dir is not None: warnings.warn( "You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: warnings.warn( "You need not use --class_prompt without --with_prior_preservation.") return args def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( total_limit=args.checkpoints_total_limit) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_dir=logging_dir, project_config=accelerator_project_config, ) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError( "Make sure to install wandb if you want to use it for logging during training.") import wandb # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: print(vars(args)) accelerator.init_trackers("custom-diffusion", config=vars(args)) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) if args.concepts_list is None: args.concepts_list = [ { "instance_prompt": args.instance_prompt, "class_prompt": args.class_prompt, "instance_data_dir": args.instance_data_dir, "class_data_dir": args.class_data_dir, "caption_target": args.caption_target, } ] else: with open(args.concepts_list, "r") as f: args.concepts_list = json.load(f) # Generate class images if prior preservation is enabled. for i, concept in enumerate(args.concepts_list): # directly path to ablation images and its corresponding prompts is provided. if (concept['instance_prompt'] is not None and concept['instance_data_dir'] is not None): break class_images_dir = Path(concept['class_data_dir']) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True, exist_ok=True) os.makedirs(f'{class_images_dir}/images', exist_ok=True) # we need to generate training images if len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config) pipeline.set_progress_bar_config(disable=True) pipeline.to(accelerator.device) # need to create prompts using class_prompt. if not os.path.isfile(concept['class_prompt']): # style based prompts are retrieved from laion dataset if args.concept_type == 'style': with open(os.path.join(class_images_dir, 'painting.txt')) as f: class_prompt_collection = [ x.strip() for x in f.readlines()] # LLM based prompt collection. else: class_prompt = concept['class_prompt'] # in case of object query chatGPT to generate captions containing the anchor category if args.concept_type == 'object': class_prompt_collection, _ = getanchorprompts( pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts) with open(class_images_dir / 'caption_anchor.txt', 'w') as f: for prompt in class_prompt_collection: f.write(prompt + '\n') # in case of memorization query chatGPT to generate different captions that can be paraphrase of the origianl caption elif args.concept_type == 'memorization': class_prompt_collection, caption_target = getanchorprompts( pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts, mem_impath=args.mem_impath) concept['caption_target'] += f';*+{caption_target}' with open(class_images_dir / 'caption_target.txt', 'w') as f: f.write(concept['caption_target']) print(class_prompt_collection, concept['caption_target']) # class_prompt is filepath to prompts. else: with open(concept['class_prompt']) as f: class_prompt_collection = [ x.strip() for x in f.readlines()] num_new_images = args.num_class_images logger.info( f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset( class_prompt_collection, num_new_images) sample_dataloader = torch.utils.data.DataLoader( sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) if os.path.exists(f'{class_images_dir}/caption.txt'): os.remove(f'{class_images_dir}/caption.txt') if os.path.exists(f'{class_images_dir}/images.txt'): os.remove(f'{class_images_dir}/images.txt') for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): accelerator.wait_for_everyone() with open(f'{class_images_dir}/caption.txt', 'a') as f1, open(f'{class_images_dir}/images.txt', 'a') as f2: images = pipeline(example["prompt"], num_inference_steps=25, guidance_scale=6., eta=1.).images for i, image in enumerate(images): hash_image = hashlib.sha1( image.tobytes()).hexdigest() image_filename = class_images_dir / \ f"images/{example['index'][i]}-{hash_image}.jpg" image.save(image_filename) f2.write(str(image_filename)+'\n') f1.write('\n'.join(example["prompt"]) + '\n') accelerator.wait_for_everyone() del pipeline if args.concept_type == 'memorization': filter(class_images_dir, args.mem_impath, outpath=str(class_images_dir / 'filtered')) if os.path.exists(class_images_dir / 'caption_target.txt'): with open(class_images_dir / 'caption_target.txt', 'r') as f: concept['caption_target'] = f.readlines()[0].strip() class_images_dir = class_images_dir / 'filtered' concept['class_prompt'] = os.path.join( class_images_dir, 'caption.txt') concept['class_data_dir'] = os.path.join( class_images_dir, 'images.txt') concept['instance_prompt'] = os.path.join( class_images_dir, 'caption.txt') concept['instance_data_dir'] = os.path.join( class_images_dir, 'images.txt') if torch.cuda.is_available(): torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: print(args.hub_model_id or Path(args.output_dir).name) repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ) print(repo_id) repo_id = args.hub_model_id # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name, revision=args.revision, use_fast=False, ) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) vae.requires_grad_(False) if args.parameter_group != 'embedding': text_encoder.requires_grad_(False) unet = create_custom_diffusion(unet, args.parameter_group) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype if accelerator.mixed_precision != "fp16": unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError( "xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.parameter_group == 'embedding': text_encoder.gradient_checkpointing_enable() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.with_prior_preservation: args.learning_rate = args.learning_rate * 2. # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Adding a modifier token which is optimized #### # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py modifier_token_id = [] if args.parameter_group == 'embedding': assert args.concept_type != 'memorization', "embedding finetuning is not supported for memorization" for concept in args.concept_list: # Convert the caption_target to ids token_ids = tokenizer.encode( [concept['caption_target']], add_special_tokens=False) print(token_ids) # Check if initializer_token is a single token or a sequence of tokens modifier_token_id += token_ids # Freeze all parameters except for the token embeddings in text encoder params_to_freeze = itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), ) freeze_params(params_to_freeze) params_to_optimize = itertools.chain( text_encoder.get_input_embeddings().parameters()) else: if args.parameter_group == 'cross-attn': params_to_optimize = itertools.chain([x[1] for x in unet.named_parameters() if ( 'attn2.to_k' in x[0] or 'attn2.to_v' in x[0])]) if args.parameter_group == 'full-weight': params_to_optimize = itertools.chain(unet.parameters()) # Optimizer creation optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # Dataset and DataLoaders creation: train_dataset = CustomDiffusionDataset( concepts_list=args.concepts_list, concept_type=args.concept_type, tokenizer=tokenizer, with_prior_preservation=args.with_prior_preservation, size=args.resolution, center_crop=args.center_crop, num_class_images=args.num_class_images, hflip=args.hflip, aug=not args.noaug, ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=lambda examples: collate_fn( examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`. if args.parameter_group == 'embedding': text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil( args.max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = args.train_batch_size * \ accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info( f" Instantaneous batch size per device = {args.train_batch_size}") logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info( f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % ( num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): if args.parameter_group == 'embedding': text_encoder.train() else: unet.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % args.gradient_accumulation_steps == 0: progress_bar.update(1) continue with accelerator.accumulate(unet) if args.parameter_group != 'embedding' else accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to( dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise( latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_anchor_hidden_states = text_encoder( batch["input_anchor_ids"])[0] # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample with torch.no_grad(): model_pred_anchor = unet(noisy_latents[:encoder_anchor_hidden_states.size( 0)], timesteps[:encoder_anchor_hidden_states.size(0)], encoder_anchor_hidden_states).sample # Get the target for loss depending on the prediction type if args.loss_type_reverse == 'model-based': if args.with_prior_preservation: target_prior = torch.chunk(noise, 2, dim=0)[1] target = model_pred_anchor else: if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity( latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: target, target_prior = torch.chunk(target, 2, dim=0) if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk( model_pred, 2, dim=0) mask = torch.chunk(batch["mask"], 2, dim=0)[0] # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = ( (loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() # Compute prior loss prior_loss = F.mse_loss( model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: mask = batch["mask"] loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = ( (loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() accelerator.backward(loss) # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if args.parameter_group == 'embedding': if accelerator.num_processes > 1: grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad else: grads_text_encoder = text_encoder.get_input_embeddings().weight.grad # Get the index for tokens that we want to zero the grads for index_grads_to_zero = torch.arange( len(tokenizer)) != modifier_token_id[0] for i in range(len(modifier_token_id[1:])): index_grads_to_zero = index_grads_to_zero & ( torch.arange(len(tokenizer)) != modifier_token_id[i]) grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[index_grads_to_zero, :].fill_(0) if accelerator.sync_gradients: params_to_clip = ( itertools.chain(text_encoder.parameters()) if args.parameter_group == 'embedding' else itertools.chain([x[1] for x in unet.named_parameters() if ('attn2' in x[0])]) if args.parameter_group == 'cross-attn' else itertools.chain(unet.parameters()) ) accelerator.clip_grad_norm_( params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: pipeline = CustomDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model( text_encoder), tokenizer=tokenizer, revision=args.revision, modifier_token_id=modifier_token_id, ) save_path = os.path.join( args.output_dir, f"delta-{global_step}") pipeline.save_pretrained( save_path, parameter_group=args.parameter_group) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item( ), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if accelerator.is_main_process: if args.validation_prompt is not None and global_step % args.validation_steps == 0: logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) # create pipeline pipeline = CustomDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, revision=args.revision, modifier_token_id=modifier_token_id, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator( device=accelerator.device).manual_seed(args.seed) images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.).images[0] for _ in range(args.num_validation_images) ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( "validation", np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { "validation": [ wandb.Image( image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) del pipeline torch.cuda.empty_cache() accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) pipeline = CustomDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, revision=args.revision, modifier_token_id=modifier_token_id, ) save_path = os.path.join(args.output_dir, "delta.bin") pipeline.save_pretrained( save_path, parameter_group=args.parameter_group) # run inference if args.validation_prompt and args.num_validation_images > 0: pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator( device=accelerator.device).manual_seed(args.seed) images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.).images[0] for _ in range(args.num_validation_images) ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( "test", np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { "test": [ wandb.Image( image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) if args.push_to_hub: save_model_card( repo_id, images=images, base_model=args.pretrained_model_name_or_path, prompt=args.instance_prompt, repo_folder=args.output_dir, ) api = HfApi(token=args.hub_token) api.upload_folder( repo_id=repo_id, folder_path=args.output_dir, path_in_repo='.', repo_type='model' ) accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args)