Linoy Tsaban
commited on
Commit
•
6255790
1
Parent(s):
6908973
Update app.py
Browse files
app.py
CHANGED
@@ -4,11 +4,88 @@ import requests
|
|
4 |
from io import BytesIO
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from diffusers import DDIMScheduler
|
7 |
-
from utils import
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from io import BytesIO
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from diffusers import DDIMScheduler
|
7 |
+
from utils import *
|
8 |
+
from inversion_utils import *
|
9 |
|
10 |
+
model_id = "CompVis/stable-diffusion-v1-4"
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
|
13 |
+
sd_pipe.scheduler = DDIMScheduler.from_config(model_id, subfolder = "scheduler")
|
14 |
+
from torch import autocast, inference_mode
|
15 |
|
16 |
+
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
17 |
+
|
18 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
19 |
+
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
20 |
+
|
21 |
+
# returns wt, zs, wts:
|
22 |
+
# wt - inverted latent
|
23 |
+
# wts - intermediate inverted latents
|
24 |
+
# zs - noise maps
|
25 |
+
|
26 |
+
sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
|
27 |
+
|
28 |
+
# vae encode image
|
29 |
+
with autocast("cuda"), inference_mode():
|
30 |
+
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
|
31 |
+
|
32 |
+
# find Zs and wts - forward process
|
33 |
+
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
34 |
+
return wt, zs, wts
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
39 |
+
|
40 |
+
# reverse process (via Zs and wT)
|
41 |
+
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
42 |
+
|
43 |
+
# vae decode image
|
44 |
+
with autocast("cuda"), inference_mode():
|
45 |
+
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
|
46 |
+
if x0_dec.dim()<4:
|
47 |
+
x0_dec = x0_dec[None,:,:,:]
|
48 |
+
img = image_grid(x0_dec)
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def edit(input_image, input_image_prompt, target_prompt, guidance_scale=15, skip=36, num_diffusion_steps=100):
|
55 |
+
offsets=(0,0,0,0)
|
56 |
+
x0 = load_512(input_image, *offsets, device)
|
57 |
+
|
58 |
+
|
59 |
+
# invert
|
60 |
+
wt, zs, wts = invert(x0 =x0 , prompt_src=input_image_prompt, num_diffusion_steps=num_diffusion_steps)
|
61 |
+
latnets = wts[skip].expand(1, -1, -1, -1)
|
62 |
+
|
63 |
+
eta = 1
|
64 |
+
#pure DDPM output
|
65 |
+
pure_ddpm_out = sample(wt, zs, wts, prompt_tar=target_prompt,
|
66 |
+
cfg_scale_tar=guidance_scale, skip=skip,
|
67 |
+
eta = eta)
|
68 |
+
return pure_ddpm_out
|
69 |
+
|
70 |
+
|
71 |
+
# See the gradio docs for the types of inputs and outputs available
|
72 |
+
inputs = [
|
73 |
+
gr.Image(label="input image", shape=(512, 512)),
|
74 |
+
gr.Textbox(label="input prompt"),
|
75 |
+
gr.Textbox(label="target prompt"),
|
76 |
+
gr.Slider(label="guidance_scale", minimum=7, maximum=18, value=15),
|
77 |
+
gr.Slider(label="skip", minimum=0, maximum=40, value=36),
|
78 |
+
gr.Slider(label="num_diffusion_steps", minimum=0, maximum=300, value=100),
|
79 |
+
|
80 |
+
|
81 |
+
]
|
82 |
+
outputs = gr.Image(label="result")
|
83 |
+
|
84 |
+
# And the minimal interface
|
85 |
+
demo = gr.Interface(
|
86 |
+
fn=edit,
|
87 |
+
inputs=inputs,
|
88 |
+
outputs=outputs,
|
89 |
+
)
|
90 |
+
|
91 |
+
demo.launch()
|