import os os.system("pip uninstall -y gradio") os.system('pip install gradio==3.43.1') import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import os import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import torch import gc import warnings warnings.filterwarnings("ignore") from PIL import Image from utils import load_models, save_model_w2w, save_model_for_diffusers from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from huggingface_hub import snapshot_download global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global network device = "cuda:0" generator = torch.Generator(device=device) models_path = snapshot_download(repo_id="Snapchat/w2w") mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) def sample_model(): global unet del unet global network unet, _, _, _, _ = load_models(device) network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) @torch.no_grad() def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global young global pointy global wavy global large original_weights = network.proj.clone() #pad to same number of PCs pcs_original = original_weights.shape[1] pcs_edits = young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((young, padding), 1) pointy_pad = torch.cat((pointy, padding), 1) wavy_pad = torch.cat((wavy, padding), 1) large_pad = torch.cat((large, padding), 1) edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original network.proj = torch.nn.Parameter(original_weights) network.reset() return image def sample_then_run(): sample_model() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return image, "model.pt" global young global pointy global wavy global large young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) large = debias(large, "Male", df, pinverse, device) large = debias(large, "Young", df, pinverse, device) large = debias(large, "Pointy_Nose", df, pinverse, device) large = debias(large, "Wavy_Hair", df, pinverse, device) large = debias(large, "Mustache", df, pinverse, device) large = debias(large, "No_Beard", df, pinverse, device) large = debias(large, "Sideburns", df, pinverse, device) large = debias(large, "Big_Nose", df, pinverse, device) large = debias(large, "Big_Lips", df, pinverse, device) large = debias(large, "Black_Hair", df, pinverse, device) large = debias(large, "Brown_Hair", df, pinverse, device) large = debias(large, "Pale_Skin", df, pinverse, device) large = debias(large, "Heavy_Makeup", df, pinverse, device) class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): global unet del unet global network unet, _, _, _, _ = load_models(device) proj = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( proj, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() ### return optimized network return network def run_inversion(dict, pcs, epochs, weight_decay,lr): global network init_image = dict["image"].convert("RGB").resize((512, 512)) mask = dict["mask"].convert("RGB").resize((512, 512)) network = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return image, "model.pt" def file_upload(file): global unet del unet global network global device proj = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = proj.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) proj = torch.cat((proj, padding), 1) unet, _, _, _, _ = load_models(device) network = LoRAw2w( proj, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) return image intro = """

weights2weights

Interpreting the Weight Space of Customized Diffusion Models

project page | paper | Duplicate Space

""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) gr.Markdown("""
Click below to sample an identity-encoding model, or upload an image below and click \"invert\". You can also optionally draw over the face to define a mask. To use model previously downloaded from this demo see \"Uplaoding a model\" in the Advanced options""") with gr.Column(): with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask", height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6) with gr.Row(): sample = gr.Button("Sample New Model") invert_button = gr.Button("Invert") with gr.Column(): gallery1 = gr.Image(label="Identity from Original Model",height=512, width=512, interactive=False) prompt1 = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") # Editing with gr.Column(): #gallery2 = gr.Image(label="Identity from Edited Model", interactive=False, visible=False ) with gr.Row(): a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Row(): a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) # prompt2 = gr.Textbox(label="Prompt", # info="Make sure to include 'sks person'" , # placeholder="sks person", # value="sks person", visible=False) # seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True, visible=False) # submit2 = gr.Button("Generate", visible=False) with gr.Accordion("Advanced Options", open=False): with gr.Tab("Inversion"): with gr.Column(): lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True) weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) with gr.Tab("Sampling"): with gr.Column(): cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True) injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) # with gr.Tab("Editing"): # with gr.Column(): # cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) # steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) # injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) # negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") with gr.Tab("Uploading a model"): gr.Markdown("""
Upload a model below downloaded from this demo.""") file_input = gr.File(label="Upload Model", container=True) submit1 = gr.Button("Generate") # gr.Markdown("""
After sampling a new model or inverting, you can download the model below.""") with gr.Row(): file_output = gr.File(label="Download Sampled Model", container=True, interactive=False, info="After sampling a new model or inverting, you can download the model below") invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [gallery1, file_output]) sample.click(fn=sample_then_run, outputs=[gallery1, file_output]) # submit1.click(fn=inference, # inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1], # outputs=gallery1) submit1.click(fn=edit_inference, inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], outputs=gallery1) file_input.change(fn=file_upload, inputs=file_input, outputs = gallery1) demo.queue().launch()