williamberman
commited on
Commit
·
f280910
1
Parent(s):
aa67e5e
working
Browse files- app.py +30 -22
- diffusion.py +5 -5
- sdxl.py +36 -16
- sdxl_models.py +53 -39
app.py
CHANGED
@@ -1,23 +1,22 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
-
from diffusers import AutoPipelineForInpainting
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
-
from sdxl import
|
8 |
from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
|
|
|
|
|
|
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
comparing_unet = SDXLUNet.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/unet/diffusion_pytorch_model.fp16.safetensors", device=device)
|
16 |
-
# comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/vae/diffusion_pytorch_model.fp16.safetensors", device=device)
|
17 |
-
comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--madebyollin--sdxl-vae-fp16-fix/snapshots/4df413ca49271c25289a6482ab97a433f8117d15/diffusion_pytorch_model.safetensors", device=device)
|
18 |
comparing_vae.to(torch.float16)
|
19 |
-
|
20 |
-
comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("./controlnet_vae.safetensors", device="cuda") # TODO - upload checkpoint
|
21 |
comparing_controlnet.to(torch.float16)
|
22 |
|
23 |
def read_content(file_path: str) -> str:
|
@@ -45,15 +44,26 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
|
|
45 |
init_image = dict["image"].convert("RGB").resize((1024, 1024))
|
46 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
return output_controlnet_vae_encoding[0], gr.update(visible=True)
|
57 |
|
58 |
|
59 |
css = '''
|
@@ -107,7 +117,7 @@ with image_blocks as demo:
|
|
107 |
with gr.Accordion(label="Advanced Settings", open=False):
|
108 |
with gr.Row(mobile_collapse=False, equal_height=True):
|
109 |
guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
|
110 |
-
steps = gr.Number(value=20, minimum=
|
111 |
strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
|
112 |
negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image")
|
113 |
with gr.Row(mobile_collapse=False, equal_height=True):
|
@@ -123,10 +133,8 @@ with image_blocks as demo:
|
|
123 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
124 |
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container], api_name='run')
|
129 |
-
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container])
|
130 |
share_button.click(None, [], [], _js=share_js)
|
131 |
|
132 |
gr.Examples(
|
@@ -155,4 +163,4 @@ with image_blocks as demo:
|
|
155 |
"""
|
156 |
)
|
157 |
|
158 |
-
image_blocks.queue(max_size=25).launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
+
from diffusers import AutoPipelineForInpainting
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
+
from sdxl import sdxl_diffusion_loop
|
8 |
from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
|
9 |
+
import torchvision.transforms.functional as TF
|
10 |
+
from diffusion import make_sigmas
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
15 |
|
16 |
+
comparing_unet = SDXLUNet.load(hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "unet/diffusion_pytorch_model.fp16.safetensors"), device=device)
|
17 |
+
comparing_vae = SDXLVae.load(hf_hub_download("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors"), device=device)
|
|
|
|
|
|
|
18 |
comparing_vae.to(torch.float16)
|
19 |
+
comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load(hf_hub_download("williamberman/sdxl_controlnet_inpainting", "sdxl_controlnet_inpaint_pre_encoded_controlnet_cond_checkpoint_200000.safetensors"), device=device)
|
|
|
20 |
comparing_controlnet.to(torch.float16)
|
21 |
|
22 |
def read_content(file_path: str) -> str:
|
|
|
44 |
init_image = dict["image"].convert("RGB").resize((1024, 1024))
|
45 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
46 |
|
47 |
+
output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
|
48 |
+
|
49 |
+
image = TF.to_tensor(dict["image"].convert("RGB").resize((1024, 1024)))
|
50 |
+
mask = TF.to_tensor(dict["mask"].convert("L").resize((1024, 1024)))
|
51 |
+
image = image * (mask < 0.5)
|
52 |
+
image = TF.normalize(image, [0.5], [0.5])
|
53 |
+
image = comparing_vae.encode(image[None, :, :, :].to(dtype=comparing_vae.dtype, device=comparing_vae.device)).to(dtype=comparing_controlnet.dtype, device=comparing_controlnet.device)
|
54 |
+
mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
|
55 |
+
image = torch.concat((image, mask), dim=1)
|
56 |
+
|
57 |
+
sigmas = make_sigmas(device=comparing_unet.device).to(dtype=comparing_unet.dtype)
|
58 |
+
timesteps = torch.linspace(0, sigmas.numel() - 1, int(steps), dtype=torch.long, device=comparing_unet.device)
|
59 |
+
|
60 |
+
out = sdxl_diffusion_loop(
|
61 |
+
prompts=prompt, negative_prompts=negative_prompt, images=image, guidance_scale=guidance_scale, sigmas=sigmas, timesteps=timesteps,
|
62 |
+
text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, controlnet=comparing_controlnet
|
63 |
)
|
64 |
+
out = comparing_vae.output_tensor_to_pil(comparing_vae.decode(out))
|
65 |
|
66 |
+
return output.images[0], out[0], gr.update(visible=True)
|
|
|
|
|
67 |
|
68 |
|
69 |
css = '''
|
|
|
117 |
with gr.Accordion(label="Advanced Settings", open=False):
|
118 |
with gr.Row(mobile_collapse=False, equal_height=True):
|
119 |
guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
|
120 |
+
steps = gr.Number(value=20, minimum=1, maximum=1000, step=1, label="steps")
|
121 |
strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
|
122 |
negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image")
|
123 |
with gr.Row(mobile_collapse=False, equal_height=True):
|
|
|
133 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
134 |
|
135 |
|
136 |
+
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
|
137 |
+
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
|
|
|
|
|
138 |
share_button.click(None, [], [], _js=share_js)
|
139 |
|
140 |
gr.Examples(
|
|
|
163 |
"""
|
164 |
)
|
165 |
|
166 |
+
image_blocks.queue(max_size=25).launch(share=True)
|
diffusion.py
CHANGED
@@ -21,15 +21,14 @@ def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_wei
|
|
21 |
x_t = x_T
|
22 |
|
23 |
for i in range(len(timesteps) - 1, -1, -1):
|
24 |
-
t = timesteps[i]
|
25 |
-
|
26 |
-
sigma = sigmas[i]
|
27 |
|
28 |
if i == 0:
|
29 |
eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
|
30 |
x_0_hat = x_t - sigma * eps_hat
|
31 |
else:
|
32 |
-
dt = sigmas[i - 1] - sigma
|
33 |
|
34 |
dx_by_dt = torch.zeros_like(x_t)
|
35 |
dx_by_dt_cur = torch.zeros_like(x_t)
|
@@ -41,7 +40,8 @@ def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_wei
|
|
41 |
eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
|
42 |
# TODO - note which specific ode this is the solution to and
|
43 |
# how input scaling does/doesn't effect the solution
|
44 |
-
dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
|
|
|
45 |
dx_by_dt += dx_by_dt_cur * rk_weight
|
46 |
|
47 |
x_t_minus_1 = x_t + dx_by_dt * dt
|
|
|
21 |
x_t = x_T
|
22 |
|
23 |
for i in range(len(timesteps) - 1, -1, -1):
|
24 |
+
t = timesteps[i].unsqueeze(0)
|
25 |
+
sigma = sigmas[t]
|
|
|
26 |
|
27 |
if i == 0:
|
28 |
eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
|
29 |
x_0_hat = x_t - sigma * eps_hat
|
30 |
else:
|
31 |
+
dt = sigmas[timesteps[i - 1]] - sigma
|
32 |
|
33 |
dx_by_dt = torch.zeros_like(x_t)
|
34 |
dx_by_dt_cur = torch.zeros_like(x_t)
|
|
|
40 |
eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
|
41 |
# TODO - note which specific ode this is the solution to and
|
42 |
# how input scaling does/doesn't effect the solution
|
43 |
+
# dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
|
44 |
+
dx_by_dt_cur = eps_hat
|
45 |
dx_by_dt += dx_by_dt_cur * rk_weight
|
46 |
|
47 |
x_t_minus_1 = x_t + dx_by_dt * dt
|
sdxl.py
CHANGED
@@ -667,7 +667,7 @@ def apply_padding(mask, coord):
|
|
667 |
|
668 |
@torch.no_grad()
|
669 |
def sdxl_diffusion_loop(
|
670 |
-
prompts: List[str],
|
671 |
unet,
|
672 |
text_encoder_one,
|
673 |
text_encoder_two,
|
@@ -683,12 +683,13 @@ def sdxl_diffusion_loop(
|
|
683 |
negative_prompts=None,
|
684 |
diffusion_loop=euler_ode_solver_diffusion_loop,
|
685 |
):
|
686 |
-
|
|
|
687 |
|
688 |
-
|
689 |
-
negative_prompts = [""] * batch_size
|
690 |
|
691 |
-
|
|
|
692 |
|
693 |
encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
|
694 |
text_encoder_one,
|
@@ -699,15 +700,26 @@ def sdxl_diffusion_loop(
|
|
699 |
encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
|
700 |
pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
|
701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
if sigmas is None:
|
703 |
sigmas = make_sigmas(device=unet.device)
|
704 |
|
|
|
|
|
|
|
705 |
if x_T is None:
|
706 |
x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
|
707 |
-
x_T = x_T * ((sigmas
|
708 |
-
|
709 |
-
if timesteps is None:
|
710 |
-
timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
|
711 |
|
712 |
if micro_conditioning is None:
|
713 |
micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
|
@@ -723,13 +735,14 @@ def sdxl_diffusion_loop(
|
|
723 |
else:
|
724 |
controlnet_cond = None
|
725 |
|
726 |
-
eps_theta = lambda
|
727 |
-
|
728 |
-
|
729 |
-
sigma=sigma,
|
730 |
unet=unet,
|
731 |
encoder_hidden_states=encoder_hidden_states,
|
732 |
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
|
|
|
|
733 |
micro_conditioning=micro_conditioning,
|
734 |
guidance_scale=guidance_scale,
|
735 |
controlnet=controlnet,
|
@@ -750,6 +763,8 @@ def sdxl_eps_theta(
|
|
750 |
unet,
|
751 |
encoder_hidden_states,
|
752 |
pooled_encoder_hidden_states,
|
|
|
|
|
753 |
micro_conditioning,
|
754 |
guidance_scale,
|
755 |
controlnet=None,
|
@@ -761,13 +776,18 @@ def sdxl_eps_theta(
|
|
761 |
|
762 |
if guidance_scale > 1.0:
|
763 |
scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
|
|
|
|
|
|
|
|
|
764 |
micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
|
|
|
765 |
if controlnet_cond is not None:
|
766 |
controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
|
767 |
|
768 |
if controlnet is not None:
|
769 |
controlnet_out = controlnet(
|
770 |
-
x_t=scaled_x_t,
|
771 |
t=t,
|
772 |
encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
|
773 |
micro_conditioning=micro_conditioning.to(controlnet.dtype),
|
@@ -801,7 +821,7 @@ def sdxl_eps_theta(
|
|
801 |
)
|
802 |
|
803 |
if guidance_scale > 1.0:
|
804 |
-
|
805 |
|
806 |
eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
|
807 |
|
@@ -867,7 +887,7 @@ def gen_sdxl_simplified_interface(
|
|
867 |
|
868 |
sigmas = make_sigmas()
|
869 |
|
870 |
-
timesteps = torch.linspace(0, sigmas.numel(), num_inference_steps, dtype=torch.long, device=unet.device)
|
871 |
|
872 |
if images is not None:
|
873 |
if not isinstance(images, list):
|
|
|
667 |
|
668 |
@torch.no_grad()
|
669 |
def sdxl_diffusion_loop(
|
670 |
+
prompts: Union[str, List[str]],
|
671 |
unet,
|
672 |
text_encoder_one,
|
673 |
text_encoder_two,
|
|
|
683 |
negative_prompts=None,
|
684 |
diffusion_loop=euler_ode_solver_diffusion_loop,
|
685 |
):
|
686 |
+
if isinstance(prompts, str):
|
687 |
+
prompts = [prompts]
|
688 |
|
689 |
+
batch_size = len(prompts)
|
|
|
690 |
|
691 |
+
if negative_prompts is not None and guidance_scale > 1.0:
|
692 |
+
prompts += negative_prompts
|
693 |
|
694 |
encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
|
695 |
text_encoder_one,
|
|
|
700 |
encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
|
701 |
pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
|
702 |
|
703 |
+
if guidance_scale > 1.0:
|
704 |
+
if negative_prompts is None:
|
705 |
+
negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
|
706 |
+
negative_pooled_encoder_hidden_states = torch.zeros_like(pooled_encoder_hidden_states)
|
707 |
+
else:
|
708 |
+
encoder_hidden_states, negative_encoder_hidden_states = torch.chunk(encoder_hidden_states, 2)
|
709 |
+
pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states = torch.chunk(pooled_encoder_hidden_states, 2)
|
710 |
+
else:
|
711 |
+
negative_encoder_hidden_states = None
|
712 |
+
negative_pooled_encoder_hidden_states = None
|
713 |
+
|
714 |
if sigmas is None:
|
715 |
sigmas = make_sigmas(device=unet.device)
|
716 |
|
717 |
+
if timesteps is None:
|
718 |
+
timesteps = torch.linspace(0, sigmas.numel() - 1, 50, dtype=torch.long, device=unet.device)
|
719 |
+
|
720 |
if x_T is None:
|
721 |
x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
|
722 |
+
x_T = x_T * ((sigmas[timesteps[-1]] ** 2 + 1) ** 0.5)
|
|
|
|
|
|
|
723 |
|
724 |
if micro_conditioning is None:
|
725 |
micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
|
|
|
735 |
else:
|
736 |
controlnet_cond = None
|
737 |
|
738 |
+
eps_theta = lambda *args, **kwargs: sdxl_eps_theta(
|
739 |
+
*args,
|
740 |
+
**kwargs,
|
|
|
741 |
unet=unet,
|
742 |
encoder_hidden_states=encoder_hidden_states,
|
743 |
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
744 |
+
negative_encoder_hidden_states=negative_encoder_hidden_states,
|
745 |
+
negative_pooled_encoder_hidden_states=negative_pooled_encoder_hidden_states,
|
746 |
micro_conditioning=micro_conditioning,
|
747 |
guidance_scale=guidance_scale,
|
748 |
controlnet=controlnet,
|
|
|
763 |
unet,
|
764 |
encoder_hidden_states,
|
765 |
pooled_encoder_hidden_states,
|
766 |
+
negative_encoder_hidden_states,
|
767 |
+
negative_pooled_encoder_hidden_states,
|
768 |
micro_conditioning,
|
769 |
guidance_scale,
|
770 |
controlnet=None,
|
|
|
776 |
|
777 |
if guidance_scale > 1.0:
|
778 |
scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
|
779 |
+
|
780 |
+
encoder_hidden_states = torch.concat((encoder_hidden_states, negative_encoder_hidden_states))
|
781 |
+
pooled_encoder_hidden_states = torch.concat((pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states))
|
782 |
+
|
783 |
micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
|
784 |
+
|
785 |
if controlnet_cond is not None:
|
786 |
controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
|
787 |
|
788 |
if controlnet is not None:
|
789 |
controlnet_out = controlnet(
|
790 |
+
x_t=scaled_x_t.to(controlnet.dtype),
|
791 |
t=t,
|
792 |
encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
|
793 |
micro_conditioning=micro_conditioning.to(controlnet.dtype),
|
|
|
821 |
)
|
822 |
|
823 |
if guidance_scale > 1.0:
|
824 |
+
eps_hat, eps_hat_uncond = eps_hat.chunk(2)
|
825 |
|
826 |
eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
|
827 |
|
|
|
887 |
|
888 |
sigmas = make_sigmas()
|
889 |
|
890 |
+
timesteps = torch.linspace(0, sigmas.numel() - 1, num_inference_steps, dtype=torch.long, device=unet.device)
|
891 |
|
892 |
if images is not None:
|
893 |
if not isinstance(images, list):
|
sdxl_models.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import math
|
2 |
import os
|
3 |
-
from typing import List, Optional
|
4 |
|
5 |
import safetensors.torch
|
6 |
import torch
|
@@ -1246,16 +1246,14 @@ class ResnetBlock2D(nn.Module):
|
|
1246 |
def forward(self, hidden_states, temb=None):
|
1247 |
residual = hidden_states
|
1248 |
|
1249 |
-
if self.time_emb_proj is not None:
|
1250 |
-
assert temb is not None
|
1251 |
-
temb = self.nonlinearity(temb)
|
1252 |
-
temb = self.time_emb_proj(temb)[:, :, None, None]
|
1253 |
-
|
1254 |
hidden_states = self.norm1(hidden_states)
|
1255 |
hidden_states = self.nonlinearity(hidden_states)
|
1256 |
hidden_states = self.conv1(hidden_states)
|
1257 |
|
1258 |
-
if
|
|
|
|
|
|
|
1259 |
hidden_states = hidden_states + temb
|
1260 |
|
1261 |
hidden_states = self.norm2(hidden_states)
|
@@ -1325,7 +1323,51 @@ class TransformerDecoderBlock(nn.Module):
|
|
1325 |
return hidden_states
|
1326 |
|
1327 |
|
1328 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1329 |
def __init__(self, channels, encoder_hidden_states_dim):
|
1330 |
super().__init__()
|
1331 |
self.to_q = nn.Linear(channels, channels, bias=False)
|
@@ -1334,10 +1376,10 @@ class Attention(nn.Module):
|
|
1334 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1335 |
|
1336 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1337 |
-
return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
|
1338 |
|
1339 |
|
1340 |
-
class VaeMidBlockAttention(nn.Module):
|
1341 |
def __init__(self, channels):
|
1342 |
super().__init__()
|
1343 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
@@ -1355,7 +1397,7 @@ class VaeMidBlockAttention(nn.Module):
|
|
1355 |
|
1356 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1357 |
|
1358 |
-
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
|
1359 |
|
1360 |
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1361 |
|
@@ -1364,34 +1406,6 @@ class VaeMidBlockAttention(nn.Module):
|
|
1364 |
return hidden_states
|
1365 |
|
1366 |
|
1367 |
-
def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
|
1368 |
-
batch_size, q_seq_len, channels = hidden_states.shape
|
1369 |
-
|
1370 |
-
if encoder_hidden_states is not None:
|
1371 |
-
kv = encoder_hidden_states
|
1372 |
-
else:
|
1373 |
-
kv = hidden_states
|
1374 |
-
|
1375 |
-
kv_seq_len = kv.shape[1]
|
1376 |
-
|
1377 |
-
query = to_q(hidden_states)
|
1378 |
-
key = to_k(kv)
|
1379 |
-
value = to_v(kv)
|
1380 |
-
|
1381 |
-
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
|
1382 |
-
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1383 |
-
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1384 |
-
|
1385 |
-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
|
1386 |
-
|
1387 |
-
hidden_states = hidden_states.to(query.dtype)
|
1388 |
-
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1389 |
-
|
1390 |
-
hidden_states = to_out(hidden_states)
|
1391 |
-
|
1392 |
-
return hidden_states
|
1393 |
-
|
1394 |
-
|
1395 |
class GEGLU(nn.Module):
|
1396 |
def __init__(self, dim_in: int, dim_out: int):
|
1397 |
super().__init__()
|
|
|
1 |
import math
|
2 |
import os
|
3 |
+
from typing import List, Literal, Optional
|
4 |
|
5 |
import safetensors.torch
|
6 |
import torch
|
|
|
1246 |
def forward(self, hidden_states, temb=None):
|
1247 |
residual = hidden_states
|
1248 |
|
|
|
|
|
|
|
|
|
|
|
1249 |
hidden_states = self.norm1(hidden_states)
|
1250 |
hidden_states = self.nonlinearity(hidden_states)
|
1251 |
hidden_states = self.conv1(hidden_states)
|
1252 |
|
1253 |
+
if self.time_emb_proj is not None:
|
1254 |
+
assert temb is not None
|
1255 |
+
temb = self.nonlinearity(temb)
|
1256 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
1257 |
hidden_states = hidden_states + temb
|
1258 |
|
1259 |
hidden_states = self.norm2(hidden_states)
|
|
|
1323 |
return hidden_states
|
1324 |
|
1325 |
|
1326 |
+
class AttentionMixin:
|
1327 |
+
attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "xformers"
|
1328 |
+
|
1329 |
+
@classmethod
|
1330 |
+
def attention(cls, to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
|
1331 |
+
batch_size, q_seq_len, channels = hidden_states.shape
|
1332 |
+
|
1333 |
+
if encoder_hidden_states is not None:
|
1334 |
+
kv = encoder_hidden_states
|
1335 |
+
else:
|
1336 |
+
kv = hidden_states
|
1337 |
+
|
1338 |
+
kv_seq_len = kv.shape[1]
|
1339 |
+
|
1340 |
+
query = to_q(hidden_states)
|
1341 |
+
key = to_k(kv)
|
1342 |
+
value = to_v(kv)
|
1343 |
+
|
1344 |
+
if AttentionMixin.attention_implementation == "xformers":
|
1345 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
|
1346 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1347 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1348 |
+
|
1349 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
|
1350 |
+
|
1351 |
+
hidden_states = hidden_states.to(query.dtype)
|
1352 |
+
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1353 |
+
elif AttentionMixin.attention_implementation == "torch_2.0_scaled_dot_product":
|
1354 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1355 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1356 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1357 |
+
|
1358 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value)
|
1359 |
+
|
1360 |
+
hidden_states = hidden_states.to(query.dtype)
|
1361 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous()
|
1362 |
+
else:
|
1363 |
+
assert False
|
1364 |
+
|
1365 |
+
hidden_states = to_out(hidden_states)
|
1366 |
+
|
1367 |
+
return hidden_states
|
1368 |
+
|
1369 |
+
|
1370 |
+
class Attention(nn.Module, AttentionMixin):
|
1371 |
def __init__(self, channels, encoder_hidden_states_dim):
|
1372 |
super().__init__()
|
1373 |
self.to_q = nn.Linear(channels, channels, bias=False)
|
|
|
1376 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1377 |
|
1378 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1379 |
+
return self.attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
|
1380 |
|
1381 |
|
1382 |
+
class VaeMidBlockAttention(nn.Module, AttentionMixin):
|
1383 |
def __init__(self, channels):
|
1384 |
super().__init__()
|
1385 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
|
|
1397 |
|
1398 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1399 |
|
1400 |
+
hidden_states = self.attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
|
1401 |
|
1402 |
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1403 |
|
|
|
1406 |
return hidden_states
|
1407 |
|
1408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1409 |
class GEGLU(nn.Module):
|
1410 |
def __init__(self, dim_in: int, dim_out: int):
|
1411 |
super().__init__()
|