williamberman
commited on
Commit
·
e4ea387
1
Parent(s):
9abdf02
fixes
Browse files- app.py +20 -11
- sdxl.py +70 -37
- sdxl_models.py +72 -29
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
-
from diffusers import AutoPipelineForInpainting
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
from sdxl import gen_sdxl_simplified_interface
|
@@ -10,9 +10,14 @@ from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCon
|
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
comparing_controlnet.to(torch.float16)
|
17 |
|
18 |
def read_content(file_path: str) -> str:
|
@@ -40,13 +45,15 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
|
|
40 |
init_image = dict["image"].convert("RGB").resize((1024, 1024))
|
41 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
42 |
|
43 |
-
output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
|
44 |
output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
|
45 |
-
|
46 |
text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
|
47 |
)
|
48 |
|
49 |
-
return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
|
|
|
|
|
50 |
|
51 |
|
52 |
css = '''
|
@@ -108,16 +115,18 @@ with image_blocks as demo:
|
|
108 |
scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
|
109 |
|
110 |
with gr.Column():
|
111 |
-
image_out = gr.Image(label="Output", elem_id="output-img", height=400)
|
112 |
-
image_out_comparing = gr.Image(label="Output", elem_id="output-img-comparing", height=400)
|
113 |
with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
|
114 |
community_icon = gr.HTML(community_icon_html)
|
115 |
loading_icon = gr.HTML(loading_icon_html)
|
116 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
117 |
|
118 |
|
119 |
-
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
|
120 |
-
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
|
|
|
|
|
121 |
share_button.click(None, [], [], _js=share_js)
|
122 |
|
123 |
gr.Examples(
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
+
from diffusers import AutoPipelineForInpainting, StableDiffusionXLPipeline
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
from sdxl import gen_sdxl_simplified_interface
|
|
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
12 |
|
13 |
+
# TODO - just download individual files
|
14 |
+
# StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16") # download weights
|
15 |
+
comparing_unet = SDXLUNet.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/unet/diffusion_pytorch_model.fp16.safetensors", device=device)
|
16 |
+
# comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/vae/diffusion_pytorch_model.fp16.safetensors", device=device)
|
17 |
+
comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--madebyollin--sdxl-vae-fp16-fix/snapshots/4df413ca49271c25289a6482ab97a433f8117d15/diffusion_pytorch_model.safetensors", device=device)
|
18 |
+
comparing_vae.to(torch.float16)
|
19 |
+
# comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("/fsx/william/diffusers-utils/output/sdxl_controlnet_inpaint_pre_encoded_controlnet_cond/checkpoint-200000/controlnet/diffusion_pytorch_model.safetensors", device="cuda") # TODO - upload checkpoint
|
20 |
+
comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("./controlnet_vae.safetensors", device="cuda") # TODO - upload checkpoint
|
21 |
comparing_controlnet.to(torch.float16)
|
22 |
|
23 |
def read_content(file_path: str) -> str:
|
|
|
45 |
init_image = dict["image"].convert("RGB").resize((1024, 1024))
|
46 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
47 |
|
48 |
+
# output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
|
49 |
output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
|
50 |
+
prompts=prompt, negative_prompts=negative_prompt, images=init_image, masks=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
|
51 |
text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
|
52 |
)
|
53 |
|
54 |
+
# return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
|
55 |
+
|
56 |
+
return output_controlnet_vae_encoding[0], gr.update(visible=True)
|
57 |
|
58 |
|
59 |
css = '''
|
|
|
115 |
scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
|
116 |
|
117 |
with gr.Column():
|
118 |
+
image_out = gr.Image(label="Output diffusers full finetune 0.1", elem_id="output-img", height=400)
|
119 |
+
image_out_comparing = gr.Image(label="Output controlnet + vae", elem_id="output-img-comparing", height=400)
|
120 |
with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
|
121 |
community_icon = gr.HTML(community_icon_html)
|
122 |
loading_icon = gr.HTML(loading_icon_html)
|
123 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
124 |
|
125 |
|
126 |
+
# btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
|
127 |
+
# prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
|
128 |
+
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container], api_name='run')
|
129 |
+
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container])
|
130 |
share_button.click(None, [], [], _js=share_js)
|
131 |
|
132 |
gr.Examples(
|
sdxl.py
CHANGED
@@ -388,9 +388,9 @@ def make_sample(d, proportion_empty_prompts, get_sdxl_conditioning_images=None):
|
|
388 |
|
389 |
micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
|
390 |
|
391 |
-
text_input_ids_one = sdxl_tokenize_one(text)
|
392 |
|
393 |
-
text_input_ids_two = sdxl_tokenize_two(text)
|
394 |
|
395 |
image = image.convert("RGB")
|
396 |
|
@@ -517,7 +517,7 @@ def sdxl_tokenize_one(prompts):
|
|
517 |
max_length=tokenizer_one.model_max_length,
|
518 |
truncation=True,
|
519 |
return_tensors="pt",
|
520 |
-
).input_ids
|
521 |
|
522 |
|
523 |
def sdxl_tokenize_two(prompts):
|
@@ -527,7 +527,7 @@ def sdxl_tokenize_two(prompts):
|
|
527 |
max_length=tokenizer_one.model_max_length,
|
528 |
truncation=True,
|
529 |
return_tensors="pt",
|
530 |
-
).input_ids
|
531 |
|
532 |
|
533 |
def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
|
@@ -667,7 +667,7 @@ def apply_padding(mask, coord):
|
|
667 |
|
668 |
@torch.no_grad()
|
669 |
def sdxl_diffusion_loop(
|
670 |
-
prompts,
|
671 |
unet,
|
672 |
text_encoder_one,
|
673 |
text_encoder_two,
|
@@ -683,8 +683,10 @@ def sdxl_diffusion_loop(
|
|
683 |
negative_prompts=None,
|
684 |
diffusion_loop=euler_ode_solver_diffusion_loop,
|
685 |
):
|
|
|
|
|
686 |
if negative_prompts is None:
|
687 |
-
negative_prompts = [""] *
|
688 |
|
689 |
prompts += negative_prompts
|
690 |
|
@@ -694,27 +696,30 @@ def sdxl_diffusion_loop(
|
|
694 |
sdxl_tokenize_one(prompts).to(text_encoder_one.device),
|
695 |
sdxl_tokenize_two(prompts).to(text_encoder_two.device),
|
696 |
)
|
697 |
-
|
698 |
-
|
699 |
-
x_T = torch.randn((1, 4, 1024 // 8, 1024 // 8), dtype=torch.float32, device=unet.device, generator=generator)
|
700 |
-
x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
|
701 |
|
702 |
if sigmas is None:
|
703 |
sigmas = make_sigmas(device=unet.device)
|
704 |
|
|
|
|
|
|
|
|
|
705 |
if timesteps is None:
|
706 |
timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
|
707 |
|
708 |
if micro_conditioning is None:
|
709 |
-
micro_conditioning = torch.tensor([1024, 1024, 0, 0, 1024, 1024], dtype=torch.long, device=unet.device)
|
|
|
710 |
|
711 |
if adapter is not None:
|
712 |
-
down_block_additional_residuals = adapter(images)
|
713 |
else:
|
714 |
down_block_additional_residuals = None
|
715 |
|
716 |
if controlnet is not None:
|
717 |
-
controlnet_cond = images
|
718 |
else:
|
719 |
controlnet_cond = None
|
720 |
|
@@ -756,21 +761,28 @@ def sdxl_eps_theta(
|
|
756 |
|
757 |
if guidance_scale > 1.0:
|
758 |
scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
|
|
|
|
|
|
|
759 |
|
760 |
if controlnet is not None:
|
761 |
controlnet_out = controlnet(
|
762 |
x_t=scaled_x_t,
|
763 |
t=t,
|
764 |
-
encoder_hidden_states=encoder_hidden_states,
|
765 |
-
micro_conditioning=micro_conditioning,
|
766 |
-
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
767 |
controlnet_cond=controlnet_cond,
|
768 |
)
|
769 |
|
770 |
-
down_block_additional_residuals = controlnet_out["down_block_res_samples"]
|
771 |
-
mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
|
772 |
add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
|
|
|
|
|
773 |
add_to_output = controlnet_out.get("add_to_output", None)
|
|
|
|
|
774 |
else:
|
775 |
mid_block_additional_residual = None
|
776 |
add_to_down_block_inputs = None
|
@@ -795,20 +807,24 @@ def sdxl_eps_theta(
|
|
795 |
|
796 |
return eps_hat
|
797 |
|
|
|
798 |
known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
|
799 |
|
|
|
|
|
800 |
def gen_sdxl_simplified_interface(
|
801 |
-
|
802 |
-
|
803 |
-
controlnet_checkpoint: Optional[str]=None,
|
804 |
-
controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]]=None,
|
805 |
-
adapter_checkpoint: Optional[str]=None,
|
806 |
num_inference_steps=50,
|
807 |
images=None,
|
808 |
masks=None,
|
809 |
-
apply_conditioning: Optional[Literal["canny"]]=None,
|
810 |
-
num_images: int=1,
|
811 |
-
|
|
|
812 |
text_encoder_one=None,
|
813 |
text_encoder_two=None,
|
814 |
unet=None,
|
@@ -886,22 +902,23 @@ def gen_sdxl_simplified_interface(
|
|
886 |
mask = masks[image_idx]
|
887 |
if isinstance(mask, str):
|
888 |
mask = Image.open(mask)
|
889 |
-
mask = mask.convert("L")
|
890 |
-
mask = mask.resize((1024, 1024))
|
891 |
elif isinstance(mask, Image.Image):
|
892 |
...
|
893 |
else:
|
894 |
assert False
|
|
|
|
|
895 |
mask = TF.to_tensor(mask)
|
896 |
|
897 |
-
if controlnet
|
898 |
image = image * (mask < 0.5)
|
899 |
-
image = TF.
|
900 |
-
image = vae.encode(image)
|
901 |
-
mask = TF.resize(mask, (1024 // 8, 1024 // 8))
|
902 |
-
image = torch.concat((image, mask))
|
903 |
else:
|
904 |
-
image = image * (mask < 0.5) + -1.0 * (mask >= 0.5)
|
|
|
905 |
|
906 |
images_.append(image)
|
907 |
|
@@ -909,9 +926,24 @@ def gen_sdxl_simplified_interface(
|
|
909 |
else:
|
910 |
images_ = None
|
911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
912 |
x_0 = sdxl_diffusion_loop(
|
913 |
-
prompts=
|
914 |
-
negative_prompts=
|
915 |
unet=unet,
|
916 |
text_encoder_one=text_encoder_one,
|
917 |
text_encoder_two=text_encoder_two,
|
@@ -920,9 +952,10 @@ def gen_sdxl_simplified_interface(
|
|
920 |
controlnet=controlnet,
|
921 |
adapter=adapter,
|
922 |
images=images_,
|
|
|
923 |
)
|
924 |
|
925 |
-
x_0 = vae.decode(x_0)
|
926 |
x_0 = vae.output_tensor_to_pil(x_0)
|
927 |
|
928 |
return x_0
|
|
|
388 |
|
389 |
micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
|
390 |
|
391 |
+
text_input_ids_one = sdxl_tokenize_one(text)[0]
|
392 |
|
393 |
+
text_input_ids_two = sdxl_tokenize_two(text)[0]
|
394 |
|
395 |
image = image.convert("RGB")
|
396 |
|
|
|
517 |
max_length=tokenizer_one.model_max_length,
|
518 |
truncation=True,
|
519 |
return_tensors="pt",
|
520 |
+
).input_ids
|
521 |
|
522 |
|
523 |
def sdxl_tokenize_two(prompts):
|
|
|
527 |
max_length=tokenizer_one.model_max_length,
|
528 |
truncation=True,
|
529 |
return_tensors="pt",
|
530 |
+
).input_ids
|
531 |
|
532 |
|
533 |
def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
|
|
|
667 |
|
668 |
@torch.no_grad()
|
669 |
def sdxl_diffusion_loop(
|
670 |
+
prompts: List[str],
|
671 |
unet,
|
672 |
text_encoder_one,
|
673 |
text_encoder_two,
|
|
|
683 |
negative_prompts=None,
|
684 |
diffusion_loop=euler_ode_solver_diffusion_loop,
|
685 |
):
|
686 |
+
batch_size = len(prompts)
|
687 |
+
|
688 |
if negative_prompts is None:
|
689 |
+
negative_prompts = [""] * batch_size
|
690 |
|
691 |
prompts += negative_prompts
|
692 |
|
|
|
696 |
sdxl_tokenize_one(prompts).to(text_encoder_one.device),
|
697 |
sdxl_tokenize_two(prompts).to(text_encoder_two.device),
|
698 |
)
|
699 |
+
encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
|
700 |
+
pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
|
|
|
|
|
701 |
|
702 |
if sigmas is None:
|
703 |
sigmas = make_sigmas(device=unet.device)
|
704 |
|
705 |
+
if x_T is None:
|
706 |
+
x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
|
707 |
+
x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
|
708 |
+
|
709 |
if timesteps is None:
|
710 |
timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
|
711 |
|
712 |
if micro_conditioning is None:
|
713 |
+
micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
|
714 |
+
micro_conditioning = micro_conditioning.expand(batch_size, -1)
|
715 |
|
716 |
if adapter is not None:
|
717 |
+
down_block_additional_residuals = adapter(images.to(dtype=adapter.dtype, device=adapter.device))
|
718 |
else:
|
719 |
down_block_additional_residuals = None
|
720 |
|
721 |
if controlnet is not None:
|
722 |
+
controlnet_cond = images.to(dtype=controlnet.dtype, device=controlnet.device)
|
723 |
else:
|
724 |
controlnet_cond = None
|
725 |
|
|
|
761 |
|
762 |
if guidance_scale > 1.0:
|
763 |
scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
|
764 |
+
micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
|
765 |
+
if controlnet_cond is not None:
|
766 |
+
controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
|
767 |
|
768 |
if controlnet is not None:
|
769 |
controlnet_out = controlnet(
|
770 |
x_t=scaled_x_t,
|
771 |
t=t,
|
772 |
+
encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
|
773 |
+
micro_conditioning=micro_conditioning.to(controlnet.dtype),
|
774 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states.to(controlnet.dtype),
|
775 |
controlnet_cond=controlnet_cond,
|
776 |
)
|
777 |
|
778 |
+
down_block_additional_residuals = [x.to(unet.dtype) for x in controlnet_out["down_block_res_samples"]]
|
779 |
+
mid_block_additional_residual = controlnet_out["mid_block_res_sample"].to(unet.dtype)
|
780 |
add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
|
781 |
+
if add_to_down_block_inputs is not None:
|
782 |
+
add_to_down_block_inputs = [x.to(unet.dtype) for x in add_to_down_block_inputs]
|
783 |
add_to_output = controlnet_out.get("add_to_output", None)
|
784 |
+
if add_to_output is not None:
|
785 |
+
add_to_output = add_to_output.to(unet.dtype)
|
786 |
else:
|
787 |
mid_block_additional_residual = None
|
788 |
add_to_down_block_inputs = None
|
|
|
807 |
|
808 |
return eps_hat
|
809 |
|
810 |
+
|
811 |
known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
|
812 |
|
813 |
+
|
814 |
+
# TODO probably just combine with sdxl_diffusion_loop
|
815 |
def gen_sdxl_simplified_interface(
|
816 |
+
prompts: Union[str, List[str]],
|
817 |
+
negative_prompts: Optional[Union[str, List[str]]] = None,
|
818 |
+
controlnet_checkpoint: Optional[str] = None,
|
819 |
+
controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]] = None,
|
820 |
+
adapter_checkpoint: Optional[str] = None,
|
821 |
num_inference_steps=50,
|
822 |
images=None,
|
823 |
masks=None,
|
824 |
+
apply_conditioning: Optional[Literal["canny"]] = None,
|
825 |
+
num_images: int = 1,
|
826 |
+
guidance_scale=5.0,
|
827 |
+
device: Optional[str] = None,
|
828 |
text_encoder_one=None,
|
829 |
text_encoder_two=None,
|
830 |
unet=None,
|
|
|
902 |
mask = masks[image_idx]
|
903 |
if isinstance(mask, str):
|
904 |
mask = Image.open(mask)
|
|
|
|
|
905 |
elif isinstance(mask, Image.Image):
|
906 |
...
|
907 |
else:
|
908 |
assert False
|
909 |
+
mask = mask.convert("L")
|
910 |
+
mask = mask.resize((1024, 1024))
|
911 |
mask = TF.to_tensor(mask)
|
912 |
|
913 |
+
if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
|
914 |
image = image * (mask < 0.5)
|
915 |
+
image = TF.normalize(image, [0.5], [0.5])
|
916 |
+
image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=unet.dtype, device=unet.device)
|
917 |
+
mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
|
918 |
+
image = torch.concat((image, mask), dim=1)
|
919 |
else:
|
920 |
+
image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=unet.dtype, device=unet.device)
|
921 |
+
image = image[None, :, :, :]
|
922 |
|
923 |
images_.append(image)
|
924 |
|
|
|
926 |
else:
|
927 |
images_ = None
|
928 |
|
929 |
+
if isinstance(prompts, str):
|
930 |
+
prompts = [prompts]
|
931 |
+
prompts_ = []
|
932 |
+
for prompt in prompts:
|
933 |
+
prompts_ += [prompt] * num_images
|
934 |
+
|
935 |
+
if negative_prompts is not None:
|
936 |
+
if isinstance(negative_prompts, str):
|
937 |
+
negative_prompts = [negative_prompts]
|
938 |
+
negative_prompts_ = []
|
939 |
+
for negative_prompt in negative_prompts:
|
940 |
+
negative_prompts_ += [negative_prompt] * num_images
|
941 |
+
else:
|
942 |
+
negative_prompts_ = None
|
943 |
+
|
944 |
x_0 = sdxl_diffusion_loop(
|
945 |
+
prompts=prompts_,
|
946 |
+
negative_prompts=negative_prompts_,
|
947 |
unet=unet,
|
948 |
text_encoder_one=text_encoder_one,
|
949 |
text_encoder_two=text_encoder_two,
|
|
|
952 |
controlnet=controlnet,
|
953 |
adapter=adapter,
|
954 |
images=images_,
|
955 |
+
guidance_scale=guidance_scale,
|
956 |
)
|
957 |
|
958 |
+
x_0 = vae.decode(x_0.to(vae.dtype))
|
959 |
x_0 = vae.output_tensor_to_pil(x_0)
|
960 |
|
961 |
return x_0
|
sdxl_models.py
CHANGED
@@ -26,7 +26,8 @@ class ModelUtils:
|
|
26 |
|
27 |
load_from = [load_from]
|
28 |
|
29 |
-
|
|
|
30 |
|
31 |
state_dict = {}
|
32 |
|
@@ -79,7 +80,7 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
79 |
|
80 |
# 512 -> 512
|
81 |
mid_block=nn.ModuleDict(dict(
|
82 |
-
attentions=nn.ModuleList([
|
83 |
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
84 |
)),
|
85 |
|
@@ -95,7 +96,7 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
95 |
# 8 -> 4 from sampling mean and std
|
96 |
|
97 |
# 4 -> 4
|
98 |
-
self.post_quant_conv = nn.Conv2d(4, 4, 1)
|
99 |
|
100 |
self.decoder = nn.ModuleDict(dict(
|
101 |
# 4 -> 512
|
@@ -103,7 +104,7 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
103 |
|
104 |
# 512 -> 512
|
105 |
mid_block=nn.ModuleDict(dict(
|
106 |
-
attentions=nn.ModuleList([
|
107 |
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
108 |
)),
|
109 |
|
@@ -179,15 +180,18 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
179 |
|
180 |
h = self.post_quant_conv(h)
|
181 |
|
|
|
|
|
182 |
h = self.decoder["mid_block"]["resnets"][0](h)
|
183 |
h = self.decoder["mid_block"]["attentions"][0](h)
|
184 |
h = self.decoder["mid_block"]["resnets"][1](h)
|
185 |
|
186 |
-
for up_block in self.
|
187 |
for resnet in up_block["resnets"]:
|
188 |
h = resnet(h)
|
189 |
|
190 |
if "upsamplers" in up_block:
|
|
|
191 |
h = up_block["upsamplers"][0]["conv"](h)
|
192 |
|
193 |
h = self.decoder["conv_norm_out"](h)
|
@@ -208,9 +212,7 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
208 |
|
209 |
@classmethod
|
210 |
def output_tensor_to_pil(self, x_pred):
|
211 |
-
x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
212 |
-
|
213 |
-
x_pred = x_pred.permute(0, 2, 3, 1).cpu().numpy()
|
214 |
|
215 |
x_pred = [Image.fromarray(x) for x in x_pred]
|
216 |
|
@@ -1323,42 +1325,83 @@ class TransformerDecoderBlock(nn.Module):
|
|
1323 |
|
1324 |
|
1325 |
class Attention(nn.Module):
|
1326 |
-
def __init__(self, channels, encoder_hidden_states_dim
|
1327 |
super().__init__()
|
1328 |
-
self.to_q = nn.Linear(channels, channels, bias=
|
1329 |
-
self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=
|
1330 |
-
self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=
|
1331 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1332 |
|
1333 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1334 |
-
|
1335 |
-
head_dim = 64
|
1336 |
|
1337 |
-
if
|
1338 |
-
|
1339 |
-
|
1340 |
-
|
|
|
|
|
|
|
|
|
1341 |
|
1342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1343 |
|
1344 |
-
|
1345 |
-
|
1346 |
-
value = self.to_v(kv)
|
1347 |
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
1351 |
|
1352 |
-
hidden_states =
|
1353 |
|
1354 |
-
hidden_states =
|
1355 |
-
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1356 |
|
1357 |
-
|
|
|
1358 |
|
1359 |
return hidden_states
|
1360 |
|
1361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1362 |
class GEGLU(nn.Module):
|
1363 |
def __init__(self, dim_in: int, dim_out: int):
|
1364 |
super().__init__()
|
|
|
26 |
|
27 |
load_from = [load_from]
|
28 |
|
29 |
+
if overrides is not None:
|
30 |
+
load_from += overrides
|
31 |
|
32 |
state_dict = {}
|
33 |
|
|
|
80 |
|
81 |
# 512 -> 512
|
82 |
mid_block=nn.ModuleDict(dict(
|
83 |
+
attentions=nn.ModuleList([VaeMidBlockAttention(512)]),
|
84 |
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
85 |
)),
|
86 |
|
|
|
96 |
# 8 -> 4 from sampling mean and std
|
97 |
|
98 |
# 4 -> 4
|
99 |
+
self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)
|
100 |
|
101 |
self.decoder = nn.ModuleDict(dict(
|
102 |
# 4 -> 512
|
|
|
104 |
|
105 |
# 512 -> 512
|
106 |
mid_block=nn.ModuleDict(dict(
|
107 |
+
attentions=nn.ModuleList([VaeMidBlockAttention(512)]),
|
108 |
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
109 |
)),
|
110 |
|
|
|
180 |
|
181 |
h = self.post_quant_conv(h)
|
182 |
|
183 |
+
h = self.decoder["conv_in"](h)
|
184 |
+
|
185 |
h = self.decoder["mid_block"]["resnets"][0](h)
|
186 |
h = self.decoder["mid_block"]["attentions"][0](h)
|
187 |
h = self.decoder["mid_block"]["resnets"][1](h)
|
188 |
|
189 |
+
for up_block in self.decoder["up_blocks"]:
|
190 |
for resnet in up_block["resnets"]:
|
191 |
h = resnet(h)
|
192 |
|
193 |
if "upsamplers" in up_block:
|
194 |
+
h = F.interpolate(h, scale_factor=2.0, mode="nearest")
|
195 |
h = up_block["upsamplers"][0]["conv"](h)
|
196 |
|
197 |
h = self.decoder["conv_norm_out"](h)
|
|
|
212 |
|
213 |
@classmethod
|
214 |
def output_tensor_to_pil(self, x_pred):
|
215 |
+
x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
|
|
|
|
216 |
|
217 |
x_pred = [Image.fromarray(x) for x in x_pred]
|
218 |
|
|
|
1325 |
|
1326 |
|
1327 |
class Attention(nn.Module):
|
1328 |
+
def __init__(self, channels, encoder_hidden_states_dim):
|
1329 |
super().__init__()
|
1330 |
+
self.to_q = nn.Linear(channels, channels, bias=False)
|
1331 |
+
self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=False)
|
1332 |
+
self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=False)
|
1333 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1334 |
|
1335 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1336 |
+
input_ndim = hidden_states.ndim
|
|
|
1337 |
|
1338 |
+
if input_ndim == 4:
|
1339 |
+
batch_size, channels, height, width = hidden_states.shape
|
1340 |
+
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
|
1341 |
+
|
1342 |
+
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states, encoder_hidden_states)
|
1343 |
+
|
1344 |
+
if input_ndim == 4:
|
1345 |
+
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1346 |
|
1347 |
+
return hidden_states
|
1348 |
+
|
1349 |
+
|
1350 |
+
class VaeMidBlockAttention(nn.Module):
|
1351 |
+
def __init__(self, channels):
|
1352 |
+
super().__init__()
|
1353 |
+
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
1354 |
+
self.to_q = nn.Linear(channels, channels, bias=True)
|
1355 |
+
self.to_k = nn.Linear(channels, channels, bias=True)
|
1356 |
+
self.to_v = nn.Linear(channels, channels, bias=True)
|
1357 |
+
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1358 |
|
1359 |
+
def forward(self, hidden_states):
|
1360 |
+
input_ndim = hidden_states.ndim
|
|
|
1361 |
|
1362 |
+
if input_ndim == 4:
|
1363 |
+
batch_size, channels, height, width = hidden_states.shape
|
1364 |
+
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
|
1365 |
|
1366 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1367 |
|
1368 |
+
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states)
|
|
|
1369 |
|
1370 |
+
if input_ndim == 4:
|
1371 |
+
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1372 |
|
1373 |
return hidden_states
|
1374 |
|
1375 |
|
1376 |
+
def attention(to_q, to_k, to_v, to_out, hidden_states, encoder_hidden_states=None):
|
1377 |
+
batch_size, q_seq_len, channels = hidden_states.shape
|
1378 |
+
head_dim = 64
|
1379 |
+
|
1380 |
+
if encoder_hidden_states is not None:
|
1381 |
+
kv = encoder_hidden_states
|
1382 |
+
else:
|
1383 |
+
kv = hidden_states
|
1384 |
+
|
1385 |
+
kv_seq_len = kv.shape[1]
|
1386 |
+
|
1387 |
+
query = to_q(hidden_states)
|
1388 |
+
key = to_k(kv)
|
1389 |
+
value = to_v(kv)
|
1390 |
+
|
1391 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
|
1392 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1393 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1394 |
+
|
1395 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
|
1396 |
+
|
1397 |
+
hidden_states = hidden_states.to(query.dtype)
|
1398 |
+
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1399 |
+
|
1400 |
+
hidden_states = to_out(hidden_states)
|
1401 |
+
|
1402 |
+
return hidden_states
|
1403 |
+
|
1404 |
+
|
1405 |
class GEGLU(nn.Module):
|
1406 |
def __init__(self, dim_in: int, dim_out: int):
|
1407 |
super().__init__()
|