import gradio as gr import sys import os import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import torch import gc import warnings warnings.filterwarnings("ignore") from PIL import Image from utils import load_models, save_model_w2w, save_model_for_diffusers from sampling import sample_weights from huggingface_hub import snapshot_download global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler device = "cuda:0" generator = torch.Generator(device=device) models_path = snapshot_download(repo_id="Snapchat/w2w") mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device) std = torch.load(f"{models_path}/std.pt").bfloat16().to(device) v = torch.load(f"{models_path}/V.pt").bfloat16().to(device) proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device) df = torch.load(f"{models_path}/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt") unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) global network def sample_model(): global unet del unet global network unet, _, _, _, _ = load_models(device) network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return [image] with gr.Blocks(css=css) as demo: gr.Markdown("# weights2weights Demo") with gr.Row(): with gr.Column(): files = gr.Files( label="Upload a photo of your face to invert, or sample a new model", file_types=["image"] ) uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125) sample = gr.Button("Sample New Model") with gr.Column(visible=False) as clear_button: remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") prompt = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") seed = gr.Number(value=5, precision=0, label="Seed", interactive=True) cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True) submit = gr.Button("Submit") with gr.Column(): gallery = gr.Gallery(label="Generated Images") sample.click(fn=sample_model) submit.click(fn=inference, inputs=[prompt, negative_prompt, cfg, steps, seed], outputs=gallery) demo.launch(share=True)