Spaces:
Runtime error
Runtime error
# Bootstrapped from: | |
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py | |
import argparse | |
import hashlib | |
import inspect | |
import itertools | |
import math | |
import os | |
import random | |
import re | |
from pathlib import Path | |
from typing import Optional, List, Literal | |
import torch | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.checkpoint | |
from diffusers import ( | |
AutoencoderKL, | |
DDPMScheduler, | |
StableDiffusionPipeline, | |
UNet2DConditionModel, | |
) | |
from diffusers.optimization import get_scheduler | |
from huggingface_hub import HfFolder, Repository, whoami | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import wandb | |
import fire | |
from lora_diffusion import ( | |
PivotalTuningDatasetCapation, | |
extract_lora_ups_down, | |
inject_trainable_lora, | |
inject_trainable_lora_extended, | |
inspect_lora, | |
save_lora_weight, | |
save_all, | |
prepare_clip_model_sets, | |
evaluate_pipe, | |
UNET_EXTENDED_TARGET_REPLACE, | |
) | |
def get_models( | |
pretrained_model_name_or_path, | |
pretrained_vae_name_or_path, | |
revision, | |
placeholder_tokens: List[str], | |
initializer_tokens: List[str], | |
device="cuda:0", | |
): | |
tokenizer = CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=revision, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
revision=revision, | |
) | |
placeholder_token_ids = [] | |
for token, init_tok in zip(placeholder_tokens, initializer_tokens): | |
num_added_tokens = tokenizer.add_tokens(token) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {token}. Please pass a different" | |
" `placeholder_token` that is not already in the tokenizer." | |
) | |
placeholder_token_id = tokenizer.convert_tokens_to_ids(token) | |
placeholder_token_ids.append(placeholder_token_id) | |
# Load models and create wrapper for stable diffusion | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
token_embeds = text_encoder.get_input_embeddings().weight.data | |
if init_tok.startswith("<rand"): | |
# <rand-"sigma">, e.g. <rand-0.5> | |
sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0]) | |
token_embeds[placeholder_token_id] = ( | |
torch.randn_like(token_embeds[0]) * sigma_val | |
) | |
print( | |
f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}" | |
) | |
print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}") | |
elif init_tok == "<zero>": | |
token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0]) | |
else: | |
token_ids = tokenizer.encode(init_tok, add_special_tokens=False) | |
# Check if initializer_token is a single token or a sequence of tokens | |
if len(token_ids) > 1: | |
raise ValueError("The initializer token must be a single token.") | |
initializer_token_id = token_ids[0] | |
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_vae_name_or_path or pretrained_model_name_or_path, | |
subfolder=None if pretrained_vae_name_or_path else "vae", | |
revision=None if pretrained_vae_name_or_path else revision, | |
) | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
revision=revision, | |
) | |
return ( | |
text_encoder.to(device), | |
vae.to(device), | |
unet.to(device), | |
tokenizer, | |
placeholder_token_ids, | |
) | |
def text2img_dataloader( | |
train_dataset, | |
train_batch_size, | |
tokenizer, | |
vae, | |
text_encoder, | |
cached_latents: bool = False, | |
): | |
if cached_latents: | |
cached_latents_dataset = [] | |
for idx in tqdm(range(len(train_dataset))): | |
batch = train_dataset[idx] | |
# rint(batch) | |
latents = vae.encode( | |
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device) | |
).latent_dist.sample() | |
latents = latents * 0.18215 | |
batch["instance_images"] = latents.squeeze(0) | |
cached_latents_dataset.append(batch) | |
def collate_fn(examples): | |
input_ids = [example["instance_prompt_ids"] for example in examples] | |
pixel_values = [example["instance_images"] for example in examples] | |
pixel_values = torch.stack(pixel_values) | |
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
input_ids = tokenizer.pad( | |
{"input_ids": input_ids}, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
batch = { | |
"input_ids": input_ids, | |
"pixel_values": pixel_values, | |
} | |
if examples[0].get("mask", None) is not None: | |
batch["mask"] = torch.stack([example["mask"] for example in examples]) | |
return batch | |
if cached_latents: | |
train_dataloader = torch.utils.data.DataLoader( | |
cached_latents_dataset, | |
batch_size=train_batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
) | |
print("PTI : Using cached latent.") | |
else: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
) | |
return train_dataloader | |
def inpainting_dataloader( | |
train_dataset, train_batch_size, tokenizer, vae, text_encoder | |
): | |
def collate_fn(examples): | |
input_ids = [example["instance_prompt_ids"] for example in examples] | |
pixel_values = [example["instance_images"] for example in examples] | |
mask_values = [example["instance_masks"] for example in examples] | |
masked_image_values = [ | |
example["instance_masked_images"] for example in examples | |
] | |
# Concat class and instance examples for prior preservation. | |
# We do this to avoid doing two forward passes. | |
if examples[0].get("class_prompt_ids", None) is not None: | |
input_ids += [example["class_prompt_ids"] for example in examples] | |
pixel_values += [example["class_images"] for example in examples] | |
mask_values += [example["class_masks"] for example in examples] | |
masked_image_values += [ | |
example["class_masked_images"] for example in examples | |
] | |
pixel_values = ( | |
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() | |
) | |
mask_values = ( | |
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() | |
) | |
masked_image_values = ( | |
torch.stack(masked_image_values) | |
.to(memory_format=torch.contiguous_format) | |
.float() | |
) | |
input_ids = tokenizer.pad( | |
{"input_ids": input_ids}, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
batch = { | |
"input_ids": input_ids, | |
"pixel_values": pixel_values, | |
"mask_values": mask_values, | |
"masked_image_values": masked_image_values, | |
} | |
if examples[0].get("mask", None) is not None: | |
batch["mask"] = torch.stack([example["mask"] for example in examples]) | |
return batch | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
) | |
return train_dataloader | |
def loss_step( | |
batch, | |
unet, | |
vae, | |
text_encoder, | |
scheduler, | |
train_inpainting=False, | |
t_mutliplier=1.0, | |
mixed_precision=False, | |
mask_temperature=1.0, | |
cached_latents: bool = False, | |
): | |
weight_dtype = torch.float32 | |
if not cached_latents: | |
latents = vae.encode( | |
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) | |
).latent_dist.sample() | |
latents = latents * 0.18215 | |
if train_inpainting: | |
masked_image_latents = vae.encode( | |
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) | |
).latent_dist.sample() | |
masked_image_latents = masked_image_latents * 0.18215 | |
mask = F.interpolate( | |
batch["mask_values"].to(dtype=weight_dtype).to(unet.device), | |
scale_factor=1 / 8, | |
) | |
else: | |
latents = batch["pixel_values"] | |
if train_inpainting: | |
masked_image_latents = batch["masked_image_latents"] | |
mask = batch["mask_values"] | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
timesteps = torch.randint( | |
0, | |
int(scheduler.config.num_train_timesteps * t_mutliplier), | |
(bsz,), | |
device=latents.device, | |
) | |
timesteps = timesteps.long() | |
noisy_latents = scheduler.add_noise(latents, noise, timesteps) | |
if train_inpainting: | |
latent_model_input = torch.cat( | |
[noisy_latents, mask, masked_image_latents], dim=1 | |
) | |
else: | |
latent_model_input = noisy_latents | |
if mixed_precision: | |
with torch.cuda.amp.autocast(): | |
encoder_hidden_states = text_encoder( | |
batch["input_ids"].to(text_encoder.device) | |
)[0] | |
model_pred = unet( | |
latent_model_input, timesteps, encoder_hidden_states | |
).sample | |
else: | |
encoder_hidden_states = text_encoder( | |
batch["input_ids"].to(text_encoder.device) | |
)[0] | |
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample | |
if scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif scheduler.config.prediction_type == "v_prediction": | |
target = scheduler.get_velocity(latents, noise, timesteps) | |
else: | |
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") | |
if batch.get("mask", None) is not None: | |
mask = ( | |
batch["mask"] | |
.to(model_pred.device) | |
.reshape( | |
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 | |
) | |
) | |
# resize to match model_pred | |
mask = F.interpolate( | |
mask.float(), | |
size=model_pred.shape[-2:], | |
mode="nearest", | |
) | |
mask = (mask + 0.01).pow(mask_temperature) | |
mask = mask / mask.max() | |
model_pred = model_pred * mask | |
target = target * mask | |
loss = ( | |
F.mse_loss(model_pred.float(), target.float(), reduction="none") | |
.mean([1, 2, 3]) | |
.mean() | |
) | |
return loss | |
def train_inversion( | |
unet, | |
vae, | |
text_encoder, | |
dataloader, | |
num_steps: int, | |
scheduler, | |
index_no_updates, | |
optimizer, | |
save_steps: int, | |
placeholder_token_ids, | |
placeholder_tokens, | |
save_path: str, | |
tokenizer, | |
lr_scheduler, | |
test_image_path: str, | |
cached_latents: bool, | |
accum_iter: int = 1, | |
log_wandb: bool = False, | |
wandb_log_prompt_cnt: int = 10, | |
class_token: str = "person", | |
train_inpainting: bool = False, | |
mixed_precision: bool = False, | |
clip_ti_decay: bool = True, | |
): | |
progress_bar = tqdm(range(num_steps)) | |
progress_bar.set_description("Steps") | |
global_step = 0 | |
# Original Emb for TI | |
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone() | |
if log_wandb: | |
preped_clip = prepare_clip_model_sets() | |
index_updates = ~index_no_updates | |
loss_sum = 0.0 | |
for epoch in range(math.ceil(num_steps / len(dataloader))): | |
unet.eval() | |
text_encoder.train() | |
for batch in dataloader: | |
lr_scheduler.step() | |
with torch.set_grad_enabled(True): | |
loss = ( | |
loss_step( | |
batch, | |
unet, | |
vae, | |
text_encoder, | |
scheduler, | |
train_inpainting=train_inpainting, | |
mixed_precision=mixed_precision, | |
cached_latents=cached_latents, | |
) | |
/ accum_iter | |
) | |
loss.backward() | |
loss_sum += loss.detach().item() | |
if global_step % accum_iter == 0: | |
# print gradient of text encoder embedding | |
print( | |
text_encoder.get_input_embeddings() | |
.weight.grad[index_updates, :] | |
.norm(dim=-1) | |
.mean() | |
) | |
optimizer.step() | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
# normalize embeddings | |
if clip_ti_decay: | |
pre_norm = ( | |
text_encoder.get_input_embeddings() | |
.weight[index_updates, :] | |
.norm(dim=-1, keepdim=True) | |
) | |
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) | |
text_encoder.get_input_embeddings().weight[ | |
index_updates | |
] = F.normalize( | |
text_encoder.get_input_embeddings().weight[ | |
index_updates, : | |
], | |
dim=-1, | |
) * ( | |
pre_norm + lambda_ * (0.4 - pre_norm) | |
) | |
print(pre_norm) | |
current_norm = ( | |
text_encoder.get_input_embeddings() | |
.weight[index_updates, :] | |
.norm(dim=-1) | |
) | |
text_encoder.get_input_embeddings().weight[ | |
index_no_updates | |
] = orig_embeds_params[index_no_updates] | |
print(f"Current Norm : {current_norm}") | |
global_step += 1 | |
progress_bar.update(1) | |
logs = { | |
"loss": loss.detach().item(), | |
"lr": lr_scheduler.get_last_lr()[0], | |
} | |
progress_bar.set_postfix(**logs) | |
if global_step % save_steps == 0: | |
save_all( | |
unet=unet, | |
text_encoder=text_encoder, | |
placeholder_token_ids=placeholder_token_ids, | |
placeholder_tokens=placeholder_tokens, | |
save_path=os.path.join( | |
save_path, f"step_inv_{global_step}.safetensors" | |
), | |
save_lora=False, | |
) | |
if log_wandb: | |
with torch.no_grad(): | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=None, | |
feature_extractor=None, | |
) | |
# open all images in test_image_path | |
images = [] | |
for file in os.listdir(test_image_path): | |
if ( | |
file.lower().endswith(".png") | |
or file.lower().endswith(".jpg") | |
or file.lower().endswith(".jpeg") | |
): | |
images.append( | |
Image.open(os.path.join(test_image_path, file)) | |
) | |
wandb.log({"loss": loss_sum / save_steps}) | |
loss_sum = 0.0 | |
wandb.log( | |
evaluate_pipe( | |
pipe, | |
target_images=images, | |
class_token=class_token, | |
learnt_token="".join(placeholder_tokens), | |
n_test=wandb_log_prompt_cnt, | |
n_step=50, | |
clip_model_sets=preped_clip, | |
) | |
) | |
if global_step >= num_steps: | |
return | |
def perform_tuning( | |
unet, | |
vae, | |
text_encoder, | |
dataloader, | |
num_steps, | |
scheduler, | |
optimizer, | |
save_steps: int, | |
placeholder_token_ids, | |
placeholder_tokens, | |
save_path, | |
lr_scheduler_lora, | |
lora_unet_target_modules, | |
lora_clip_target_modules, | |
mask_temperature, | |
out_name: str, | |
tokenizer, | |
test_image_path: str, | |
cached_latents: bool, | |
log_wandb: bool = False, | |
wandb_log_prompt_cnt: int = 10, | |
class_token: str = "person", | |
train_inpainting: bool = False, | |
): | |
progress_bar = tqdm(range(num_steps)) | |
progress_bar.set_description("Steps") | |
global_step = 0 | |
weight_dtype = torch.float16 | |
unet.train() | |
text_encoder.train() | |
if log_wandb: | |
preped_clip = prepare_clip_model_sets() | |
loss_sum = 0.0 | |
for epoch in range(math.ceil(num_steps / len(dataloader))): | |
for batch in dataloader: | |
lr_scheduler_lora.step() | |
optimizer.zero_grad() | |
loss = loss_step( | |
batch, | |
unet, | |
vae, | |
text_encoder, | |
scheduler, | |
train_inpainting=train_inpainting, | |
t_mutliplier=0.8, | |
mixed_precision=True, | |
mask_temperature=mask_temperature, | |
cached_latents=cached_latents, | |
) | |
loss_sum += loss.detach().item() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_( | |
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 | |
) | |
optimizer.step() | |
progress_bar.update(1) | |
logs = { | |
"loss": loss.detach().item(), | |
"lr": lr_scheduler_lora.get_last_lr()[0], | |
} | |
progress_bar.set_postfix(**logs) | |
global_step += 1 | |
if global_step % save_steps == 0: | |
save_all( | |
unet, | |
text_encoder, | |
placeholder_token_ids=placeholder_token_ids, | |
placeholder_tokens=placeholder_tokens, | |
save_path=os.path.join( | |
save_path, f"step_{global_step}.safetensors" | |
), | |
target_replace_module_text=lora_clip_target_modules, | |
target_replace_module_unet=lora_unet_target_modules, | |
) | |
moved = ( | |
torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) | |
.mean() | |
.item() | |
) | |
print("LORA Unet Moved", moved) | |
moved = ( | |
torch.tensor( | |
list(itertools.chain(*inspect_lora(text_encoder).values())) | |
) | |
.mean() | |
.item() | |
) | |
print("LORA CLIP Moved", moved) | |
if log_wandb: | |
with torch.no_grad(): | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=None, | |
feature_extractor=None, | |
) | |
# open all images in test_image_path | |
images = [] | |
for file in os.listdir(test_image_path): | |
if file.endswith(".png") or file.endswith(".jpg"): | |
images.append( | |
Image.open(os.path.join(test_image_path, file)) | |
) | |
wandb.log({"loss": loss_sum / save_steps}) | |
loss_sum = 0.0 | |
wandb.log( | |
evaluate_pipe( | |
pipe, | |
target_images=images, | |
class_token=class_token, | |
learnt_token="".join(placeholder_tokens), | |
n_test=wandb_log_prompt_cnt, | |
n_step=50, | |
clip_model_sets=preped_clip, | |
) | |
) | |
if global_step >= num_steps: | |
break | |
save_all( | |
unet, | |
text_encoder, | |
placeholder_token_ids=placeholder_token_ids, | |
placeholder_tokens=placeholder_tokens, | |
save_path=os.path.join(save_path, f"{out_name}.safetensors"), | |
target_replace_module_text=lora_clip_target_modules, | |
target_replace_module_unet=lora_unet_target_modules, | |
) | |
def train( | |
instance_data_dir: str, | |
pretrained_model_name_or_path: str, | |
output_dir: str, | |
train_text_encoder: bool = True, | |
pretrained_vae_name_or_path: str = None, | |
revision: Optional[str] = None, | |
perform_inversion: bool = True, | |
use_template: Literal[None, "object", "style"] = None, | |
train_inpainting: bool = False, | |
placeholder_tokens: str = "", | |
placeholder_token_at_data: Optional[str] = None, | |
initializer_tokens: Optional[str] = None, | |
seed: int = 42, | |
resolution: int = 512, | |
color_jitter: bool = True, | |
train_batch_size: int = 1, | |
sample_batch_size: int = 1, | |
max_train_steps_tuning: int = 1000, | |
max_train_steps_ti: int = 1000, | |
save_steps: int = 100, | |
gradient_accumulation_steps: int = 4, | |
gradient_checkpointing: bool = False, | |
lora_rank: int = 4, | |
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, | |
lora_clip_target_modules={"CLIPAttention"}, | |
lora_dropout_p: float = 0.0, | |
lora_scale: float = 1.0, | |
use_extended_lora: bool = False, | |
clip_ti_decay: bool = True, | |
learning_rate_unet: float = 1e-4, | |
learning_rate_text: float = 1e-5, | |
learning_rate_ti: float = 5e-4, | |
continue_inversion: bool = False, | |
continue_inversion_lr: Optional[float] = None, | |
use_face_segmentation_condition: bool = False, | |
cached_latents: bool = True, | |
use_mask_captioned_data: bool = False, | |
mask_temperature: float = 1.0, | |
scale_lr: bool = False, | |
lr_scheduler: str = "linear", | |
lr_warmup_steps: int = 0, | |
lr_scheduler_lora: str = "linear", | |
lr_warmup_steps_lora: int = 0, | |
weight_decay_ti: float = 0.00, | |
weight_decay_lora: float = 0.001, | |
use_8bit_adam: bool = False, | |
device="cuda:0", | |
extra_args: Optional[dict] = None, | |
log_wandb: bool = False, | |
wandb_log_prompt_cnt: int = 10, | |
wandb_project_name: str = "new_pti_project", | |
wandb_entity: str = "new_pti_entity", | |
proxy_token: str = "person", | |
enable_xformers_memory_efficient_attention: bool = False, | |
out_name: str = "final_lora", | |
): | |
torch.manual_seed(seed) | |
if log_wandb: | |
wandb.init( | |
project=wandb_project_name, | |
entity=wandb_entity, | |
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}", | |
reinit=True, | |
config={ | |
**(extra_args if extra_args is not None else {}), | |
}, | |
) | |
if output_dir is not None: | |
os.makedirs(output_dir, exist_ok=True) | |
# print(placeholder_tokens, initializer_tokens) | |
if len(placeholder_tokens) == 0: | |
placeholder_tokens = [] | |
print("PTI : Placeholder Tokens not given, using null token") | |
else: | |
placeholder_tokens = placeholder_tokens.split("|") | |
assert ( | |
sorted(placeholder_tokens) == placeholder_tokens | |
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" | |
if initializer_tokens is None: | |
print("PTI : Initializer Tokens not given, doing random inits") | |
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens) | |
else: | |
initializer_tokens = initializer_tokens.split("|") | |
assert len(initializer_tokens) == len( | |
placeholder_tokens | |
), "Unequal Initializer token for Placeholder tokens." | |
if proxy_token is not None: | |
class_token = proxy_token | |
class_token = "".join(initializer_tokens) | |
if placeholder_token_at_data is not None: | |
tok, pat = placeholder_token_at_data.split("|") | |
token_map = {tok: pat} | |
else: | |
token_map = {"DUMMY": "".join(placeholder_tokens)} | |
print("PTI : Placeholder Tokens", placeholder_tokens) | |
print("PTI : Initializer Tokens", initializer_tokens) | |
# get the models | |
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( | |
pretrained_model_name_or_path, | |
pretrained_vae_name_or_path, | |
revision, | |
placeholder_tokens, | |
initializer_tokens, | |
device=device, | |
) | |
noise_scheduler = DDPMScheduler.from_config( | |
pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
if gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if enable_xformers_memory_efficient_attention: | |
from diffusers.utils.import_utils import is_xformers_available | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError( | |
"xformers is not available. Make sure it is installed correctly" | |
) | |
if scale_lr: | |
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size | |
text_encoder_lr = ( | |
learning_rate_text * gradient_accumulation_steps * train_batch_size | |
) | |
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size | |
else: | |
unet_lr = learning_rate_unet | |
text_encoder_lr = learning_rate_text | |
ti_lr = learning_rate_ti | |
train_dataset = PivotalTuningDatasetCapation( | |
instance_data_root=instance_data_dir, | |
token_map=token_map, | |
use_template=use_template, | |
tokenizer=tokenizer, | |
size=resolution, | |
color_jitter=color_jitter, | |
use_face_segmentation_condition=use_face_segmentation_condition, | |
use_mask_captioned_data=use_mask_captioned_data, | |
train_inpainting=train_inpainting, | |
) | |
train_dataset.blur_amount = 200 | |
if train_inpainting: | |
assert not cached_latents, "Cached latents not supported for inpainting" | |
train_dataloader = inpainting_dataloader( | |
train_dataset, train_batch_size, tokenizer, vae, text_encoder | |
) | |
else: | |
train_dataloader = text2img_dataloader( | |
train_dataset, | |
train_batch_size, | |
tokenizer, | |
vae, | |
text_encoder, | |
cached_latents=cached_latents, | |
) | |
index_no_updates = torch.arange(len(tokenizer)) != -1 | |
for tok_id in placeholder_token_ids: | |
index_no_updates[tok_id] = False | |
unet.requires_grad_(False) | |
vae.requires_grad_(False) | |
params_to_freeze = itertools.chain( | |
text_encoder.text_model.encoder.parameters(), | |
text_encoder.text_model.final_layer_norm.parameters(), | |
text_encoder.text_model.embeddings.position_embedding.parameters(), | |
) | |
for param in params_to_freeze: | |
param.requires_grad = False | |
if cached_latents: | |
vae = None | |
# STEP 1 : Perform Inversion | |
if perform_inversion: | |
ti_optimizer = optim.AdamW( | |
text_encoder.get_input_embeddings().parameters(), | |
lr=ti_lr, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=weight_decay_ti, | |
) | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=ti_optimizer, | |
num_warmup_steps=lr_warmup_steps, | |
num_training_steps=max_train_steps_ti, | |
) | |
train_inversion( | |
unet, | |
vae, | |
text_encoder, | |
train_dataloader, | |
max_train_steps_ti, | |
cached_latents=cached_latents, | |
accum_iter=gradient_accumulation_steps, | |
scheduler=noise_scheduler, | |
index_no_updates=index_no_updates, | |
optimizer=ti_optimizer, | |
lr_scheduler=lr_scheduler, | |
save_steps=save_steps, | |
placeholder_tokens=placeholder_tokens, | |
placeholder_token_ids=placeholder_token_ids, | |
save_path=output_dir, | |
test_image_path=instance_data_dir, | |
log_wandb=log_wandb, | |
wandb_log_prompt_cnt=wandb_log_prompt_cnt, | |
class_token=class_token, | |
train_inpainting=train_inpainting, | |
mixed_precision=False, | |
tokenizer=tokenizer, | |
clip_ti_decay=clip_ti_decay, | |
) | |
del ti_optimizer | |
# Next perform Tuning with LoRA: | |
if not use_extended_lora: | |
unet_lora_params, _ = inject_trainable_lora( | |
unet, | |
r=lora_rank, | |
target_replace_module=lora_unet_target_modules, | |
dropout_p=lora_dropout_p, | |
scale=lora_scale, | |
) | |
else: | |
print("PTI : USING EXTENDED UNET!!!") | |
lora_unet_target_modules = ( | |
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE | |
) | |
print("PTI : Will replace modules: ", lora_unet_target_modules) | |
unet_lora_params, _ = inject_trainable_lora_extended( | |
unet, r=lora_rank, target_replace_module=lora_unet_target_modules | |
) | |
print(f"PTI : has {len(unet_lora_params)} lora") | |
print("PTI : Before training:") | |
inspect_lora(unet) | |
params_to_optimize = [ | |
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, | |
] | |
text_encoder.requires_grad_(False) | |
if continue_inversion: | |
params_to_optimize += [ | |
{ | |
"params": text_encoder.get_input_embeddings().parameters(), | |
"lr": continue_inversion_lr | |
if continue_inversion_lr is not None | |
else ti_lr, | |
} | |
] | |
text_encoder.requires_grad_(True) | |
params_to_freeze = itertools.chain( | |
text_encoder.text_model.encoder.parameters(), | |
text_encoder.text_model.final_layer_norm.parameters(), | |
text_encoder.text_model.embeddings.position_embedding.parameters(), | |
) | |
for param in params_to_freeze: | |
param.requires_grad = False | |
else: | |
text_encoder.requires_grad_(False) | |
if train_text_encoder: | |
text_encoder_lora_params, _ = inject_trainable_lora( | |
text_encoder, | |
target_replace_module=lora_clip_target_modules, | |
r=lora_rank, | |
) | |
params_to_optimize += [ | |
{ | |
"params": itertools.chain(*text_encoder_lora_params), | |
"lr": text_encoder_lr, | |
} | |
] | |
inspect_lora(text_encoder) | |
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) | |
unet.train() | |
if train_text_encoder: | |
text_encoder.train() | |
train_dataset.blur_amount = 70 | |
lr_scheduler_lora = get_scheduler( | |
lr_scheduler_lora, | |
optimizer=lora_optimizers, | |
num_warmup_steps=lr_warmup_steps_lora, | |
num_training_steps=max_train_steps_tuning, | |
) | |
perform_tuning( | |
unet, | |
vae, | |
text_encoder, | |
train_dataloader, | |
max_train_steps_tuning, | |
cached_latents=cached_latents, | |
scheduler=noise_scheduler, | |
optimizer=lora_optimizers, | |
save_steps=save_steps, | |
placeholder_tokens=placeholder_tokens, | |
placeholder_token_ids=placeholder_token_ids, | |
save_path=output_dir, | |
lr_scheduler_lora=lr_scheduler_lora, | |
lora_unet_target_modules=lora_unet_target_modules, | |
lora_clip_target_modules=lora_clip_target_modules, | |
mask_temperature=mask_temperature, | |
tokenizer=tokenizer, | |
out_name=out_name, | |
test_image_path=instance_data_dir, | |
log_wandb=log_wandb, | |
wandb_log_prompt_cnt=wandb_log_prompt_cnt, | |
class_token=class_token, | |
train_inpainting=train_inpainting, | |
) | |
def main(): | |
fire.Fire(train) | |