weights2weights / app.py
multimodalart's picture
Create app.py
eb710fe verified
raw
history blame
4.97 kB
import gradio as gr
import sys
import os
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from huggingface_hub import snapshot_download
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
device = "cuda:0"
generator = torch.Generator(device=device)
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
global network
def sample_model():
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return [image]
with gr.Blocks(css=css) as demo:
gr.Markdown("# <em>weights2weights</em> Demo")
with gr.Row():
with gr.Column():
files = gr.Files(
label="Upload a photo of your face to invert, or sample a new model",
file_types=["image"]
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
sample = gr.Button("Sample New Model")
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)
submit = gr.Button("Submit")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
sample.click(fn=sample_model)
submit.click(fn=inference,
inputs=[prompt, negative_prompt, cfg, steps, seed],
outputs=gallery)
demo.launch(share=True)