williamberman commited on
Commit
e4ea387
·
1 Parent(s): 9abdf02
Files changed (3) hide show
  1. app.py +20 -11
  2. sdxl.py +70 -37
  3. sdxl_models.py +72 -29
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
 
4
- from diffusers import AutoPipelineForInpainting
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from sdxl import gen_sdxl_simplified_interface
@@ -10,9 +10,14 @@ from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCon
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
12
 
13
- comparing_unet = SDXLUNet.load_fp16(device=device)
14
- comparing_vae = SDXLVae.load_fp16_fix(device=device)
15
- comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("", device="cuda") # TODO - upload checkpoint
 
 
 
 
 
16
  comparing_controlnet.to(torch.float16)
17
 
18
  def read_content(file_path: str) -> str:
@@ -40,13 +45,15 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
40
  init_image = dict["image"].convert("RGB").resize((1024, 1024))
41
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
42
 
43
- output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
44
  output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
45
- prompt=prompt, negative_prompt=negative_prompt, images=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
46
  text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
47
  )
48
 
49
- return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
 
 
50
 
51
 
52
  css = '''
@@ -108,16 +115,18 @@ with image_blocks as demo:
108
  scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
109
 
110
  with gr.Column():
111
- image_out = gr.Image(label="Output", elem_id="output-img", height=400)
112
- image_out_comparing = gr.Image(label="Output", elem_id="output-img-comparing", height=400)
113
  with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
114
  community_icon = gr.HTML(community_icon_html)
115
  loading_icon = gr.HTML(loading_icon_html)
116
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
117
 
118
 
119
- btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
120
- prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
 
 
121
  share_button.click(None, [], [], _js=share_js)
122
 
123
  gr.Examples(
 
1
  import gradio as gr
2
  import torch
3
 
4
+ from diffusers import AutoPipelineForInpainting, StableDiffusionXLPipeline
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from sdxl import gen_sdxl_simplified_interface
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
12
 
13
+ # TODO - just download individual files
14
+ # StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16") # download weights
15
+ comparing_unet = SDXLUNet.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/unet/diffusion_pytorch_model.fp16.safetensors", device=device)
16
+ # comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/vae/diffusion_pytorch_model.fp16.safetensors", device=device)
17
+ comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--madebyollin--sdxl-vae-fp16-fix/snapshots/4df413ca49271c25289a6482ab97a433f8117d15/diffusion_pytorch_model.safetensors", device=device)
18
+ comparing_vae.to(torch.float16)
19
+ # comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("/fsx/william/diffusers-utils/output/sdxl_controlnet_inpaint_pre_encoded_controlnet_cond/checkpoint-200000/controlnet/diffusion_pytorch_model.safetensors", device="cuda") # TODO - upload checkpoint
20
+ comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("./controlnet_vae.safetensors", device="cuda") # TODO - upload checkpoint
21
  comparing_controlnet.to(torch.float16)
22
 
23
  def read_content(file_path: str) -> str:
 
45
  init_image = dict["image"].convert("RGB").resize((1024, 1024))
46
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
47
 
48
+ # output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
49
  output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
50
+ prompts=prompt, negative_prompts=negative_prompt, images=init_image, masks=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
51
  text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
52
  )
53
 
54
+ # return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
55
+
56
+ return output_controlnet_vae_encoding[0], gr.update(visible=True)
57
 
58
 
59
  css = '''
 
115
  scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
116
 
117
  with gr.Column():
118
+ image_out = gr.Image(label="Output diffusers full finetune 0.1", elem_id="output-img", height=400)
119
+ image_out_comparing = gr.Image(label="Output controlnet + vae", elem_id="output-img-comparing", height=400)
120
  with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
121
  community_icon = gr.HTML(community_icon_html)
122
  loading_icon = gr.HTML(loading_icon_html)
123
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
124
 
125
 
126
+ # btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
127
+ # prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
128
+ btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container], api_name='run')
129
+ prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container])
130
  share_button.click(None, [], [], _js=share_js)
131
 
