weights2weights / app.py
amildravid4292's picture
Update app.py
60f721c verified
raw
history blame
17.9 kB
import os
# os.system('pip install gradio==3.43.1')
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from huggingface_hub import snapshot_download
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global network
device = "cuda:0"
generator = torch.Generator(device=device)
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
def sample_model():
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
@torch.no_grad()
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return image
@torch.no_grad()
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global young
global pointy
global wavy
global large
original_weights = network.proj.clone()
edited_weights = original_weights+a1*1e6*young+a2*1e6*pointy+a3*1e6*wavy+a4*2e6*large
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return image
def sample_then_run():
sample_model()
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, cartoon"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
return image
#directions
global young
global pointy
global wavy
global large
young = get_direction(df, "Young", pinverse, 1000, device)
young = debias(young, "Male", df, pinverse, device)
young = debias(young, "Pointy_Nose", df, pinverse, device)
young = debias(young, "Wavy_Hair", df, pinverse, device)
young = debias(young, "Chubby", df, pinverse, device)
young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item()
young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item()
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
pointy = debias(pointy, "Young", df, pinverse, device)
pointy = debias(pointy, "Male", df, pinverse, device)
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
pointy = debias(pointy, "Chubby", df, pinverse, device)
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item()
pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item()
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
wavy = debias(wavy, "Young", df, pinverse, device)
wavy = debias(wavy, "Male", df, pinverse, device)
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
wavy = debias(wavy, "Chubby", df, pinverse, device)
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
wavy_max = torch.max(proj@wavy[0]/(torch.norm(wavy))**2).item()
wavy_min = torch.min(proj@wavy[0]/(torch.norm(wavy))**2).item()
large = get_direction(df, "Chubby", pinverse, 1000, device)
large = debias(large, "Male", df, pinverse, device)
large = debias(large, "Young", df, pinverse, device)
large = debias(large, "Pointy_Nose", df, pinverse, device)
large = debias(large, "Wavy_Hair", df, pinverse, device)
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
class CustomImageDataset(Dataset):
def __init__(self, images, transform=None):
self.images = images
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
return image
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
proj = torch.zeros(1,pcs).bfloat16().to(device)
network = LoRAw2w( proj, mean, std, v[:, :pcs],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
### load mask
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
### check if an actual mask was draw, otherwise mask is just all ones
if torch.sum(mask) == 0:
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
### single image dataset
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
train_dataset = CustomImageDataset(image, transform=image_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
### optimizer
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
### training loop
unet.train()
for epoch in tqdm.tqdm(range(epochs)):
for batch in train_dataloader:
### prepare inputs
batch = batch.to(device).bfloat16()
latents = vae.encode(batch).latent_dist.sample()
latents = latents*0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
### loss + sgd step
with network:
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
optim.zero_grad()
loss.backward()
optim.step()
### return optimized network
return network
def run_inversion(dict, pcs, epochs, weight_decay,lr):
global network
init_image = dict["image"].convert("RGB").resize((512, 512))
mask = dict["mask"].convert("RGB").resize((512, 512))
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
#sample an image
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, cartoon"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
torch.save(network.proj, "model.pt" )
return image, "model.pt"
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a>
|
<a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
display: inline-block;
">
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
with gr.Tab("Sampling Models + Editing"):
with gr.Row():
with gr.Column():
gallery1 = gr.Image(label="Identity from Sampled Model")
sample = gr.Button("Sample New Model")
gallery2 = gr.Image(label="Identity from Edited Model")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
with gr.Row():
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Row():
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Column():
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
submit = gr.Button("Generate")
sample.click(fn=sample_then_run, outputs=gallery1)
submit.click(fn=edit_inference,
inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
outputs=gallery2)
with gr.Tab("Inversion"):
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
# input_image = gr.ImageEditor(sources='upload', elem_id="image_upload", type='pil', label="Upload image and draw to define mask",
# height=512, width=512)
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
invert_button = gr.Button("Invert")
with gr.Column():
gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512)
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
submit = gr.Button("Generate")
file_output = gr.File(label="Download Model", container=False)
invert_button.click(fn=run_inversion,
inputs=[input_image, pcs, epochs, weight_decay,lr],
outputs = [gallery, file_output])
submit.click(fn=inference,
inputs=[prompt, negative_prompt, cfg, steps, seed,],
outputs=gallery)
demo.queue().launch(share=True)