Spaces:
Sleeping
Sleeping
File size: 5,868 Bytes
d16b52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import Optional
import torch
from diffusers.pipelines import StableDiffusionPipeline
from safetensors import safe_open
from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
from .convert_lora_safetensor_to_diffusers import convert_lora_model_level
def load_third_party_checkpoints(
pipeline: StableDiffusionPipeline,
third_party_dict: dict,
dreambooth_path: Optional[str] = None,
):
"""
Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165
"""
vae = third_party_dict.get("vae", None)
lora_list = third_party_dict.get("lora_list", [])
dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)
text_embedding_dict = third_party_dict.get("text_embedding_dict", {})
if dreambooth is not None:
dreambooth_state_dict = {}
if dreambooth.endswith(".safetensors"):
with safe_open(dreambooth, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
else:
dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
if "state_dict" in dreambooth_state_dict:
dreambooth_state_dict = dreambooth_state_dict["state_dict"]
# load unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
# load vae from dreambooth (if need)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
# add prefix for compiled model
if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)
# load text encoder (if need)
text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict)
if text_encoder_checkpoint:
pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False)
if vae is not None:
vae_state_dict = {}
if vae.endswith("safetensors"):
with safe_open(vae, framework="pt", device="cpu") as f:
for key in f.keys():
vae_state_dict[key] = f.get_tensor(key)
elif vae.endswith("ckpt") or vae.endswith("pt"):
vae_state_dict = torch.load(vae, map_location="cpu")
if "state_dict" in vae_state_dict:
vae_state_dict = vae_state_dict["state_dict"]
vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()}
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config)
# add prefix for compiled model
if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)
if lora_list:
for lora_dict in lora_list:
lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
lora_state_dict = {}
with safe_open(lora, framework="pt", device="cpu") as file:
for k in file.keys():
lora_state_dict[k] = file.get_tensor(k)
pipeline.unet, pipeline.text_encoder = convert_lora_model_level(
lora_state_dict,
pipeline.unet,
pipeline.text_encoder,
alpha=lora_alpha,
)
print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.')
if text_embedding_dict is not None:
from diffusers.loaders import TextualInversionLoaderMixin
assert isinstance(
pipeline, TextualInversionLoaderMixin
), "Pipeline must inherit from TextualInversionLoaderMixin."
for token, embedding_path in text_embedding_dict.items():
pipeline.load_textual_inversion(embedding_path, token)
return pipeline
def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None):
lora_list = third_party_dict.get("lora_list", [])
dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)
if dreambooth is not None:
dreambooth_state_dict = {}
if dreambooth.endswith(".safetensors"):
with safe_open(dreambooth, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
else:
dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
if "state_dict" in dreambooth_state_dict:
dreambooth_state_dict = dreambooth_state_dict["state_dict"]
# load unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config)
unet.load_state_dict(converted_unet_checkpoint, strict=False)
if lora_list:
for lora_dict in lora_list:
lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
lora_state_dict = {}
with safe_open(lora, framework="pt", device="cpu") as file:
for k in file.keys():
if "text" not in k:
lora_state_dict[k] = file.get_tensor(k)
unet, _ = convert_lora_model_level(
lora_state_dict,
unet,
None,
alpha=lora_alpha,
)
print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.')
return unet
|