SliderSpace / utils /model_util.py
RohitGandikota's picture
adding utils for sliders
4cbd4f2 verified
raw
history blame
17.3 kB
from typing import Literal, Union, Optional
import torch, gc, os
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast
from transformers import (
AutoModel,
CLIPModel,
CLIPProcessor,
)
from huggingface_hub import hf_hub_download
from diffusers import (
UNet2DConditionModel,
SchedulerMixin,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
FluxPipeline,
AutoencoderKL,
FluxTransformer2DModel,
)
import copy
from diffusers.schedulers import (
DDIMScheduler,
DDPMScheduler,
LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
)
from diffusers import LCMScheduler, AutoencoderTiny
import sys
sys.path.append('.')
from .flux_utils import *
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
def load_diffusers_model(
pretrained_model_name_or_path: str,
v2: bool = False,
clip_skip: Optional[int] = None,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
# VAE はいらない
if v2:
tokenizer = CLIPTokenizer.from_pretrained(
TOKENIZER_V2_MODEL_NAME,
subfolder="tokenizer",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
# default is clip skip 2
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
else:
tokenizer = CLIPTokenizer.from_pretrained(
TOKENIZER_V1_MODEL_NAME,
subfolder="tokenizer",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="unet",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
return tokenizer, text_encoder, unet
def load_checkpoint_model(
checkpoint_path: str,
v2: bool = False,
clip_skip: Optional[int] = None,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
pipe = StableDiffusionPipeline.from_ckpt(
checkpoint_path,
upcast_attention=True if v2 else False,
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
unet = pipe.unet
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
if clip_skip is not None:
if v2:
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
else:
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
del pipe
return tokenizer, text_encoder, unet
def load_models(
pretrained_model_name_or_path: str,
scheduler_name: AVAILABLE_SCHEDULERS,
v2: bool = False,
v_pred: bool = False,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
if pretrained_model_name_or_path.endswith(
".ckpt"
) or pretrained_model_name_or_path.endswith(".safetensors"):
tokenizer, text_encoder, unet = load_checkpoint_model(
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
)
else: # diffusers
tokenizer, text_encoder, unet = load_diffusers_model(
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
)
# VAE はいらない
scheduler = create_noise_scheduler(
scheduler_name,
prediction_type="v_prediction" if v_pred else "epsilon",
)
return tokenizer, text_encoder, unet, scheduler
def load_diffusers_model_xl(
pretrained_model_name_or_path: str,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
tokenizers = [
CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
),
CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer_2",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
pad_token_id=0, # same as open clip
),
]
text_encoders = [
CLIPTextModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
),
CLIPTextModelWithProjection.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder_2",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
),
]
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="unet",
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
return tokenizers, text_encoders, unet
def load_checkpoint_model_xl(
checkpoint_path: str,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
pipe = StableDiffusionXLPipeline.from_single_file(
checkpoint_path,
torch_dtype=weight_dtype,
cache_dir=DIFFUSERS_CACHE_DIR,
)
unet = pipe.unet
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
if len(text_encoders) == 2:
text_encoders[1].pad_token_id = 0
del pipe
return tokenizers, text_encoders, unet
def load_models_xl_(
pretrained_model_name_or_path: str,
scheduler_name: AVAILABLE_SCHEDULERS,
weight_dtype: torch.dtype = torch.float32,
) -> tuple[
list[CLIPTokenizer],
list[SDXL_TEXT_ENCODER_TYPE],
UNet2DConditionModel,
SchedulerMixin,
]:
if pretrained_model_name_or_path.endswith(
".ckpt"
) or pretrained_model_name_or_path.endswith(".safetensors"):
(
tokenizers,
text_encoders,
unet,
) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
else: # diffusers
(
tokenizers,
text_encoders,
unet,
) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
scheduler = create_noise_scheduler(scheduler_name)
return tokenizers, text_encoders, unet, scheduler
def create_noise_scheduler(
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
) -> SchedulerMixin:
# 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
name = scheduler_name.lower().replace(" ", "_")
if name == "ddim":
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
prediction_type=prediction_type, # これでいいの?
)
elif name == "ddpm":
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
prediction_type=prediction_type,
)
elif name == "lms":
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
prediction_type=prediction_type,
)
elif name == "euler_a":
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
scheduler = EulerAncestralDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
# clip_sample=False,
prediction_type=prediction_type,
)
else:
raise ValueError(f"Unknown scheduler name: {name}")
return scheduler
def load_models_xl(params):
"""
Load all required models for training
Args:
params: Dictionary containing model parameters and configurations
Returns:
dict: Dictionary containing all loaded models and tokenizers
"""
device = params['device']
weight_dtype = params['weight_dtype']
# Load SDXL components (UNet, text encoders, tokenizers)
scheduler_name = 'ddim'
tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_(
params['pretrained_model_name_or_path'],
scheduler_name=scheduler_name,
)
# Move text encoders to device and set to eval mode
for text_encoder in text_encoders:
text_encoder.to(device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# Set up UNet
unet.to(device, dtype=weight_dtype)
unet.requires_grad_(False)
unet.eval()
# Load tiny VAE for efficiency
vae = AutoencoderTiny.from_pretrained(
"madebyollin/taesdxl",
torch_dtype=weight_dtype
)
vae = vae.to(device, dtype=weight_dtype)
vae.requires_grad_(False)
# Load appropriate encoder (CLIP or DinoV2)
if params['encoder'] == 'dinov2-small':
clip_model = AutoModel.from_pretrained(
'facebook/dinov2-small',
torch_dtype=weight_dtype
)
clip_processor= None
else:
clip_model = CLIPModel.from_pretrained(
"wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
torch_dtype=weight_dtype
)
clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
clip_model = clip_model.to(device, dtype=weight_dtype)
clip_model.requires_grad_(False)
# If using DMD checkpoint, load it
if params['distilled'] != 'None':
if '.safetensors' in params['distilled']:
unet.load_state_dict(load_file(params['distilled'], device=device))
elif 'dmd2' in params['distilled']:
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
else:
unet.load_state_dict(torch.load(params['distilled']))
# Set up LCM scheduler for DMD
noise_scheduler = LCMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
prediction_type="epsilon",
original_inference_steps=1000
)
noise_scheduler.set_timesteps(params['max_denoising_steps'])
pipe = StableDiffusionXLPipeline(vae = vae,
text_encoder = text_encoders[0],
text_encoder_2 = text_encoders[1],
tokenizer = tokenizers[0],
tokenizer_2 = tokenizers[1],
unet = unet,
scheduler = noise_scheduler)
pipe.set_progress_bar_config(disable=True)
return {
'unet': unet,
'vae': vae,
'clip_model': clip_model,
'clip_processor': clip_processor,
'tokenizers': tokenizers,
'text_encoders': text_encoders,
'noise_scheduler': noise_scheduler
}, pipe
def load_models_flux(params):
# Load the tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
params['pretrained_model_name_or_path'],
subfolder="tokenizer",
torch_dtype=params['weight_dtype'], device_map=params['device']
)
tokenizer_two = T5TokenizerFast.from_pretrained(
params['pretrained_model_name_or_path'],
subfolder="tokenizer_2",
torch_dtype=params['weight_dtype'], device_map=params['device']
)
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
params['pretrained_model_name_or_path'],
subfolder="scheduler",
torch_dtype=params['weight_dtype'], device=params['device']
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
params['pretrained_model_name_or_path'],
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
params['pretrained_model_name_or_path'], subfolder="text_encoder_2"
)
# Load the text encoders
text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype'])
# Load VAE
vae = AutoencoderKL.from_pretrained(
params['pretrained_model_name_or_path'],
subfolder="vae",
torch_dtype=params['weight_dtype'], device_map='auto'
)
transformer = FluxTransformer2DModel.from_pretrained(
params['pretrained_model_name_or_path'],
subfolder="transformer",
torch_dtype=params['weight_dtype']
)
# We only train the additional adapter LoRA layers
transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
vae.to(params['device'])
transformer.to(params['device'])
text_encoder_one.to(params['device'])
text_encoder_two.to(params['device'])
# Load appropriate encoder (CLIP or DinoV2)
if params['encoder'] == 'dinov2-small':
clip_model = AutoModel.from_pretrained(
'facebook/dinov2-small',
torch_dtype=params['weight_dtype']
)
clip_processor= None
else:
clip_model = CLIPModel.from_pretrained(
"wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
torch_dtype=params['weight_dtype']
)
clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
clip_model = clip_model.to(params['device'], dtype=params['weight_dtype'])
clip_model.requires_grad_(False)
pipe = FluxPipeline(noise_scheduler,
vae,
text_encoder_one,
tokenizer_one,
text_encoder_two,
tokenizer_two,
transformer,
)
pipe.set_progress_bar_config(disable=True)
return {
'transformer': transformer,
'vae': vae,
'clip_model': clip_model,
'clip_processor': clip_processor,
'tokenizers': [tokenizer_one, tokenizer_two],
'text_encoders': [text_encoder_one,text_encoder_two],
'noise_scheduler': noise_scheduler
}, pipe
def save_checkpoint(networks, save_path, weight_dtype):
"""
Save network weights and perform cleanup
Args:
networks: Dictionary of LoRA networks to save
save_path: Path to save the checkpoints
weight_dtype: Data type for the weights
"""
print("Saving checkpoint...")
try:
# Create save directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)
# Save each network's weights
for net_idx, network in networks.items():
save_name = f"{save_path}/slider_{net_idx}.pt"
try:
network.save_weights(
save_name,
dtype=weight_dtype,
)
except Exception as e:
print(f"Error saving network {net_idx}: {str(e)}")
continue
# Cleanup
torch.cuda.empty_cache()
gc.collect()
print("Checkpoint saved successfully.")
except Exception as e:
print(f"Error during checkpoint saving: {str(e)}")
finally:
# Ensure memory is cleaned up even if save fails
torch.cuda.empty_cache()
gc.collect()