File size: 5,868 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Optional

import torch
from diffusers.pipelines import StableDiffusionPipeline
from safetensors import safe_open

from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
from .convert_lora_safetensor_to_diffusers import convert_lora_model_level


def load_third_party_checkpoints(
    pipeline: StableDiffusionPipeline,
    third_party_dict: dict,
    dreambooth_path: Optional[str] = None,
):
    """
    Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165
    """
    vae = third_party_dict.get("vae", None)
    lora_list = third_party_dict.get("lora_list", [])

    dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)

    text_embedding_dict = third_party_dict.get("text_embedding_dict", {})

    if dreambooth is not None:
        dreambooth_state_dict = {}
        if dreambooth.endswith(".safetensors"):
            with safe_open(dreambooth, framework="pt", device="cpu") as f:
                for key in f.keys():
                    dreambooth_state_dict[key] = f.get_tensor(key)
        else:
            dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
            if "state_dict" in dreambooth_state_dict:
                dreambooth_state_dict = dreambooth_state_dict["state_dict"]
        # load unet
        converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
        pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)

        # load vae from dreambooth (if need)
        converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
        # add prefix for compiled model
        if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
            converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
        pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)

        # load text encoder (if need)
        text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict)
        if text_encoder_checkpoint:
            pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False)

    if vae is not None:
        vae_state_dict = {}
        if vae.endswith("safetensors"):
            with safe_open(vae, framework="pt", device="cpu") as f:
                for key in f.keys():
                    vae_state_dict[key] = f.get_tensor(key)
        elif vae.endswith("ckpt") or vae.endswith("pt"):
            vae_state_dict = torch.load(vae, map_location="cpu")
        if "state_dict" in vae_state_dict:
            vae_state_dict = vae_state_dict["state_dict"]

        vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()}

        converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config)
        # add prefix for compiled model
        if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]:
            converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()}
        pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True)

    if lora_list:
        for lora_dict in lora_list:
            lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
            lora_state_dict = {}
            with safe_open(lora, framework="pt", device="cpu") as file:
                for k in file.keys():
                    lora_state_dict[k] = file.get_tensor(k)
            pipeline.unet, pipeline.text_encoder = convert_lora_model_level(
                lora_state_dict,
                pipeline.unet,
                pipeline.text_encoder,
                alpha=lora_alpha,
            )
            print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.')

    if text_embedding_dict is not None:
        from diffusers.loaders import TextualInversionLoaderMixin

        assert isinstance(
            pipeline, TextualInversionLoaderMixin
        ), "Pipeline must inherit from TextualInversionLoaderMixin."

        for token, embedding_path in text_embedding_dict.items():
            pipeline.load_textual_inversion(embedding_path, token)

    return pipeline


def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None):
    lora_list = third_party_dict.get("lora_list", [])
    dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None)

    if dreambooth is not None:
        dreambooth_state_dict = {}
        if dreambooth.endswith(".safetensors"):
            with safe_open(dreambooth, framework="pt", device="cpu") as f:
                for key in f.keys():
                    dreambooth_state_dict[key] = f.get_tensor(key)
        else:
            dreambooth_state_dict = torch.load(dreambooth, map_location="cpu")
            if "state_dict" in dreambooth_state_dict:
                dreambooth_state_dict = dreambooth_state_dict["state_dict"]
        # load unet
        converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config)
        unet.load_state_dict(converted_unet_checkpoint, strict=False)

    if lora_list:
        for lora_dict in lora_list:
            lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"]
            lora_state_dict = {}

            with safe_open(lora, framework="pt", device="cpu") as file:
                for k in file.keys():
                    if "text" not in k:
                        lora_state_dict[k] = file.get_tensor(k)
            unet, _ = convert_lora_model_level(
                lora_state_dict,
                unet,
                None,
                alpha=lora_alpha,
            )
            print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.')

    return unet