leoxing1996
add demo
d16b52d
raw
history blame
5.87 kB
from typing import Optional
import torch
from diffusers.pipelines import StableDiffusionPipeline
from safetensors import safe_open
from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
from .convert_lora_safetensor_to_diffusers import convert_lora_model_level
def load_third_party_checkpoints(
pipeline: StableDiffusionPipeline,
third_party_dict: dict,
dreambooth_path: Optional[str] = None,
):
"""
Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165
"""
vae = third_party_dict.get("vae", None)
lora_list = third_party_dict.get("lora_list", [])
dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)
text_embedding_dict = third_party_dict.get("text_embedding_dict", {})
if dreambooth is not None:
dreambooth_state_dict = {}
if dreambooth.endswith(".safetensors"):
with safe_open(dreambooth, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
else:
dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
if "state_dict" in dreambooth_state_dict:
dreambooth_state_dict = dreambooth_state_dict["state_dict"]
# load unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
# load vae from dreambooth (if need)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
# add prefix for compiled model
if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)
# load text encoder (if need)
text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict)
if text_encoder_checkpoint:
pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False)
if vae is not None:
vae_state_dict = {}
if vae.endswith("safetensors"):
with safe_open(vae, framework="pt", device="cpu") as f:
for key in f.keys():
vae_state_dict[key] = f.get_tensor(key)
elif vae.endswith("ckpt") or vae.endswith("pt"):
vae_state_dict = torch.load(vae, map_location="cpu")
if "state_dict" in vae_state_dict:
vae_state_dict = vae_state_dict["state_dict"]
vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()}
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config)
# add prefix for compiled model
if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)
if lora_list:
for lora_dict in lora_list:
lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
lora_state_dict = {}
with safe_open(lora, framework="pt", device="cpu") as file:
for k in file.keys():
lora_state_dict[k] = file.get_tensor(k)
pipeline.unet, pipeline.text_encoder = convert_lora_model_level(
lora_state_dict,
pipeline.unet,
pipeline.text_encoder,
alpha=lora_alpha,
)
print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.')
if text_embedding_dict is not None:
from diffusers.loaders import TextualInversionLoaderMixin
assert isinstance(
pipeline, TextualInversionLoaderMixin
), "Pipeline must inherit from TextualInversionLoaderMixin."
for token, embedding_path in text_embedding_dict.items():
pipeline.load_textual_inversion(embedding_path, token)
return pipeline
def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None):
lora_list = third_party_dict.get("lora_list", [])
dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)
if dreambooth is not None:
dreambooth_state_dict = {}
if dreambooth.endswith(".safetensors"):
with safe_open(dreambooth, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
else:
dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
if "state_dict" in dreambooth_state_dict:
dreambooth_state_dict = dreambooth_state_dict["state_dict"]
# load unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config)
unet.load_state_dict(converted_unet_checkpoint, strict=False)
if lora_list:
for lora_dict in lora_list:
lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
lora_state_dict = {}
with safe_open(lora, framework="pt", device="cpu") as file:
for k in file.keys():
if "text" not in k:
lora_state_dict[k] = file.get_tensor(k)
unet, _ = convert_lora_model_level(
lora_state_dict,
unet,
None,
alpha=lora_alpha,
)
print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.')
return unet