from typing import Literal, Union, Optional import torch, gc, os from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast from transformers import ( AutoModel, CLIPModel, CLIPProcessor, ) from huggingface_hub import hf_hub_download from diffusers import ( UNet2DConditionModel, SchedulerMixin, StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, AutoencoderKL, FluxTransformer2DModel, ) import copy from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, FlowMatchEulerDiscreteScheduler, ) from diffusers import LCMScheduler, AutoencoderTiny import sys sys.path.append('.') from .flux_utils import * TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this def load_diffusers_model( pretrained_model_name_or_path: str, v2: bool = False, clip_skip: Optional[int] = None, weight_dtype: torch.dtype = torch.float32, ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: # VAE はいらない if v2: tokenizer = CLIPTokenizer.from_pretrained( TOKENIZER_V2_MODEL_NAME, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", # default is clip skip 2 num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) else: tokenizer = CLIPTokenizer.from_pretrained( TOKENIZER_V1_MODEL_NAME, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) return tokenizer, text_encoder, unet def load_checkpoint_model( checkpoint_path: str, v2: bool = False, clip_skip: Optional[int] = None, weight_dtype: torch.dtype = torch.float32, ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: pipe = StableDiffusionPipeline.from_ckpt( checkpoint_path, upcast_attention=True if v2 else False, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) unet = pipe.unet tokenizer = pipe.tokenizer text_encoder = pipe.text_encoder if clip_skip is not None: if v2: text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) else: text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) del pipe return tokenizer, text_encoder, unet def load_models( pretrained_model_name_or_path: str, scheduler_name: AVAILABLE_SCHEDULERS, v2: bool = False, v_pred: bool = False, weight_dtype: torch.dtype = torch.float32, ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: if pretrained_model_name_or_path.endswith( ".ckpt" ) or pretrained_model_name_or_path.endswith(".safetensors"): tokenizer, text_encoder, unet = load_checkpoint_model( pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype ) else: # diffusers tokenizer, text_encoder, unet = load_diffusers_model( pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype ) # VAE はいらない scheduler = create_noise_scheduler( scheduler_name, prediction_type="v_prediction" if v_pred else "epsilon", ) return tokenizer, text_encoder, unet, scheduler def load_diffusers_model_xl( pretrained_model_name_or_path: str, weight_dtype: torch.dtype = torch.float32, ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet tokenizers = [ CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer_2", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, pad_token_id=0, # same as open clip ), ] text_encoders = [ CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), CLIPTextModelWithProjection.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder_2", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), ] unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) return tokenizers, text_encoders, unet def load_checkpoint_model_xl( checkpoint_path: str, weight_dtype: torch.dtype = torch.float32, ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: pipe = StableDiffusionXLPipeline.from_single_file( checkpoint_path, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) unet = pipe.unet tokenizers = [pipe.tokenizer, pipe.tokenizer_2] text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if len(text_encoders) == 2: text_encoders[1].pad_token_id = 0 del pipe return tokenizers, text_encoders, unet def load_models_xl_( pretrained_model_name_or_path: str, scheduler_name: AVAILABLE_SCHEDULERS, weight_dtype: torch.dtype = torch.float32, ) -> tuple[ list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel, SchedulerMixin, ]: if pretrained_model_name_or_path.endswith( ".ckpt" ) or pretrained_model_name_or_path.endswith(".safetensors"): ( tokenizers, text_encoders, unet, ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) else: # diffusers ( tokenizers, text_encoders, unet, ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) scheduler = create_noise_scheduler(scheduler_name) return tokenizers, text_encoders, unet, scheduler def create_noise_scheduler( scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", ) -> SchedulerMixin: # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 name = scheduler_name.lower().replace(" ", "_") if name == "ddim": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, prediction_type=prediction_type, # これでいいの? ) elif name == "ddpm": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, prediction_type=prediction_type, ) elif name == "lms": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, prediction_type=prediction_type, ) elif name == "euler_a": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, # clip_sample=False, prediction_type=prediction_type, ) else: raise ValueError(f"Unknown scheduler name: {name}") return scheduler def load_models_xl(params): """ Load all required models for training Args: params: Dictionary containing model parameters and configurations Returns: dict: Dictionary containing all loaded models and tokenizers """ device = params['device'] weight_dtype = params['weight_dtype'] # Load SDXL components (UNet, text encoders, tokenizers) scheduler_name = 'ddim' tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_( params['pretrained_model_name_or_path'], scheduler_name=scheduler_name, ) # Move text encoders to device and set to eval mode for text_encoder in text_encoders: text_encoder.to(device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() # Set up UNet unet.to(device, dtype=weight_dtype) unet.requires_grad_(False) unet.eval() # Load tiny VAE for efficiency vae = AutoencoderTiny.from_pretrained( "madebyollin/taesdxl", torch_dtype=weight_dtype ) vae = vae.to(device, dtype=weight_dtype) vae.requires_grad_(False) # Load appropriate encoder (CLIP or DinoV2) if params['encoder'] == 'dinov2-small': clip_model = AutoModel.from_pretrained( 'facebook/dinov2-small', torch_dtype=weight_dtype ) clip_processor= None else: clip_model = CLIPModel.from_pretrained( "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", torch_dtype=weight_dtype ) clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M") clip_model = clip_model.to(device, dtype=weight_dtype) clip_model.requires_grad_(False) # If using DMD checkpoint, load it if params['distilled'] != 'None': if '.safetensors' in params['distilled']: unet.load_state_dict(load_file(params['distilled'], device=device)) elif 'dmd2' in params['distilled']: repo_name = "tianweiy/DMD2" ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin" unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name))) else: unet.load_state_dict(torch.load(params['distilled'])) # Set up LCM scheduler for DMD noise_scheduler = LCMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, prediction_type="epsilon", original_inference_steps=1000 ) noise_scheduler.set_timesteps(params['max_denoising_steps']) pipe = StableDiffusionXLPipeline(vae = vae, text_encoder = text_encoders[0], text_encoder_2 = text_encoders[1], tokenizer = tokenizers[0], tokenizer_2 = tokenizers[1], unet = unet, scheduler = noise_scheduler) pipe.set_progress_bar_config(disable=True) return { 'unet': unet, 'vae': vae, 'clip_model': clip_model, 'clip_processor': clip_processor, 'tokenizers': tokenizers, 'text_encoders': text_encoders, 'noise_scheduler': noise_scheduler }, pipe def load_models_flux(params): # Load the tokenizers tokenizer_one = CLIPTokenizer.from_pretrained( params['pretrained_model_name_or_path'], subfolder="tokenizer", torch_dtype=params['weight_dtype'], device_map=params['device'] ) tokenizer_two = T5TokenizerFast.from_pretrained( params['pretrained_model_name_or_path'], subfolder="tokenizer_2", torch_dtype=params['weight_dtype'], device_map=params['device'] ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( params['pretrained_model_name_or_path'], subfolder="scheduler", torch_dtype=params['weight_dtype'], device=params['device'] ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( params['pretrained_model_name_or_path'], ) text_encoder_cls_two = import_model_class_from_model_name_or_path( params['pretrained_model_name_or_path'], subfolder="text_encoder_2" ) # Load the text encoders text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype']) # Load VAE vae = AutoencoderKL.from_pretrained( params['pretrained_model_name_or_path'], subfolder="vae", torch_dtype=params['weight_dtype'], device_map='auto' ) transformer = FluxTransformer2DModel.from_pretrained( params['pretrained_model_name_or_path'], subfolder="transformer", torch_dtype=params['weight_dtype'] ) # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) vae.to(params['device']) transformer.to(params['device']) text_encoder_one.to(params['device']) text_encoder_two.to(params['device']) # Load appropriate encoder (CLIP or DinoV2) if params['encoder'] == 'dinov2-small': clip_model = AutoModel.from_pretrained( 'facebook/dinov2-small', torch_dtype=params['weight_dtype'] ) clip_processor= None else: clip_model = CLIPModel.from_pretrained( "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", torch_dtype=params['weight_dtype'] ) clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M") clip_model = clip_model.to(params['device'], dtype=params['weight_dtype']) clip_model.requires_grad_(False) pipe = FluxPipeline(noise_scheduler, vae, text_encoder_one, tokenizer_one, text_encoder_two, tokenizer_two, transformer, ) pipe.set_progress_bar_config(disable=True) return { 'transformer': transformer, 'vae': vae, 'clip_model': clip_model, 'clip_processor': clip_processor, 'tokenizers': [tokenizer_one, tokenizer_two], 'text_encoders': [text_encoder_one,text_encoder_two], 'noise_scheduler': noise_scheduler }, pipe def save_checkpoint(networks, save_path, weight_dtype): """ Save network weights and perform cleanup Args: networks: Dictionary of LoRA networks to save save_path: Path to save the checkpoints weight_dtype: Data type for the weights """ print("Saving checkpoint...") try: # Create save directory if it doesn't exist os.makedirs(save_path, exist_ok=True) # Save each network's weights for net_idx, network in networks.items(): save_name = f"{save_path}/slider_{net_idx}.pt" try: network.save_weights( save_name, dtype=weight_dtype, ) except Exception as e: print(f"Error saving network {net_idx}: {str(e)}") continue # Cleanup torch.cuda.empty_cache() gc.collect() print("Checkpoint saved successfully.") except Exception as e: print(f"Error during checkpoint saving: {str(e)}") finally: # Ensure memory is cleaned up even if save fails torch.cuda.empty_cache() gc.collect()