Spaces:
Running
on
Zero
Running
on
Zero
import os | |
# os.system("pip uninstall -y gradio") | |
# #os.system('pip install gradio==3.43.1') | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
import gradio as gr | |
import sys | |
import os | |
import tqdm | |
sys.path.append(os.path.abspath(os.path.join("", ".."))) | |
import torch | |
import gc | |
import warnings | |
warnings.filterwarnings("ignore") | |
from PIL import Image | |
from utils import load_models, save_model_w2w, save_model_for_diffusers | |
from editing import get_direction, debias | |
from sampling import sample_weights | |
from lora_w2w import LoRAw2w | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
global network | |
device = "cuda:0" | |
generator = torch.Generator(device=device) | |
from gradio_imageslider import ImageSlider | |
models_path = snapshot_download(repo_id="Snapchat/w2w") | |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) | |
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) | |
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) | |
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) | |
df = torch.load(f"{models_path}/files/identity_df.pt") | |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") | |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) | |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) | |
def sample_model(): | |
global unet | |
del unet | |
global network | |
unet, _, _, _, _ = load_models(device) | |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) | |
def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed): | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
generator = generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
generator = generator, | |
device = device | |
).bfloat16() | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
noise_scheduler.set_timesteps(ddim_steps) | |
latents = latents * noise_scheduler.init_noise_sigma | |
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
with network: | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
#guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
image = Image.fromarray((image * 255).round().astype("uint8")) | |
return image | |
def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
global young | |
global pointy | |
global wavy | |
global large | |
original_weights = network.proj.clone() | |
#pad to same number of PCs | |
pcs_original = original_weights.shape[1] | |
pcs_edits = young.shape[1] | |
padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) | |
young_pad = torch.cat((young, padding), 1) | |
pointy_pad = torch.cat((pointy, padding), 1) | |
wavy_pad = torch.cat((wavy, padding), 1) | |
large_pad = torch.cat((large, padding), 1) | |
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad | |
generator = generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
generator = generator, | |
device = device | |
).bfloat16() | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
noise_scheduler.set_timesteps(ddim_steps) | |
latents = latents * noise_scheduler.init_noise_sigma | |
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
if t>start_noise: | |
pass | |
elif t<=start_noise: | |
network.proj = torch.nn.Parameter(edited_weights) | |
network.reset() | |
with network: | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
#guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
image = Image.fromarray((image * 255).round().astype("uint8")) | |
#reset weights back to original | |
network.proj = torch.nn.Parameter(original_weights) | |
network.reset() | |
return (image, input_image["background"]) | |
def sample_then_run(): | |
sample_model() | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
torch.save(network.proj, "model.pt" ) | |
return image, "model.pt" | |
global young | |
global pointy | |
global wavy | |
global large | |
young = get_direction(df, "Young", pinverse, 1000, device) | |
young = debias(young, "Male", df, pinverse, device) | |
young = debias(young, "Pointy_Nose", df, pinverse, device) | |
young = debias(young, "Wavy_Hair", df, pinverse, device) | |
young = debias(young, "Chubby", df, pinverse, device) | |
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) | |
pointy = debias(pointy, "Young", df, pinverse, device) | |
pointy = debias(pointy, "Male", df, pinverse, device) | |
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) | |
pointy = debias(pointy, "Chubby", df, pinverse, device) | |
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) | |
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) | |
wavy = debias(wavy, "Young", df, pinverse, device) | |
wavy = debias(wavy, "Male", df, pinverse, device) | |
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) | |
wavy = debias(wavy, "Chubby", df, pinverse, device) | |
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) | |
large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) | |
large = debias(large, "Male", df, pinverse, device) | |
large = debias(large, "Young", df, pinverse, device) | |
large = debias(large, "Pointy_Nose", df, pinverse, device) | |
large = debias(large, "Wavy_Hair", df, pinverse, device) | |
large = debias(large, "Mustache", df, pinverse, device) | |
large = debias(large, "No_Beard", df, pinverse, device) | |
large = debias(large, "Sideburns", df, pinverse, device) | |
large = debias(large, "Big_Nose", df, pinverse, device) | |
large = debias(large, "Big_Lips", df, pinverse, device) | |
large = debias(large, "Black_Hair", df, pinverse, device) | |
large = debias(large, "Brown_Hair", df, pinverse, device) | |
large = debias(large, "Pale_Skin", df, pinverse, device) | |
large = debias(large, "Heavy_Makeup", df, pinverse, device) | |
class CustomImageDataset(Dataset): | |
def __init__(self, images, transform=None): | |
self.images = images | |
self.transform = transform | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image = self.images[idx] | |
if self.transform: | |
image = self.transform(image) | |
return image | |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): | |
global unet | |
del unet | |
global network | |
unet, _, _, _, _ = load_models(device) | |
proj = torch.zeros(1,pcs).bfloat16().to(device) | |
network = LoRAw2w( proj, mean, std, v[:, :pcs], | |
unet, | |
rank=1, | |
multiplier=1.0, | |
alpha=27.0, | |
train_method="xattn-strict" | |
).to(device, torch.bfloat16) | |
### load mask | |
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) | |
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) | |
### check if an actual mask was draw, otherwise mask is just all ones | |
if torch.sum(mask) == 0: | |
mask = torch.ones((1,1,64,64)).to(device).bfloat16() | |
### single image dataset | |
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.RandomCrop(512), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5])]) | |
train_dataset = CustomImageDataset(image, transform=image_transforms) | |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) | |
### optimizer | |
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) | |
### training loop | |
unet.train() | |
for epoch in tqdm.tqdm(range(epochs)): | |
for batch in train_dataloader: | |
### prepare inputs | |
batch = batch.to(device).bfloat16() | |
latents = vae.encode(batch).latent_dist.sample() | |
latents = latents*0.18215 | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
### loss + sgd step | |
with network: | |
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample | |
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
### return optimized network | |
return network | |
def run_inversion(input_image, pcs, epochs, weight_decay,lr): | |
global network | |
print(len(input_image["layers"])) | |
init_image = input_image["background"].convert("RGB").resize((512, 512)) | |
mask = input_image["layers"][0].convert("RGB").resize((512, 512)) | |
network = invert([init_image], mask, pcs, epochs, weight_decay,lr) | |
#sample an image | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
torch.save(network.proj, "model.pt" ) | |
return image, "model.pt" | |
def file_upload(file): | |
global unet | |
del unet | |
global network | |
global device | |
proj = torch.load(file.name).to(device) | |
#pad to 10000 Principal components to keep everything consistent | |
pcs = proj.shape[1] | |
padding = torch.zeros((1,10000-pcs)).to(device) | |
proj = torch.cat((proj, padding), 1) | |
unet, _, _, _, _ = load_models(device) | |
network = LoRAw2w( proj, mean, std, v[:, :10000], | |
unet, | |
rank=1, | |
multiplier=1.0, | |
alpha=27.0, | |
train_method="xattn-strict" | |
).to(device, torch.bfloat16) | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
return image | |
intro = """ | |
<div style="display: flex;align-items: center;justify-content: center"> | |
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1> | |
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3> | |
</div> | |
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> | |
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a> | |
| | |
<a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style=" | |
display: inline-block; | |
"> | |
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a> | |
</p> | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML(intro) | |
gr.Markdown("""<div style="text-align: justify;"> Click below to sample an identity-encoding model, or upload an image below and click \"invert\". You can also optionally draw over the face to define a mask. To use model previously downloaded from this demo see \"Uplaoding a model\" in the Advanced options""") | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
# input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask", | |
# height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6) | |
input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", | |
height=512, width=512, brush=gr.Brush(), layers=False) | |
with gr.Row(): | |
sample = gr.Button("Sample New Model") | |
invert_button = gr.Button("Invert") | |
with gr.Column(): | |
image_slider = ImageSlider(position=1., type="pil", height=512, width=512) | |
# gallery1 = gr.Image(label="Identity from Original Model",height=512, width=512, interactive=False) | |
prompt1 = gr.Textbox(label="Prompt", | |
info="Make sure to include 'sks person'" , | |
placeholder="sks person", | |
value="sks person") | |
# Editing | |
with gr.Column(): | |
#gallery2 = gr.Image(label="Identity from Edited Model", interactive=False, visible=False ) | |
with gr.Row(): | |
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
with gr.Row(): | |
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
# prompt2 = gr.Textbox(label="Prompt", | |
# info="Make sure to include 'sks person'" , | |
# placeholder="sks person", | |
# value="sks person", visible=False) | |
# seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True, visible=False) | |
# submit2 = gr.Button("Generate", visible=False) | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Tab("Inversion"): | |
with gr.Row(): | |
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) | |
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) | |
with gr.Row(): | |
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True) | |
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) | |
with gr.Tab("Sampling"): | |
with gr.Row(): | |
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) | |
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) | |
seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True) | |
with gr.Row(): | |
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") | |
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) | |
# with gr.Tab("Editing"): | |
# with gr.Column(): | |
# cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) | |
# steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) | |
# injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) | |
# negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") | |
with gr.Tab("Uploading a model"): | |
gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""") | |
file_input = gr.File(label="Upload Model", container=True) | |
submit1 = gr.Button("Generate") | |
gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""") | |
with gr.Row(): | |
file_output = gr.File(label="Download Sampled Model", container=True, interactive=False) | |
invert_button.click(fn=run_inversion, | |
inputs=[input_image, pcs, epochs, weight_decay,lr], | |
outputs = [image_slider, file_output]) | |
sample.click(fn=sample_then_run, outputs=[input_image, file_output]) | |
# submit1.click(fn=inference, | |
# inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1], | |
# outputs=gallery1) | |
# submit1.click(fn=edit_inference, | |
# inputs=[input_image, prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], | |
# outputs=image_slider) | |
submit1.click( | |
fn=edit_inference, inputs=[input_image, prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], outputs=[image_slider] | |
) | |
file_input.change(fn=file_upload, inputs=file_input, outputs = input_image) | |
demo.queue().launch() |