multimodalart HF staff commited on
Commit
eb710fe
1 Parent(s): 8483373

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import os
4
+ import tqdm
5
+ sys.path.append(os.path.abspath(os.path.join("", "..")))
6
+ import torch
7
+ import gc
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+ from PIL import Image
11
+ from utils import load_models, save_model_w2w, save_model_for_diffusers
12
+ from sampling import sample_weights
13
+ from huggingface_hub import snapshot_download
14
+
15
+ global device
16
+ global generator
17
+ global unet
18
+ global vae
19
+ global text_encoder
20
+ global tokenizer
21
+ global noise_scheduler
22
+ device = "cuda:0"
23
+ generator = torch.Generator(device=device)
24
+
25
+ models_path = snapshot_download(repo_id="Snapchat/w2w")
26
+
27
+ mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
28
+ std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
29
+ v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
30
+ proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
31
+ df = torch.load(f"{models_path}/identity_df.pt")
32
+ weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
33
+
34
+ unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
35
+
36
+ global network
37
+
38
+ def sample_model():
39
+ global unet
40
+ del unet
41
+ global network
42
+ unet, _, _, _, _ = load_models(device)
43
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
44
+
45
+ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
46
+ global device
47
+ global generator
48
+ global unet
49
+ global vae
50
+ global text_encoder
51
+ global tokenizer
52
+ global noise_scheduler
53
+ generator = generator.manual_seed(seed)
54
+ latents = torch.randn(
55
+ (1, unet.in_channels, 512 // 8, 512 // 8),
56
+ generator = generator,
57
+ device = device
58
+ ).bfloat16()
59
+
60
+
61
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
62
+
63
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
64
+
65
+ max_length = text_input.input_ids.shape[-1]
66
+ uncond_input = tokenizer(
67
+ [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
68
+ )
69
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
70
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
71
+ noise_scheduler.set_timesteps(ddim_steps)
72
+ latents = latents * noise_scheduler.init_noise_sigma
73
+
74
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
75
+ latent_model_input = torch.cat([latents] * 2)
76
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
77
+ with network:
78
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
79
+ #guidance
80
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
81
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
82
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
83
+
84
+ latents = 1 / 0.18215 * latents
85
+ image = vae.decode(latents).sample
86
+ image = (image / 2 + 0.5).clamp(0, 1)
87
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
88
+
89
+ image = Image.fromarray((image * 255).round().astype("uint8"))
90
+
91
+ return [image]
92
+
93
+ with gr.Blocks(css=css) as demo:
94
+ gr.Markdown("# <em>weights2weights</em> Demo")
95
+ with gr.Row():
96
+ with gr.Column():
97
+ files = gr.Files(
98
+ label="Upload a photo of your face to invert, or sample a new model",
99
+ file_types=["image"]
100
+ )
101
+ uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
102
+
103
+ sample = gr.Button("Sample New Model")
104
+
105
+ with gr.Column(visible=False) as clear_button:
106
+ remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
107
+ prompt = gr.Textbox(label="Prompt",
108
+ info="Make sure to include 'sks person'" ,
109
+ placeholder="sks person",
110
+ value="sks person")
111
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
112
+ seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
113
+ cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
114
+ steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)
115
+
116
+
117
+ submit = gr.Button("Submit")
118
+
119
+ with gr.Column():
120
+ gallery = gr.Gallery(label="Generated Images")
121
+
122
+ sample.click(fn=sample_model)
123
+
124
+ submit.click(fn=inference,
125
+ inputs=[prompt, negative_prompt, cfg, steps, seed],
126
+ outputs=gallery)
127
+
128
+ demo.launch(share=True)