File size: 2,947 Bytes
fb4fac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from .sd_unet import SDUNetStateDictConverter, SDUNet
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder


class SDLoRA:
    def __init__(self):
        pass

    def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
        special_keys = {
            "down.blocks": "down_blocks",
            "up.blocks": "up_blocks",
            "mid.block": "mid_block",
            "proj.in": "proj_in",
            "proj.out": "proj_out",
            "transformer.blocks": "transformer_blocks",
            "to.q": "to_q",
            "to.k": "to_k",
            "to.v": "to_v",
            "to.out": "to_out",
        }
        state_dict_ = {}
        for key in state_dict:
            if ".lora_up" not in key:
                continue
            if not key.startswith(lora_prefix):
                continue
            weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
            weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
            if len(weight_up.shape) == 4:
                weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
                weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
                lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
            else:
                lora_weight = alpha * torch.mm(weight_up, weight_down)
            target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
            for special_key in special_keys:
                target_name = target_name.replace(special_key, special_keys[special_key])
            state_dict_[target_name] = lora_weight.cpu()
        return state_dict_
    
    def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
        state_dict_unet = unet.state_dict()
        state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
        state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
        if len(state_dict_lora) > 0:
            for name in state_dict_lora:
                state_dict_unet[name] += state_dict_lora[name].to(device=device)
            unet.load_state_dict(state_dict_unet)

    def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
        state_dict_text_encoder = text_encoder.state_dict()
        state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
        state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
        if len(state_dict_lora) > 0:
            for name in state_dict_lora:
                state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
            text_encoder.load_state_dict(state_dict_text_encoder)