132
  gr.Examples(
sdxl.py CHANGED
@@ -388,9 +388,9 @@ def make_sample(d, proportion_empty_prompts, get_sdxl_conditioning_images=None):
388
 
389
  micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
390
 
391
- text_input_ids_one = sdxl_tokenize_one(text)
392
 
393
- text_input_ids_two = sdxl_tokenize_two(text)
394
 
395
  image = image.convert("RGB")
396
 
@@ -517,7 +517,7 @@ def sdxl_tokenize_one(prompts):
517
  max_length=tokenizer_one.model_max_length,
518
  truncation=True,
519
  return_tensors="pt",
520
- ).input_ids[0]
521
 
522
 
523
  def sdxl_tokenize_two(prompts):
@@ -527,7 +527,7 @@ def sdxl_tokenize_two(prompts):
527
  max_length=tokenizer_one.model_max_length,
528
  truncation=True,
529
  return_tensors="pt",
530
- ).input_ids[0]
531
 
532
 
533
  def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
@@ -667,7 +667,7 @@ def apply_padding(mask, coord):
667
 
668
  @torch.no_grad()
669
  def sdxl_diffusion_loop(
670
- prompts,
671
  unet,
672
  text_encoder_one,
673
  text_encoder_two,
@@ -683,8 +683,10 @@ def sdxl_diffusion_loop(
683
  negative_prompts=None,
684
  diffusion_loop=euler_ode_solver_diffusion_loop,
685
  ):
 
 
686
  if negative_prompts is None:
687
- negative_prompts = [""] * len(prompts)
688
 
689
  prompts += negative_prompts
690
 
@@ -694,27 +696,30 @@ def sdxl_diffusion_loop(
694
  sdxl_tokenize_one(prompts).to(text_encoder_one.device),
695
  sdxl_tokenize_two(prompts).to(text_encoder_two.device),
696
  )
697
-
698
- if x_T is None:
699
- x_T = torch.randn((1, 4, 1024 // 8, 1024 // 8), dtype=torch.float32, device=unet.device, generator=generator)
700
- x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
701
 
702
  if sigmas is None:
703
  sigmas = make_sigmas(device=unet.device)
704
 
 
 
 
 
705
  if timesteps is None:
706
  timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
707
 
708
  if micro_conditioning is None:
709
- micro_conditioning = torch.tensor([1024, 1024, 0, 0, 1024, 1024], dtype=torch.long, device=unet.device)
 
710
 
711
  if adapter is not None:
712
- down_block_additional_residuals = adapter(images)
713
  else:
714
  down_block_additional_residuals = None
715
 
716
  if controlnet is not None:
717
- controlnet_cond = images
718
  else:
719
  controlnet_cond = None
720
 
@@ -756,21 +761,28 @@ def sdxl_eps_theta(
756
 
757
  if guidance_scale > 1.0:
758
  scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
 
 
 
759
 
760
  if controlnet is not None:
761
  controlnet_out = controlnet(
762
  x_t=scaled_x_t,
763
  t=t,
764
- encoder_hidden_states=encoder_hidden_states,
765
- micro_conditioning=micro_conditioning,
766
- pooled_encoder_hidden_states=pooled_encoder_hidden_states,
767
  controlnet_cond=controlnet_cond,
768
  )
769
 
770
- down_block_additional_residuals = controlnet_out["down_block_res_samples"]
771
- mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
772
  add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
 
 
773
  add_to_output = controlnet_out.get("add_to_output", None)
 
 
774
  else:
775
  mid_block_additional_residual = None
776
  add_to_down_block_inputs = None
@@ -795,20 +807,24 @@ def sdxl_eps_theta(
795
 
796
  return eps_hat
797
 
 
798
  known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
799
 
 
 
800
  def gen_sdxl_simplified_interface(
801
- prompt:str,
802
- negative_prompt: Optional[str] = None,
803
- controlnet_checkpoint: Optional[str]=None,
804
- controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]]=None,
805
- adapter_checkpoint: Optional[str]=None,
806
  num_inference_steps=50,
807
  images=None,
808
  masks=None,
809
- apply_conditioning: Optional[Literal["canny"]]=None,
810
- num_images: int=1,
811
- device: Optional[str]=None,
 
812
  text_encoder_one=None,
813
  text_encoder_two=None,
814
  unet=None,
@@ -886,22 +902,23 @@ def gen_sdxl_simplified_interface(
886
  mask = masks[image_idx]
887
  if isinstance(mask, str):
888
  mask = Image.open(mask)
889
- mask = mask.convert("L")
890
- mask = mask.resize((1024, 1024))
891
  elif isinstance(mask, Image.Image):
892
  ...
893
  else:
894
  assert False
 
 
895
  mask = TF.to_tensor(mask)
896
 
897
- if controlnet == "SDXLControlNetPreEncodedControlnetCond":
898
  image = image * (mask < 0.5)
899
- image = TF.normalized(image, [0.5], [0.5])
900
- image = vae.encode(image)
901
- mask = TF.resize(mask, (1024 // 8, 1024 // 8))
902
- image = torch.concat((image, mask))
903
  else:
904
- image = image * (mask < 0.5) + -1.0 * (mask >= 0.5)
 
905
 
906
  images_.append(image)
907
 
@@ -909,9 +926,24 @@ def gen_sdxl_simplified_interface(
909
  else:
910
  images_ = None
911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
  x_0 = sdxl_diffusion_loop(
913
- prompts=[prompt] * num_images,
914
- negative_prompts=[negative_prompt] * num_images,
915
  unet=unet,
916
  text_encoder_one=text_encoder_one,
917
  text_encoder_two=text_encoder_two,
@@ -920,9 +952,10 @@ def gen_sdxl_simplified_interface(
920
  controlnet=controlnet,
921
  adapter=adapter,
922
  images=images_,
 
923
  )
924
 
925
- x_0 = vae.decode(x_0)
926
  x_0 = vae.output_tensor_to_pil(x_0)
927
 
928
  return x_0
 
388
 
389
  micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
390
 
391
+ text_input_ids_one = sdxl_tokenize_one(text)[0]
392
 
393
+ text_input_ids_two = sdxl_tokenize_two(text)[0]
394
 
395
  image = image.convert("RGB")
396
 
 
517
  max_length=tokenizer_one.model_max_length,
518
  truncation=True,
519
  return_tensors="pt",
520
+ ).input_ids
521
 
522
 
523
  def sdxl_tokenize_two(prompts):
 
527
  max_length=tokenizer_one.model_max_length,
528
  truncation=True,
529
  return_tensors="pt",
530
+ ).input_ids
531
 
532
 
533
  def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
 
667
 
668
  @torch.no_grad()
669
  def sdxl_diffusion_loop(
670
+ prompts: List[str],
671
  unet,
672
  text_encoder_one,
673
  text_encoder_two,
 
683
  negative_prompts=None,
684
  diffusion_loop=euler_ode_solver_diffusion_loop,
685
  ):
686
+ batch_size = len(prompts)
687
+
688
  if negative_prompts is None:
689
+ negative_prompts = [""] * batch_size
690
 
691
  prompts += negative_prompts
692
 
 
696
  sdxl_tokenize_one(prompts).to(text_encoder_one.device),
697
  sdxl_tokenize_two(prompts).to(text_encoder_two.device),
698
  )
699
+ encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
700
+ pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
 
 
701
 
702
  if sigmas is None:
703
  sigmas = make_sigmas(device=unet.device)
704
 
705
+ if x_T is None:
706
+ x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
707
+ x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
708
+
709
  if timesteps is None:
710
  timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
711
 
712
  if micro_conditioning is None:
713
+ micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
714
+ micro_conditioning = micro_conditioning.expand(batch_size, -1)
715
 
716
  if adapter is not None:
717
+ down_block_additional_residuals = adapter(images.to(dtype=adapter.dtype, device=adapter.device))
718
  else:
719
  down_block_additional_residuals = None
720
 
721
  if controlnet is not None:
722
+ controlnet_cond = images.to(dtype=controlnet.dtype, device=controlnet.device)
723
  else:
724
  controlnet_cond = None
725
 
 
761
 
762
  if guidance_scale > 1.0:
763
  scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
764
+ micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
765
+ if controlnet_cond is not None:
766
+ controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
767
 
768
  if controlnet is not None:
769
  controlnet_out = controlnet(
770
  x_t=scaled_x_t,
771
  t=t,
772
+ encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
773
+ micro_conditioning=micro_conditioning.to(controlnet.dtype),
774
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states.to(controlnet.dtype),
775
  controlnet_cond=controlnet_cond,
776
  )
777
 
778
+ down_block_additional_residuals = [x.to(unet.dtype) for x in controlnet_out["down_block_res_samples"]]
779
+ mid_block_additional_residual = controlnet_out["mid_block_res_sample"].to(unet.dtype)
780
  add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
781
+ if add_to_down_block_inputs is not None:
782
+ add_to_down_block_inputs = [x.to(unet.dtype) for x in add_to_down_block_inputs]
783
  add_to_output = controlnet_out.get("add_to_output", None)
784
+ if add_to_output is not None:
785
+ add_to_output = add_to_output.to(unet.dtype)
786
  else:
787
  mid_block_additional_residual = None
788
  add_to_down_block_inputs = None
 
807
 
808
  return eps_hat
809
 
810
+
811
  known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
812
 
813
+
814
+ # TODO probably just combine with sdxl_diffusion_loop
815
  def gen_sdxl_simplified_interface(
816
+ prompts: Union[str, List[str]],
817
+ negative_prompts: Optional[Union[str, List[str]]] = None,
818
+ controlnet_checkpoint: Optional[str] = None,
819
+ controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]] = None,
820
+ adapter_checkpoint: Optional[str] = None,
821
  num_inference_steps=50,
822
  images=None,
823
  masks=None,
824
+ apply_conditioning: Optional[Literal["canny"]] = None,
825
+ num_images: int = 1,
826
+ guidance_scale=5.0,
827
+ device: Optional[str] = None,
828
  text_encoder_one=None,
829
  text_encoder_two=None,
830
  unet=None,
 
902
  mask = masks[image_idx]
903
  if isinstance(mask, str):
904
  mask = Image.open(mask)
 
 
905
  elif isinstance(mask, Image.Image):
906
  ...
907
  else:
908
  assert False
909
+ mask = mask.convert("L")
910
+ mask = mask.resize((1024, 1024))
911
  mask = TF.to_tensor(mask)
912
 
913
+ if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
914
  image = image * (mask < 0.5)
915
+ image = TF.normalize(image, [0.5], [0.5])
916
+ image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=unet.dtype, device=unet.device)
917
+ mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
918
+ image = torch.concat((image, mask), dim=1)
919
  else:
920
+ image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=unet.dtype, device=unet.device)
921
+ image = image[None, :, :, :]
922
 
923
  images_.append(image)
924
 
 
926
  else:
927
  images_ = None
928
 
929
+ if isinstance(prompts, str):
930
+ prompts = [prompts]
931
+ prompts_ = []
932
+ for prompt in prompts:
933
+ prompts_ += [prompt] * num_images
934
+
935
+ if negative_prompts is not None:
936
+ if isinstance(negative_prompts, str):
937
+ negative_prompts = [negative_prompts]
938
+ negative_prompts_ = []
939
+ for negative_prompt in negative_prompts:
940
+ negative_prompts_ += [negative_prompt] * num_images
941
+ else:
942
+ negative_prompts_ = None
943
+
944
  x_0 = sdxl_diffusion_loop(
945
+ prompts=prompts_,
946
+ negative_prompts=negative_prompts_,
947
  unet=unet,
948
  text_encoder_one=text_encoder_one,
949
  text_encoder_two=text_encoder_two,
 
952
  controlnet=controlnet,
953
  adapter=adapter,
954
  images=images_,
955
+ guidance_scale=guidance_scale,
956
  )
957
 
958
+ x_0 = vae.decode(x_0.to(vae.dtype))
959
  x_0 = vae.output_tensor_to_pil(x_0)
960
 
961
  return x_0
sdxl_models.py CHANGED
@@ -26,7 +26,8 @@ class ModelUtils:
26
 
27
  load_from = [load_from]
28
 
29
- load_from += overrides
 
30
 
31
  state_dict = {}
32
 
@@ -79,7 +80,7 @@ class SDXLVae(nn.Module, ModelUtils):
79
 
80
  # 512 -> 512
81
  mid_block=nn.ModuleDict(dict(
82
- attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
83
  resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
84
  )),
85
 
@@ -95,7 +96,7 @@ class SDXLVae(nn.Module, ModelUtils):
95
  # 8 -> 4 from sampling mean and std
96
 
97
  # 4 -> 4
98
- self.post_quant_conv = nn.Conv2d(4, 4, 1)
99
 
100
  self.decoder = nn.ModuleDict(dict(
101
  # 4 -> 512
@@ -103,7 +104,7 @@ class SDXLVae(nn.Module, ModelUtils):
103
 
104
  # 512 -> 512
105
  mid_block=nn.ModuleDict(dict(
106
- attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
107
  resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
108
  )),
109
 
@@ -179,15 +180,18 @@ class SDXLVae(nn.Module, ModelUtils):
179
 
180
  h = self.post_quant_conv(h)
181
 
 
 
182
  h = self.decoder["mid_block"]["resnets"][0](h)
183
  h = self.decoder["mid_block"]["attentions"][0](h)
184
  h = self.decoder["mid_block"]["resnets"][1](h)
185
 
186
- for up_block in self.encoder["up_blocks"]:
187
  for resnet in up_block["resnets"]:
188
  h = resnet(h)
189
 
190
  if "upsamplers" in up_block:
 
191
  h = up_block["upsamplers"][0]["conv"](h)
192
 
193
  h = self.decoder["conv_norm_out"](h)
@@ -208,9 +212,7 @@ class SDXLVae(nn.Module, ModelUtils):
208
 
209
  @classmethod
210
  def output_tensor_to_pil(self, x_pred):
211
- x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1)
212
-
213
- x_pred = x_pred.permute(0, 2, 3, 1).cpu().numpy()
214
 
215
  x_pred = [Image.fromarray(x) for x in x_pred]
216
 
@@ -1323,42 +1325,83 @@ class TransformerDecoderBlock(nn.Module):
1323
 
1324
 
1325
  class Attention(nn.Module):
1326
- def __init__(self, channels, encoder_hidden_states_dim, qkv_bias=False):
1327
  super().__init__()
1328
- self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
1329
- self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
1330
- self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
1331
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1332
 
1333
  def forward(self, hidden_states, encoder_hidden_states=None):
1334
- batch_size, q_seq_len, channels = hidden_states.shape
1335
- head_dim = 64
1336
 
1337
- if encoder_hidden_states is not None:
1338
- kv = encoder_hidden_states
1339
- else:
1340
- kv = hidden_states
 
 
 
 
1341
 
1342
- kv_seq_len = kv.shape[1]
 
 
 
 
 
 
 
 
 
 
1343
 
1344
- query = self.to_q(hidden_states)
1345
- key = self.to_k(kv)
1346
- value = self.to_v(kv)
1347
 
1348
- query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1349
- key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1350
- value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1351
 
1352
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1353
 
1354
- hidden_states = hidden_states.to(query.dtype)
1355
- hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1356
 
1357
- hidden_states = self.to_out(hidden_states)
 
1358
 
1359
  return hidden_states
1360
 
1361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1362
  class GEGLU(nn.Module):
1363
  def __init__(self, dim_in: int, dim_out: int):
1364
  super().__init__()
 
26
 
27
  load_from = [load_from]
28
 
29
+ if overrides is not None:
30
+ load_from += overrides
31
 
32
  state_dict = {}
33
 
 
80
 
81
  # 512 -> 512
82
  mid_block=nn.ModuleDict(dict(
83
+ attentions=nn.ModuleList([VaeMidBlockAttention(512)]),
84
  resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
85
  )),
86
 
 
96
  # 8 -> 4 from sampling mean and std
97
 
98
  # 4 -> 4
99
+ self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)
100
 
101
  self.decoder = nn.ModuleDict(dict(
102
  # 4 -> 512
 
104
 
105
  # 512 -> 512
106
  mid_block=nn.ModuleDict(dict(
107
+ attentions=nn.ModuleList([VaeMidBlockAttention(512)]),
108
  resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
109
  )),
110
 
 
180
 
181
  h = self.post_quant_conv(h)
182
 
183
+ h = self.decoder["conv_in"](h)
184
+
185
  h = self.decoder["mid_block"]["resnets"][0](h)
186
  h = self.decoder["mid_block"]["attentions"][0](h)
187
  h = self.decoder["mid_block"]["resnets"][1](h)
188
 
189
+ for up_block in self.decoder["up_blocks"]:
190
  for resnet in up_block["resnets"]:
191
  h = resnet(h)
192
 
193
  if "upsamplers" in up_block:
194
+ h = F.interpolate(h, scale_factor=2.0, mode="nearest")
195
  h = up_block["upsamplers"][0]["conv"](h)
196
 
197
  h = self.decoder["conv_norm_out"](h)
 
212
 
213
  @classmethod
214
  def output_tensor_to_pil(self, x_pred):
215
+ x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
 
 
216
 
217
  x_pred = [Image.fromarray(x) for x in x_pred]
218
 
 
1325
 
1326
 
1327
  class Attention(nn.Module):
1328
+ def __init__(self, channels, encoder_hidden_states_dim):
1329
  super().__init__()
1330
+ self.to_q = nn.Linear(channels, channels, bias=False)
1331
+ self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=False)
1332
+ self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=False)
1333
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1334
 
1335
  def forward(self, hidden_states, encoder_hidden_states=None):
1336
+ input_ndim = hidden_states.ndim
 
1337
 
1338
+ if input_ndim == 4:
1339
+ batch_size, channels, height, width = hidden_states.shape
1340
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
1341
+
1342
+ hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states, encoder_hidden_states)
1343
+
1344
+ if input_ndim == 4:
1345
+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1346
 
1347
+ return hidden_states
1348
+
1349
+
1350
+ class VaeMidBlockAttention(nn.Module):
1351
+ def __init__(self, channels):
1352
+ super().__init__()
1353
+ self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
1354
+ self.to_q = nn.Linear(channels, channels, bias=True)
1355
+ self.to_k = nn.Linear(channels, channels, bias=True)
1356
+ self.to_v = nn.Linear(channels, channels, bias=True)
1357
+ self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1358
 
1359
+ def forward(self, hidden_states):
1360
+ input_ndim = hidden_states.ndim
 
1361
 
1362
+ if input_ndim == 4:
1363
+ batch_size, channels, height, width = hidden_states.shape
1364
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
1365
 
1366
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1367
 
1368
+ hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states)
 
1369
 
1370
+ if input_ndim == 4:
1371
+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1372
 
1373
  return hidden_states
1374
 
1375
 
1376
+ def attention(to_q, to_k, to_v, to_out, hidden_states, encoder_hidden_states=None):
1377
+ batch_size, q_seq_len, channels = hidden_states.shape
1378
+ head_dim = 64
1379
+
1380
+ if encoder_hidden_states is not None:
1381
+ kv = encoder_hidden_states
1382
+ else:
1383
+ kv = hidden_states
1384
+
1385
+ kv_seq_len = kv.shape[1]
1386
+
1387
+ query = to_q(hidden_states)
1388
+ key = to_k(kv)
1389
+ value = to_v(kv)
1390
+
1391
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1392
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1393
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1394
+
1395
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1396
+
1397
+ hidden_states = hidden_states.to(query.dtype)
1398
+ hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1399
+
1400
+ hidden_states = to_out(hidden_states)
1401
+
1402
+ return hidden_states
1403
+
1404
+
1405
  class GEGLU(nn.Module):
1406
  def __init__(self, dim_in: int, dim_out: int):
1407
  super().__init__()