abc
commited on
Commit
·
ac2ea1d
1
Parent(s):
0380cfd
Upload 10 files
Browse files- lycoris/dylora.py +175 -0
- lycoris/ia3.py +68 -0
- lycoris/kohya.py +262 -28
- lycoris/kohya_model_utils.py +977 -996
- lycoris/locon.py +2 -1
- lycoris/loha.py +2 -1
- lycoris/lokr.py +220 -0
- lycoris/utils.py +241 -185
lycoris/dylora.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
from collections import OrderedDict, abc as container_abcs
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class DyLoraModule(nn.Module):
|
12 |
+
"""
|
13 |
+
Hadamard product Implementaion for Dynamic Low Rank adaptation
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
lora_name,
|
19 |
+
org_module: nn.Module,
|
20 |
+
multiplier=1.0,
|
21 |
+
lora_dim=4, alpha=1,
|
22 |
+
dropout=0.,
|
23 |
+
use_cp=False,
|
24 |
+
block_size=1,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
28 |
+
super().__init__()
|
29 |
+
self.lora_name = lora_name
|
30 |
+
self.lora_dim = lora_dim
|
31 |
+
assert lora_dim % block_size == 0, 'lora_dim must be a multiple of block_size'
|
32 |
+
self.block_count = lora_dim//block_size
|
33 |
+
self.block_size = block_size
|
34 |
+
|
35 |
+
self.shape = org_module.weight.shape
|
36 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
37 |
+
in_dim = org_module.in_channels
|
38 |
+
k_size = org_module.kernel_size
|
39 |
+
out_dim = org_module.out_channels
|
40 |
+
shape = (out_dim, in_dim*k_size[0]*k_size[1])
|
41 |
+
self.op = F.conv2d
|
42 |
+
self.extra_args = {
|
43 |
+
"stride": org_module.stride,
|
44 |
+
"padding": org_module.padding,
|
45 |
+
"dilation": org_module.dilation,
|
46 |
+
"groups": org_module.groups
|
47 |
+
}
|
48 |
+
else:
|
49 |
+
in_dim = org_module.in_features
|
50 |
+
out_dim = org_module.out_features
|
51 |
+
shape = (out_dim, in_dim)
|
52 |
+
self.op = F.linear
|
53 |
+
self.extra_args = {}
|
54 |
+
|
55 |
+
self.lora_dim = lora_dim
|
56 |
+
self.up_list = nn.ParameterList([
|
57 |
+
torch.empty(shape[0], 1)
|
58 |
+
for i in range(lora_dim)
|
59 |
+
])
|
60 |
+
self.up_list.requires_grad_(False)
|
61 |
+
self.up_update = [
|
62 |
+
torch.zeros_like(self.up_list[i])
|
63 |
+
for i in range(lora_dim)
|
64 |
+
]
|
65 |
+
|
66 |
+
self.down_list = nn.ParameterList([
|
67 |
+
torch.empty(1, shape[1])
|
68 |
+
for i in range(lora_dim)
|
69 |
+
])
|
70 |
+
self.down_list.requires_grad_(False)
|
71 |
+
self.down_update = [
|
72 |
+
torch.zeros_like(self.down_list[i])
|
73 |
+
for i in range(lora_dim)
|
74 |
+
]
|
75 |
+
|
76 |
+
self.index = 0
|
77 |
+
|
78 |
+
if type(alpha) == torch.Tensor:
|
79 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
80 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
81 |
+
self.scale = alpha / self.lora_dim
|
82 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
83 |
+
|
84 |
+
# Need more experiences on init method
|
85 |
+
|
86 |
+
for v in self.down_list:
|
87 |
+
torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
|
88 |
+
for v in self.up_list:
|
89 |
+
torch.nn.init.zeros_(v)
|
90 |
+
for i, v in enumerate(self.up_update):
|
91 |
+
v.copy_(self.up_list[i])
|
92 |
+
for i, v in enumerate(self.down_update):
|
93 |
+
v.copy_(self.down_list[i])
|
94 |
+
|
95 |
+
self.multiplier = multiplier
|
96 |
+
self.org_module = [org_module] # remove in applying
|
97 |
+
self.grad_ckpt = False
|
98 |
+
|
99 |
+
self.apply_train(0)
|
100 |
+
|
101 |
+
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
102 |
+
# TODO: Remove `args` and the parsing logic when BC allows.
|
103 |
+
if len(args) > 0:
|
104 |
+
if destination is None:
|
105 |
+
destination = args[0]
|
106 |
+
if len(args) > 1 and prefix == '':
|
107 |
+
prefix = args[1]
|
108 |
+
if len(args) > 2 and keep_vars is False:
|
109 |
+
keep_vars = args[2]
|
110 |
+
# DeprecationWarning is ignored by default
|
111 |
+
|
112 |
+
if destination is None:
|
113 |
+
destination = OrderedDict()
|
114 |
+
destination._metadata = OrderedDict()
|
115 |
+
|
116 |
+
local_metadata = dict(version=self._version)
|
117 |
+
if hasattr(destination, "_metadata"):
|
118 |
+
destination._metadata[prefix[:-1]] = local_metadata
|
119 |
+
|
120 |
+
destination[f'{prefix}alpha'] = self.alpha
|
121 |
+
destination[f'{prefix}lora_up.weight'] = nn.Parameter(
|
122 |
+
torch.concat(self.up_update, dim=1)
|
123 |
+
)
|
124 |
+
destination[f'{prefix}lora_down.weight'] = nn.Parameter(
|
125 |
+
torch.concat(self.down_update)
|
126 |
+
)
|
127 |
+
return destination
|
128 |
+
|
129 |
+
def apply_to(self):
|
130 |
+
self.org_module[0].forward = self.forward
|
131 |
+
|
132 |
+
def apply_train(self, b:int):
|
133 |
+
self.up_list.requires_grad_(False)
|
134 |
+
self.down_list.requires_grad_(False)
|
135 |
+
|
136 |
+
for i in range(self.index*self.block_size, (self.index+1)*self.block_size):
|
137 |
+
self.up_update[i].copy_(self.up_list[i])
|
138 |
+
self.down_update[i].copy_(self.down_list[i])
|
139 |
+
|
140 |
+
for i in range(b*self.block_size, (b+1)*self.block_size):
|
141 |
+
self.up_list[i].copy_(self.up_update[i])
|
142 |
+
self.down_list[i].copy_(self.down_update[i])
|
143 |
+
|
144 |
+
self.up_list.requires_grad_(True)
|
145 |
+
self.down_list.requires_grad_(True)
|
146 |
+
self.index = b
|
147 |
+
|
148 |
+
@torch.enable_grad()
|
149 |
+
def forward(self, x):
|
150 |
+
b = random.randint(0, self.block_count-1)
|
151 |
+
if self.up_update[b].device != self.up_list[b].device:
|
152 |
+
device = self.up_list[b].device
|
153 |
+
for i in range(self.lora_dim):
|
154 |
+
self.up_update[i] = self.up_update[i].to(device)
|
155 |
+
self.down_update[i] = self.down_update[i].to(device)
|
156 |
+
|
157 |
+
if self.training:
|
158 |
+
self.apply_train(b)
|
159 |
+
down = torch.concat(
|
160 |
+
list(self.down_update[:b*self.block_size])
|
161 |
+
+ list(self.down_list[b*self.block_size:(b+1)*self.block_size])
|
162 |
+
)
|
163 |
+
up = torch.concat(
|
164 |
+
list(self.up_update[:b*self.block_size])
|
165 |
+
+ list(self.up_list[b*self.block_size:(b+1)*self.block_size]),
|
166 |
+
dim=1
|
167 |
+
)
|
168 |
+
|
169 |
+
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
170 |
+
return self.op(
|
171 |
+
x,
|
172 |
+
self.org_module[0].weight + (up@down).view(self.shape) * self.alpha/(b+1),
|
173 |
+
bias,
|
174 |
+
**self.extra_args
|
175 |
+
)
|
lycoris/ia3.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class IA3Module(nn.Module):
|
9 |
+
"""
|
10 |
+
Hadamard product Implementaion for Low Rank Adaptation
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
lora_name,
|
16 |
+
org_module: nn.Module,
|
17 |
+
multiplier=1.0,
|
18 |
+
train_on_input=False,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
+
super().__init__()
|
23 |
+
self.lora_name = lora_name
|
24 |
+
self.cp=False
|
25 |
+
|
26 |
+
self.shape = org_module.weight.shape
|
27 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
28 |
+
in_dim = org_module.in_channels
|
29 |
+
out_dim = org_module.out_channels
|
30 |
+
if train_on_input:
|
31 |
+
train_dim = in_dim
|
32 |
+
else:
|
33 |
+
train_dim = out_dim
|
34 |
+
self.weight = nn.Parameter(torch.empty(1, train_dim, 1, 1))
|
35 |
+
else:
|
36 |
+
in_dim = org_module.in_features
|
37 |
+
out_dim = org_module.out_features
|
38 |
+
if train_on_input:
|
39 |
+
train_dim = in_dim
|
40 |
+
else:
|
41 |
+
train_dim = out_dim
|
42 |
+
|
43 |
+
self.weight = nn.Parameter(torch.empty(train_dim))
|
44 |
+
|
45 |
+
# Need more experiences on init method
|
46 |
+
torch.nn.init.constant_(self.weight, 0)
|
47 |
+
|
48 |
+
self.multiplier = multiplier
|
49 |
+
self.org_forward = None
|
50 |
+
self.org_module = [org_module] # remove in applying
|
51 |
+
self.grad_ckpt = False
|
52 |
+
self.train_input = train_on_input
|
53 |
+
self.register_buffer('on_input', torch.tensor(int(train_on_input)))
|
54 |
+
|
55 |
+
def apply_to(self):
|
56 |
+
self.org_forward = self.org_module[0].forward
|
57 |
+
self.org_module[0].forward = self.forward
|
58 |
+
|
59 |
+
@torch.enable_grad()
|
60 |
+
def forward(self, x):
|
61 |
+
if self.train_input:
|
62 |
+
x = x * (1 + self.weight * self.multiplier)
|
63 |
+
out = self.org_forward(x)
|
64 |
+
dtype = out.dtype
|
65 |
+
if not self.train_input:
|
66 |
+
out = out * (1 + self.weight * self.multiplier)
|
67 |
+
out = out.to(dtype)
|
68 |
+
return out
|
lycoris/kohya.py
CHANGED
@@ -13,6 +13,9 @@ import torch
|
|
13 |
from .kohya_utils import *
|
14 |
from .locon import LoConModule
|
15 |
from .loha import LohaModule
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
@@ -21,39 +24,55 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|
21 |
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
22 |
conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
|
23 |
dropout = float(kwargs.get('dropout', 0.))
|
24 |
-
algo = kwargs.get('algo', 'lora')
|
25 |
-
|
|
|
|
|
26 |
network_module = {
|
27 |
'lora': LoConModule,
|
|
|
28 |
'loha': LohaModule,
|
|
|
|
|
|
|
29 |
}[algo]
|
30 |
|
31 |
print(f'Using rank adaptation algo: {algo}')
|
32 |
|
33 |
-
if (algo == 'loha'
|
34 |
and not kwargs.get('no_dim_warn', False)
|
35 |
and (network_dim>64 or conv_dim>64)):
|
36 |
print('='*20 + 'WARNING' + '='*20)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
print('='*20 + 'WARNING' + '='*20)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
return network
|
59 |
|
@@ -86,8 +105,9 @@ class LycorisNetwork(torch.nn.Module):
|
|
86 |
multiplier=1.0,
|
87 |
lora_dim=4, conv_lora_dim=4,
|
88 |
alpha=1, conv_alpha=1,
|
89 |
-
use_cp =
|
90 |
dropout = 0, network_module = LoConModule,
|
|
|
91 |
) -> None:
|
92 |
super().__init__()
|
93 |
self.multiplier = multiplier
|
@@ -124,19 +144,25 @@ class LycorisNetwork(torch.nn.Module):
|
|
124 |
if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
|
125 |
lora = network_module(
|
126 |
lora_name, child_module, self.multiplier,
|
127 |
-
self.lora_dim, self.alpha,
|
|
|
|
|
128 |
)
|
129 |
elif child_module.__class__.__name__ == 'Conv2d':
|
130 |
k_size, *_ = child_module.kernel_size
|
131 |
if k_size==1 and lora_dim>0:
|
132 |
lora = network_module(
|
133 |
lora_name, child_module, self.multiplier,
|
134 |
-
self.lora_dim, self.alpha,
|
|
|
|
|
135 |
)
|
136 |
elif conv_lora_dim>0:
|
137 |
lora = network_module(
|
138 |
lora_name, child_module, self.multiplier,
|
139 |
-
self.conv_lora_dim, self.conv_alpha,
|
|
|
|
|
140 |
)
|
141 |
else:
|
142 |
continue
|
@@ -149,19 +175,25 @@ class LycorisNetwork(torch.nn.Module):
|
|
149 |
if module.__class__.__name__ == 'Linear' and lora_dim>0:
|
150 |
lora = network_module(
|
151 |
lora_name, module, self.multiplier,
|
152 |
-
self.lora_dim, self.alpha,
|
|
|
|
|
153 |
)
|
154 |
elif module.__class__.__name__ == 'Conv2d':
|
155 |
k_size, *_ = module.kernel_size
|
156 |
if k_size==1 and lora_dim>0:
|
157 |
lora = network_module(
|
158 |
lora_name, module, self.multiplier,
|
159 |
-
self.lora_dim, self.alpha,
|
|
|
|
|
160 |
)
|
161 |
elif conv_lora_dim>0:
|
162 |
lora = network_module(
|
163 |
lora_name, module, self.multiplier,
|
164 |
-
self.conv_lora_dim, self.conv_alpha,
|
|
|
|
|
165 |
)
|
166 |
else:
|
167 |
continue
|
@@ -306,3 +338,205 @@ class LycorisNetwork(torch.nn.Module):
|
|
306 |
save_file(state_dict, file, metadata)
|
307 |
else:
|
308 |
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from .kohya_utils import *
|
14 |
from .locon import LoConModule
|
15 |
from .loha import LohaModule
|
16 |
+
from .ia3 import IA3Module
|
17 |
+
from .lokr import LokrModule
|
18 |
+
from .dylora import DyLoraModule
|
19 |
|
20 |
|
21 |
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
|
|
24 |
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
25 |
conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
|
26 |
dropout = float(kwargs.get('dropout', 0.))
|
27 |
+
algo = kwargs.get('algo', 'lora').lower()
|
28 |
+
use_cp = (not kwargs.get('disable_conv_cp', True)
|
29 |
+
or kwargs.get('use_conv_cp', False))
|
30 |
+
block_size = int(kwargs.get('block_size', 4))
|
31 |
network_module = {
|
32 |
'lora': LoConModule,
|
33 |
+
'locon': LoConModule,
|
34 |
'loha': LohaModule,
|
35 |
+
'ia3': IA3Module,
|
36 |
+
'lokr': LokrModule,
|
37 |
+
'dylora': DyLoraModule,
|
38 |
}[algo]
|
39 |
|
40 |
print(f'Using rank adaptation algo: {algo}')
|
41 |
|
42 |
+
if ((algo == 'loha' or algo == 'lokr')
|
43 |
and not kwargs.get('no_dim_warn', False)
|
44 |
and (network_dim>64 or conv_dim>64)):
|
45 |
print('='*20 + 'WARNING' + '='*20)
|
46 |
+
warning_type ={
|
47 |
+
'loha': "Hadamard Product representation",
|
48 |
+
'lokr': "Kronecker Product representation"
|
49 |
+
}
|
50 |
+
warning_msg = f"""You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)\n
|
51 |
+
in {warning_type[algo]}!\n
|
52 |
+
Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n
|
53 |
+
If you just want to use high dim {algo}, please consider use lower lr.
|
54 |
+
"""
|
55 |
+
warn(warning_msg, stacklevel=2)
|
56 |
print('='*20 + 'WARNING' + '='*20)
|
57 |
|
58 |
+
if algo == 'ia3':
|
59 |
+
network = IA3Network(
|
60 |
+
text_encoder, unet,
|
61 |
+
multiplier = multiplier,
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
network = LycorisNetwork(
|
65 |
+
text_encoder, unet,
|
66 |
+
multiplier=multiplier,
|
67 |
+
lora_dim=network_dim, conv_lora_dim=conv_dim,
|
68 |
+
alpha=network_alpha, conv_alpha=conv_alpha,
|
69 |
+
dropout=dropout,
|
70 |
+
use_cp=use_cp,
|
71 |
+
network_module=network_module,
|
72 |
+
decompose_both=kwargs.get('decompose_both', False),
|
73 |
+
factor=kwargs.get('factor', -1),
|
74 |
+
block_size = block_size
|
75 |
+
)
|
76 |
|
77 |
return network
|
78 |
|
|
|
105 |
multiplier=1.0,
|
106 |
lora_dim=4, conv_lora_dim=4,
|
107 |
alpha=1, conv_alpha=1,
|
108 |
+
use_cp = False,
|
109 |
dropout = 0, network_module = LoConModule,
|
110 |
+
**kwargs,
|
111 |
) -> None:
|
112 |
super().__init__()
|
113 |
self.multiplier = multiplier
|
|
|
144 |
if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
|
145 |
lora = network_module(
|
146 |
lora_name, child_module, self.multiplier,
|
147 |
+
self.lora_dim, self.alpha,
|
148 |
+
self.dropout, use_cp,
|
149 |
+
**kwargs
|
150 |
)
|
151 |
elif child_module.__class__.__name__ == 'Conv2d':
|
152 |
k_size, *_ = child_module.kernel_size
|
153 |
if k_size==1 and lora_dim>0:
|
154 |
lora = network_module(
|
155 |
lora_name, child_module, self.multiplier,
|
156 |
+
self.lora_dim, self.alpha,
|
157 |
+
self.dropout, use_cp,
|
158 |
+
**kwargs
|
159 |
)
|
160 |
elif conv_lora_dim>0:
|
161 |
lora = network_module(
|
162 |
lora_name, child_module, self.multiplier,
|
163 |
+
self.conv_lora_dim, self.conv_alpha,
|
164 |
+
self.dropout, use_cp,
|
165 |
+
**kwargs
|
166 |
)
|
167 |
else:
|
168 |
continue
|
|
|
175 |
if module.__class__.__name__ == 'Linear' and lora_dim>0:
|
176 |
lora = network_module(
|
177 |
lora_name, module, self.multiplier,
|
178 |
+
self.lora_dim, self.alpha,
|
179 |
+
self.dropout, use_cp,
|
180 |
+
**kwargs
|
181 |
)
|
182 |
elif module.__class__.__name__ == 'Conv2d':
|
183 |
k_size, *_ = module.kernel_size
|
184 |
if k_size==1 and lora_dim>0:
|
185 |
lora = network_module(
|
186 |
lora_name, module, self.multiplier,
|
187 |
+
self.lora_dim, self.alpha,
|
188 |
+
self.dropout, use_cp,
|
189 |
+
**kwargs
|
190 |
)
|
191 |
elif conv_lora_dim>0:
|
192 |
lora = network_module(
|
193 |
lora_name, module, self.multiplier,
|
194 |
+
self.conv_lora_dim, self.conv_alpha,
|
195 |
+
self.dropout, use_cp,
|
196 |
+
**kwargs
|
197 |
)
|
198 |
else:
|
199 |
continue
|
|
|
338 |
save_file(state_dict, file, metadata)
|
339 |
else:
|
340 |
torch.save(state_dict, file)
|
341 |
+
|
342 |
+
|
343 |
+
class IA3Network(torch.nn.Module):
|
344 |
+
'''
|
345 |
+
IA3 network
|
346 |
+
'''
|
347 |
+
# Ignore proj_in or proj_out, their channels is only a few.
|
348 |
+
UNET_TARGET_REPLACE_MODULE = []
|
349 |
+
UNET_TARGET_REPLACE_NAME = ["to_k", "to_v", "ff.net.2"]
|
350 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = []
|
351 |
+
TEXT_ENCODER_TARGET_REPLACE_NAME= ["k_proj", "v_proj", "mlp.fc2"]
|
352 |
+
TRAIN_INPUT = ["mlp.fc2", "ff.net.2"]
|
353 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
354 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
355 |
+
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
text_encoder, unet,
|
359 |
+
multiplier=1.0,
|
360 |
+
**kwargs,
|
361 |
+
) -> None:
|
362 |
+
super().__init__()
|
363 |
+
self.multiplier = multiplier
|
364 |
+
|
365 |
+
# create module instances
|
366 |
+
def create_modules(
|
367 |
+
prefix,
|
368 |
+
root_module: torch.nn.Module,
|
369 |
+
target_replace_modules,
|
370 |
+
target_replace_names = [],
|
371 |
+
target_train_input = []
|
372 |
+
) -> List[IA3Module]:
|
373 |
+
print('Create LyCORIS Module')
|
374 |
+
loras = []
|
375 |
+
for name, module in root_module.named_modules():
|
376 |
+
if module.__class__.__name__ in target_replace_modules:
|
377 |
+
for child_name, child_module in module.named_modules():
|
378 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
379 |
+
lora_name = lora_name.replace('.', '_')
|
380 |
+
if child_module.__class__.__name__ in {'Linear', 'Conv2d'}:
|
381 |
+
lora = IA3Module(
|
382 |
+
lora_name, child_module, self.multiplier,
|
383 |
+
name in target_train_input,
|
384 |
+
**kwargs,
|
385 |
+
)
|
386 |
+
loras.append(lora)
|
387 |
+
elif any(i in name for i in target_replace_names):
|
388 |
+
lora_name = prefix + '.' + name
|
389 |
+
lora_name = lora_name.replace('.', '_')
|
390 |
+
if module.__class__.__name__ in {'Linear', 'Conv2d'}:
|
391 |
+
lora = IA3Module(
|
392 |
+
lora_name, module, self.multiplier,
|
393 |
+
name in target_train_input,
|
394 |
+
**kwargs,
|
395 |
+
)
|
396 |
+
loras.append(lora)
|
397 |
+
return loras
|
398 |
+
|
399 |
+
self.text_encoder_loras = create_modules(
|
400 |
+
IA3Network.LORA_PREFIX_TEXT_ENCODER,
|
401 |
+
text_encoder,
|
402 |
+
IA3Network.TEXT_ENCODER_TARGET_REPLACE_MODULE,
|
403 |
+
IA3Network.TEXT_ENCODER_TARGET_REPLACE_NAME,
|
404 |
+
IA3Network.TRAIN_INPUT
|
405 |
+
)
|
406 |
+
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
407 |
+
|
408 |
+
self.unet_loras = create_modules(
|
409 |
+
IA3Network.LORA_PREFIX_UNET,
|
410 |
+
unet,
|
411 |
+
IA3Network.UNET_TARGET_REPLACE_MODULE,
|
412 |
+
IA3Network.UNET_TARGET_REPLACE_NAME,
|
413 |
+
IA3Network.TRAIN_INPUT
|
414 |
+
)
|
415 |
+
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
416 |
+
|
417 |
+
self.weights_sd = None
|
418 |
+
|
419 |
+
# assertion
|
420 |
+
names = set()
|
421 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
422 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
423 |
+
names.add(lora.lora_name)
|
424 |
+
|
425 |
+
def set_multiplier(self, multiplier):
|
426 |
+
self.multiplier = multiplier
|
427 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
428 |
+
lora.multiplier = self.multiplier
|
429 |
+
|
430 |
+
def load_weights(self, file):
|
431 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
432 |
+
from safetensors.torch import load_file, safe_open
|
433 |
+
self.weights_sd = load_file(file)
|
434 |
+
else:
|
435 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
436 |
+
|
437 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
438 |
+
if self.weights_sd:
|
439 |
+
weights_has_text_encoder = weights_has_unet = False
|
440 |
+
for key in self.weights_sd.keys():
|
441 |
+
if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER):
|
442 |
+
weights_has_text_encoder = True
|
443 |
+
elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET):
|
444 |
+
weights_has_unet = True
|
445 |
+
|
446 |
+
if apply_text_encoder is None:
|
447 |
+
apply_text_encoder = weights_has_text_encoder
|
448 |
+
else:
|
449 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
450 |
+
|
451 |
+
if apply_unet is None:
|
452 |
+
apply_unet = weights_has_unet
|
453 |
+
else:
|
454 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
455 |
+
else:
|
456 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
457 |
+
|
458 |
+
if apply_text_encoder:
|
459 |
+
print("enable LyCORIS for text encoder")
|
460 |
+
else:
|
461 |
+
self.text_encoder_loras = []
|
462 |
+
|
463 |
+
if apply_unet:
|
464 |
+
print("enable LyCORIS for U-Net")
|
465 |
+
else:
|
466 |
+
self.unet_loras = []
|
467 |
+
|
468 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
469 |
+
lora.apply_to()
|
470 |
+
self.add_module(lora.lora_name, lora)
|
471 |
+
|
472 |
+
if self.weights_sd:
|
473 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
474 |
+
info = self.load_state_dict(self.weights_sd, False)
|
475 |
+
print(f"weights are loaded: {info}")
|
476 |
+
|
477 |
+
def enable_gradient_checkpointing(self):
|
478 |
+
# not supported
|
479 |
+
def make_ckpt(module):
|
480 |
+
if isinstance(module, torch.nn.Module):
|
481 |
+
module.grad_ckpt = True
|
482 |
+
self.apply(make_ckpt)
|
483 |
+
pass
|
484 |
+
|
485 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
486 |
+
def enumerate_params(loras):
|
487 |
+
params = []
|
488 |
+
for lora in loras:
|
489 |
+
params.extend(lora.parameters())
|
490 |
+
return params
|
491 |
+
|
492 |
+
self.requires_grad_(True)
|
493 |
+
all_params = []
|
494 |
+
|
495 |
+
if self.text_encoder_loras:
|
496 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
497 |
+
if text_encoder_lr is not None:
|
498 |
+
param_data['lr'] = text_encoder_lr
|
499 |
+
all_params.append(param_data)
|
500 |
+
|
501 |
+
if self.unet_loras:
|
502 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
503 |
+
if unet_lr is not None:
|
504 |
+
param_data['lr'] = unet_lr
|
505 |
+
all_params.append(param_data)
|
506 |
+
|
507 |
+
return all_params
|
508 |
+
|
509 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
510 |
+
self.requires_grad_(True)
|
511 |
+
|
512 |
+
def on_epoch_start(self, text_encoder, unet):
|
513 |
+
self.train()
|
514 |
+
|
515 |
+
def get_trainable_params(self):
|
516 |
+
return self.parameters()
|
517 |
+
|
518 |
+
def save_weights(self, file, dtype, metadata):
|
519 |
+
if metadata is not None and len(metadata) == 0:
|
520 |
+
metadata = None
|
521 |
+
|
522 |
+
state_dict = self.state_dict()
|
523 |
+
|
524 |
+
if dtype is not None:
|
525 |
+
for key in list(state_dict.keys()):
|
526 |
+
v = state_dict[key]
|
527 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
528 |
+
state_dict[key] = v
|
529 |
+
|
530 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
531 |
+
from safetensors.torch import save_file
|
532 |
+
|
533 |
+
# Precalculate model hashes to save time on indexing
|
534 |
+
if metadata is None:
|
535 |
+
metadata = {}
|
536 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
537 |
+
metadata["sshs_model_hash"] = model_hash
|
538 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
539 |
+
|
540 |
+
save_file(state_dict, file, metadata)
|
541 |
+
else:
|
542 |
+
torch.save(state_dict, file)
|
lycoris/kohya_model_utils.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
-
'''
|
2 |
-
https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
3 |
-
'''
|
4 |
# v1: split from train_db_fixed.py.
|
5 |
# v2: support safetensors
|
6 |
|
7 |
import math
|
8 |
import os
|
9 |
import torch
|
10 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
11 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
12 |
from safetensors.torch import load_file, save_file
|
13 |
|
@@ -19,7 +16,7 @@ BETA_END = 0.0120
|
|
19 |
UNET_PARAMS_MODEL_CHANNELS = 320
|
20 |
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
21 |
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
22 |
-
UNET_PARAMS_IMAGE_SIZE =
|
23 |
UNET_PARAMS_IN_CHANNELS = 4
|
24 |
UNET_PARAMS_OUT_CHANNELS = 4
|
25 |
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
@@ -48,596 +45,574 @@ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
|
48 |
|
49 |
|
50 |
def shave_segments(path, n_shave_prefix_segments=1):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
|
60 |
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
|
77 |
-
|
78 |
|
79 |
-
|
80 |
|
81 |
|
82 |
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
|
95 |
-
|
96 |
|
97 |
|
98 |
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
|
106 |
-
|
107 |
-
|
108 |
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
|
113 |
|
114 |
-
|
115 |
|
116 |
-
|
117 |
|
118 |
|
119 |
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
|
127 |
-
|
128 |
-
|
129 |
|
130 |
-
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
|
136 |
-
|
137 |
-
|
138 |
|
139 |
-
|
140 |
-
|
141 |
|
142 |
-
|
143 |
|
144 |
-
|
145 |
|
146 |
-
|
147 |
|
148 |
|
149 |
def assign_to_checkpoint(
|
150 |
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
151 |
):
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
|
167 |
-
|
168 |
|
169 |
-
|
170 |
|
171 |
-
|
172 |
-
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
|
200 |
|
201 |
def conv_attn_to_linear(checkpoint):
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
212 |
|
213 |
def linear_transformer_to_conv(checkpoint):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
|
221 |
|
222 |
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
"old": f"output_blocks.{i}.1",
|
367 |
-
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
368 |
-
}
|
369 |
-
assign_to_checkpoint(
|
370 |
-
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
371 |
-
)
|
372 |
-
else:
|
373 |
-
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
374 |
-
for path in resnet_0_paths:
|
375 |
-
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
376 |
-
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
377 |
-
|
378 |
-
new_checkpoint[new_path] = unet_state_dict[old_path]
|
379 |
-
|
380 |
-
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
381 |
-
if v2:
|
382 |
-
linear_transformer_to_conv(new_checkpoint)
|
383 |
|
384 |
-
|
385 |
|
386 |
|
387 |
def convert_ldm_vae_checkpoint(checkpoint, config):
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
|
|
454 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
479 |
-
|
480 |
-
|
481 |
-
num_mid_res_blocks = 2
|
482 |
-
for i in range(1, num_mid_res_blocks + 1):
|
483 |
-
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
484 |
-
|
485 |
-
paths = renew_vae_resnet_paths(resnets)
|
486 |
-
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
487 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
488 |
-
|
489 |
-
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
490 |
-
paths = renew_vae_attention_paths(mid_attentions)
|
491 |
-
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
492 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
493 |
-
conv_attn_to_linear(new_checkpoint)
|
494 |
-
return new_checkpoint
|
495 |
|
496 |
|
497 |
def create_unet_diffusers_config(v2):
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
|
533 |
|
534 |
def create_vae_diffusers_config():
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
|
556 |
|
557 |
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
|
565 |
|
566 |
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
|
|
641 |
|
642 |
# endregion
|
643 |
|
@@ -645,540 +620,546 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
|
645 |
# region Diffusers->StableDiffusion の変換コード
|
646 |
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
647 |
|
|
|
648 |
def conv_transformer_to_linear(checkpoint):
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
|
656 |
|
657 |
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
for j in range(2):
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
713 |
-
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
714 |
-
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
715 |
-
|
716 |
-
# no upsample in up_blocks.3
|
717 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
718 |
-
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
719 |
-
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
720 |
-
|
721 |
-
hf_mid_atn_prefix = "mid_block.attentions.0."
|
722 |
-
sd_mid_atn_prefix = "middle_block.1."
|
723 |
-
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
724 |
-
|
725 |
-
for j in range(2):
|
726 |
-
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
727 |
-
sd_mid_res_prefix = f"middle_block.{2*j}."
|
728 |
-
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
729 |
-
|
730 |
-
# buyer beware: this is a *brittle* function,
|
731 |
-
# and correct output requires that all of these pieces interact in
|
732 |
-
# the exact order in which I have arranged them.
|
733 |
-
mapping = {k: k for k in unet_state_dict.keys()}
|
734 |
-
for sd_name, hf_name in unet_conversion_map:
|
735 |
-
mapping[hf_name] = sd_name
|
736 |
-
for k, v in mapping.items():
|
737 |
-
if "resnets" in k:
|
738 |
-
for sd_part, hf_part in unet_conversion_map_resnet:
|
739 |
-
v = v.replace(hf_part, sd_part)
|
740 |
-
mapping[k] = v
|
741 |
-
for k, v in mapping.items():
|
742 |
-
for sd_part, hf_part in unet_conversion_map_layer:
|
743 |
-
v = v.replace(hf_part, sd_part)
|
744 |
-
mapping[k] = v
|
745 |
-
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
746 |
-
|
747 |
-
if v2:
|
748 |
-
conv_transformer_to_linear(new_state_dict)
|
749 |
-
|
750 |
-
return new_state_dict
|
751 |
|
752 |
|
753 |
# ================#
|
754 |
# VAE Conversion #
|
755 |
# ================#
|
756 |
|
|
|
757 |
def reshape_weight_for_sd(w):
|
758 |
# convert HF linear weights to SD conv2d weights
|
759 |
-
|
760 |
|
761 |
|
762 |
def convert_vae_state_dict(vae_state_dict):
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
|
828 |
|
829 |
# endregion
|
830 |
|
831 |
# region 自作のモデル読み書きなど
|
832 |
|
|
|
833 |
def is_safetensors(path):
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
else:
|
849 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
850 |
-
if "state_dict" in checkpoint:
|
851 |
-
state_dict = checkpoint["state_dict"]
|
852 |
else:
|
853 |
-
|
854 |
-
|
|
|
|
|
|
|
|
|
855 |
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
|
867 |
-
|
868 |
|
869 |
|
870 |
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
871 |
-
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
872 |
-
|
873 |
-
if dtype is not None:
|
874 |
-
for k, v in state_dict.items():
|
875 |
-
if type(v) is torch.Tensor:
|
876 |
-
state_dict[k] = v.to(dtype)
|
877 |
-
|
878 |
-
# Convert the UNet2DConditionModel model.
|
879 |
-
unet_config = create_unet_diffusers_config(v2)
|
880 |
-
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
881 |
-
|
882 |
-
unet = UNet2DConditionModel(**unet_config)
|
883 |
-
info = unet.load_state_dict(converted_unet_checkpoint)
|
884 |
-
print("loading u-net:", info)
|
885 |
-
|
886 |
-
# Convert the VAE model.
|
887 |
-
vae_config = create_vae_diffusers_config()
|
888 |
-
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
889 |
-
|
890 |
-
vae = AutoencoderKL(**vae_config)
|
891 |
-
info = vae.load_state_dict(converted_vae_checkpoint)
|
892 |
-
print("loading vae:", info)
|
893 |
-
|
894 |
-
# convert text_model
|
895 |
-
if v2:
|
896 |
-
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
897 |
-
cfg = CLIPTextConfig(
|
898 |
-
vocab_size=49408,
|
899 |
-
hidden_size=1024,
|
900 |
-
intermediate_size=4096,
|
901 |
-
num_hidden_layers=23,
|
902 |
-
num_attention_heads=16,
|
903 |
-
max_position_embeddings=77,
|
904 |
-
hidden_act="gelu",
|
905 |
-
layer_norm_eps=1e-05,
|
906 |
-
dropout=0.0,
|
907 |
-
attention_dropout=0.0,
|
908 |
-
initializer_range=0.02,
|
909 |
-
initializer_factor=1.0,
|
910 |
-
pad_token_id=1,
|
911 |
-
bos_token_id=0,
|
912 |
-
eos_token_id=2,
|
913 |
-
model_type="clip_text_model",
|
914 |
-
projection_dim=512,
|
915 |
-
torch_dtype="float32",
|
916 |
-
transformers_version="4.25.0.dev0",
|
917 |
-
)
|
918 |
-
text_model = CLIPTextModel._from_config(cfg)
|
919 |
-
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
920 |
-
else:
|
921 |
-
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
922 |
-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
923 |
-
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
924 |
-
print("loading text encoder:", info)
|
925 |
|
926 |
-
|
|
|
|
|
927 |
|
|
|
|
|
|
|
928 |
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
if ".position_ids" in key:
|
933 |
-
return None
|
934 |
-
|
935 |
-
# common
|
936 |
-
key = key.replace("text_model.encoder.", "transformer.")
|
937 |
-
key = key.replace("text_model.", "")
|
938 |
-
if "layers" in key:
|
939 |
-
# resblocks conversion
|
940 |
-
key = key.replace(".layers.", ".resblocks.")
|
941 |
-
if ".layer_norm" in key:
|
942 |
-
key = key.replace(".layer_norm", ".ln_")
|
943 |
-
elif ".mlp." in key:
|
944 |
-
key = key.replace(".fc1.", ".c_fc.")
|
945 |
-
key = key.replace(".fc2.", ".c_proj.")
|
946 |
-
elif '.self_attn.out_proj' in key:
|
947 |
-
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
948 |
-
elif '.self_attn.' in key:
|
949 |
-
key = None # 特殊なので後で処理する
|
950 |
-
else:
|
951 |
-
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
952 |
-
elif '.position_embedding' in key:
|
953 |
-
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
954 |
-
elif '.token_embedding' in key:
|
955 |
-
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
956 |
-
elif 'final_layer_norm' in key:
|
957 |
-
key = key.replace("final_layer_norm", "ln_final")
|
958 |
-
return key
|
959 |
-
|
960 |
-
keys = list(checkpoint.keys())
|
961 |
-
new_sd = {}
|
962 |
-
for key in keys:
|
963 |
-
new_key = convert_key(key)
|
964 |
-
if new_key is None:
|
965 |
-
continue
|
966 |
-
new_sd[new_key] = checkpoint[key]
|
967 |
-
|
968 |
-
# attnの変換
|
969 |
-
for key in keys:
|
970 |
-
if 'layers' in key and 'q_proj' in key:
|
971 |
-
# 三つを結合
|
972 |
-
key_q = key
|
973 |
-
key_k = key.replace("q_proj", "k_proj")
|
974 |
-
key_v = key.replace("q_proj", "v_proj")
|
975 |
-
|
976 |
-
value_q = checkpoint[key_q]
|
977 |
-
value_k = checkpoint[key_k]
|
978 |
-
value_v = checkpoint[key_v]
|
979 |
-
value = torch.cat([value_q, value_k, value_v])
|
980 |
-
|
981 |
-
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
982 |
-
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
983 |
-
new_sd[new_key] = value
|
984 |
-
|
985 |
-
# 最後の層などを捏造するか
|
986 |
-
if make_dummy_weights:
|
987 |
-
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
988 |
-
keys = list(new_sd.keys())
|
989 |
-
for key in keys:
|
990 |
-
if key.startswith("transformer.resblocks.22."):
|
991 |
-
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
992 |
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
|
997 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
998 |
|
|
|
|
|
|
|
999 |
|
1000 |
-
|
1001 |
-
|
1002 |
-
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1003 |
-
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1004 |
-
if checkpoint is None: # safetensors または state_dictのckpt
|
1005 |
-
checkpoint = {}
|
1006 |
-
strict = False
|
1007 |
-
else:
|
1008 |
-
strict = True
|
1009 |
-
if "state_dict" in state_dict:
|
1010 |
-
del state_dict["state_dict"]
|
1011 |
-
else:
|
1012 |
-
# 新しく作る
|
1013 |
-
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1014 |
-
checkpoint = {}
|
1015 |
-
state_dict = {}
|
1016 |
-
strict = False
|
1017 |
-
|
1018 |
-
def update_sd(prefix, sd):
|
1019 |
-
for k, v in sd.items():
|
1020 |
-
key = prefix + k
|
1021 |
-
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1022 |
-
if save_dtype is not None:
|
1023 |
-
v = v.detach().clone().to("cpu").to(save_dtype)
|
1024 |
-
state_dict[key] = v
|
1025 |
-
|
1026 |
-
# Convert the UNet model
|
1027 |
-
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1028 |
-
update_sd("model.diffusion_model.", unet_state_dict)
|
1029 |
-
|
1030 |
-
# Convert the text encoder model
|
1031 |
-
if v2:
|
1032 |
-
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1033 |
-
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1034 |
-
update_sd("cond_stage_model.model.", text_enc_dict)
|
1035 |
-
else:
|
1036 |
-
text_enc_dict = text_encoder.state_dict()
|
1037 |
-
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1038 |
-
|
1039 |
-
# Convert the VAE
|
1040 |
-
if vae is not None:
|
1041 |
-
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1042 |
-
update_sd("first_stage_model.", vae_dict)
|
1043 |
-
|
1044 |
-
# Put together new checkpoint
|
1045 |
-
key_count = len(state_dict.keys())
|
1046 |
-
new_ckpt = {'state_dict': state_dict}
|
1047 |
-
|
1048 |
-
if 'epoch' in checkpoint:
|
1049 |
-
epochs += checkpoint['epoch']
|
1050 |
-
if 'global_step' in checkpoint:
|
1051 |
-
steps += checkpoint['global_step']
|
1052 |
-
|
1053 |
-
new_ckpt['epoch'] = epochs
|
1054 |
-
new_ckpt['global_step'] = steps
|
1055 |
-
|
1056 |
-
if is_safetensors(output_file):
|
1057 |
-
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1058 |
-
save_file(state_dict, output_file)
|
1059 |
-
else:
|
1060 |
-
torch.save(new_ckpt, output_file)
|
1061 |
-
|
1062 |
-
return key_count
|
1063 |
|
|
|
1064 |
|
1065 |
-
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1066 |
-
if pretrained_model_name_or_path is None:
|
1067 |
-
# load default settings for v1/v2
|
1068 |
-
if v2:
|
1069 |
-
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1070 |
-
else:
|
1071 |
-
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1072 |
-
|
1073 |
-
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1074 |
-
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1075 |
-
if vae is None:
|
1076 |
-
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1077 |
-
|
1078 |
-
pipeline = StableDiffusionPipeline(
|
1079 |
-
unet=unet,
|
1080 |
-
text_encoder=text_encoder,
|
1081 |
-
vae=vae,
|
1082 |
-
scheduler=scheduler,
|
1083 |
-
tokenizer=tokenizer,
|
1084 |
-
safety_checker=None,
|
1085 |
-
feature_extractor=None,
|
1086 |
-
requires_safety_checker=None,
|
1087 |
-
)
|
1088 |
-
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1089 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1090 |
|
1091 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1092 |
|
|
|
|
|
|
|
|
|
1093 |
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
# Diffusers local/remote
|
1098 |
-
try:
|
1099 |
-
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1100 |
-
except EnvironmentError as e:
|
1101 |
-
print(f"exception occurs in loading vae: {e}")
|
1102 |
-
print("retry with subfolder='vae'")
|
1103 |
-
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1104 |
-
return vae
|
1105 |
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
# StableDiffusion
|
1114 |
-
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1115 |
-
else torch.load(vae_id, map_location="cpu"))
|
1116 |
-
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1117 |
-
|
1118 |
-
# vae only or full model
|
1119 |
-
full_model = False
|
1120 |
-
for vae_key in vae_sd:
|
1121 |
-
if vae_key.startswith(VAE_PREFIX):
|
1122 |
-
full_model = True
|
1123 |
-
break
|
1124 |
-
if not full_model:
|
1125 |
-
sd = {}
|
1126 |
-
for key, value in vae_sd.items():
|
1127 |
-
sd[VAE_PREFIX + key] = value
|
1128 |
-
vae_sd = sd
|
1129 |
-
del sd
|
1130 |
|
1131 |
-
|
1132 |
-
|
|
|
1133 |
|
1134 |
-
|
1135 |
-
vae.load_state_dict(converted_vae_checkpoint)
|
1136 |
-
return vae
|
1137 |
|
1138 |
-
# endregion
|
1139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1140 |
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
|
|
1144 |
|
1145 |
-
|
|
|
|
|
1146 |
|
1147 |
-
|
1148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1149 |
|
1150 |
-
|
1151 |
-
while size <= max_size:
|
1152 |
-
width = size
|
1153 |
-
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1154 |
-
resos.add((width, height))
|
1155 |
-
resos.add((height, width))
|
1156 |
|
1157 |
-
# # make additional resos
|
1158 |
-
# if width >= height and width - divisible >= min_size:
|
1159 |
-
# resos.add((width - divisible, height))
|
1160 |
-
# resos.add((height, width - divisible))
|
1161 |
-
# if height >= width and height - divisible >= min_size:
|
1162 |
-
# resos.add((width, height - divisible))
|
1163 |
-
# resos.add((height - divisible, width))
|
1164 |
|
1165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1166 |
|
1167 |
-
resos = list(resos)
|
1168 |
-
resos.sort()
|
1169 |
|
1170 |
-
|
1171 |
-
return resos, aspect_ratios
|
1172 |
|
1173 |
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1179 |
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# v1: split from train_db_fixed.py.
|
2 |
# v2: support safetensors
|
3 |
|
4 |
import math
|
5 |
import os
|
6 |
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
8 |
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
from safetensors.torch import load_file, save_file
|
10 |
|
|
|
16 |
UNET_PARAMS_MODEL_CHANNELS = 320
|
17 |
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
18 |
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
19 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
20 |
UNET_PARAMS_IN_CHANNELS = 4
|
21 |
UNET_PARAMS_OUT_CHANNELS = 4
|
22 |
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
|
|
45 |
|
46 |
|
47 |
def shave_segments(path, n_shave_prefix_segments=1):
|
48 |
+
"""
|
49 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
50 |
+
"""
|
51 |
+
if n_shave_prefix_segments >= 0:
|
52 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
53 |
+
else:
|
54 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
55 |
|
56 |
|
57 |
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
58 |
+
"""
|
59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
60 |
+
"""
|
61 |
+
mapping = []
|
62 |
+
for old_item in old_list:
|
63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
65 |
|
66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
68 |
|
69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
71 |
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
|
74 |
+
mapping.append({"old": old_item, "new": new_item})
|
75 |
|
76 |
+
return mapping
|
77 |
|
78 |
|
79 |
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
|
87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
89 |
|
90 |
+
mapping.append({"old": old_item, "new": new_item})
|
91 |
|
92 |
+
return mapping
|
93 |
|
94 |
|
95 |
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
96 |
+
"""
|
97 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
98 |
+
"""
|
99 |
+
mapping = []
|
100 |
+
for old_item in old_list:
|
101 |
+
new_item = old_item
|
102 |
|
103 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
104 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
105 |
|
106 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
107 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
108 |
|
109 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
110 |
|
111 |
+
mapping.append({"old": old_item, "new": new_item})
|
112 |
|
113 |
+
return mapping
|
114 |
|
115 |
|
116 |
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
117 |
+
"""
|
118 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
119 |
+
"""
|
120 |
+
mapping = []
|
121 |
+
for old_item in old_list:
|
122 |
+
new_item = old_item
|
123 |
|
124 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
125 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
126 |
|
127 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
128 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
129 |
|
130 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
131 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
132 |
|
133 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
134 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
135 |
|
136 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
137 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
138 |
|
139 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
140 |
|
141 |
+
mapping.append({"old": old_item, "new": new_item})
|
142 |
|
143 |
+
return mapping
|
144 |
|
145 |
|
146 |
def assign_to_checkpoint(
|
147 |
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
148 |
):
|
149 |
+
"""
|
150 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
151 |
+
to them. It splits attention layers, and takes into account additional replacements
|
152 |
+
that may arise.
|
153 |
|
154 |
+
Assigns the weights to the new checkpoint.
|
155 |
+
"""
|
156 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
157 |
|
158 |
+
# Splits the attention layers into three variables.
|
159 |
+
if attention_paths_to_split is not None:
|
160 |
+
for path, path_map in attention_paths_to_split.items():
|
161 |
+
old_tensor = old_checkpoint[path]
|
162 |
+
channels = old_tensor.shape[0] // 3
|
163 |
|
164 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
165 |
|
166 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
167 |
|
168 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
169 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
170 |
|
171 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
172 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
173 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
174 |
|
175 |
+
for path in paths:
|
176 |
+
new_path = path["new"]
|
177 |
|
178 |
+
# These have already been assigned
|
179 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
180 |
+
continue
|
181 |
|
182 |
+
# Global renaming happens here
|
183 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
184 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
185 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
186 |
|
187 |
+
if additional_replacements is not None:
|
188 |
+
for replacement in additional_replacements:
|
189 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
190 |
|
191 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
192 |
+
if "proj_attn.weight" in new_path:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
194 |
+
else:
|
195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
196 |
|
197 |
|
198 |
def conv_attn_to_linear(checkpoint):
|
199 |
+
keys = list(checkpoint.keys())
|
200 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
201 |
+
for key in keys:
|
202 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
205 |
+
elif "proj_attn.weight" in key:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
208 |
|
209 |
|
210 |
def linear_transformer_to_conv(checkpoint):
|
211 |
+
keys = list(checkpoint.keys())
|
212 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
213 |
+
for key in keys:
|
214 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
215 |
+
if checkpoint[key].ndim == 2:
|
216 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
217 |
|
218 |
|
219 |
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
220 |
+
"""
|
221 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# extract state_dict for UNet
|
225 |
+
unet_state_dict = {}
|
226 |
+
unet_key = "model.diffusion_model."
|
227 |
+
keys = list(checkpoint.keys())
|
228 |
+
for key in keys:
|
229 |
+
if key.startswith(unet_key):
|
230 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
231 |
+
|
232 |
+
new_checkpoint = {}
|
233 |
+
|
234 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
235 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
236 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
237 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
238 |
+
|
239 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
240 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
246 |
+
|
247 |
+
# Retrieves the keys for the input blocks only
|
248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
249 |
+
input_blocks = {
|
250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
251 |
+
}
|
252 |
+
|
253 |
+
# Retrieves the keys for the middle blocks only
|
254 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
255 |
+
middle_blocks = {
|
256 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
257 |
+
}
|
258 |
+
|
259 |
+
# Retrieves the keys for the output blocks only
|
260 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
261 |
+
output_blocks = {
|
262 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
263 |
+
}
|
264 |
+
|
265 |
+
for i in range(1, num_input_blocks):
|
266 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
267 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
268 |
+
|
269 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
270 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
271 |
+
|
272 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
273 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
274 |
+
f"input_blocks.{i}.0.op.weight"
|
275 |
+
)
|
276 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
277 |
+
|
278 |
+
paths = renew_resnet_paths(resnets)
|
279 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
280 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
281 |
+
|
282 |
+
if len(attentions):
|
283 |
+
paths = renew_attention_paths(attentions)
|
284 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
285 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
286 |
+
|
287 |
+
resnet_0 = middle_blocks[0]
|
288 |
+
attentions = middle_blocks[1]
|
289 |
+
resnet_1 = middle_blocks[2]
|
290 |
+
|
291 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
292 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
293 |
+
|
294 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
295 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
296 |
+
|
297 |
+
attentions_paths = renew_attention_paths(attentions)
|
298 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
299 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
300 |
+
|
301 |
+
for i in range(num_output_blocks):
|
302 |
+
block_id = i // (config["layers_per_block"] + 1)
|
303 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
304 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
305 |
+
output_block_list = {}
|
306 |
+
|
307 |
+
for layer in output_block_layers:
|
308 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
309 |
+
if layer_id in output_block_list:
|
310 |
+
output_block_list[layer_id].append(layer_name)
|
311 |
+
else:
|
312 |
+
output_block_list[layer_id] = [layer_name]
|
313 |
+
|
314 |
+
if len(output_block_list) > 1:
|
315 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
316 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
317 |
+
|
318 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
319 |
+
paths = renew_resnet_paths(resnets)
|
320 |
+
|
321 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
322 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
323 |
+
|
324 |
+
# オリジナル:
|
325 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
326 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
327 |
+
|
328 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
329 |
+
for l in output_block_list.values():
|
330 |
+
l.sort()
|
331 |
+
|
332 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
333 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
334 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
335 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
336 |
+
]
|
337 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
338 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
339 |
+
]
|
340 |
+
|
341 |
+
# Clear attentions as they have been attributed above.
|
342 |
+
if len(attentions) == 2:
|
343 |
+
attentions = []
|
344 |
+
|
345 |
+
if len(attentions):
|
346 |
+
paths = renew_attention_paths(attentions)
|
347 |
+
meta_path = {
|
348 |
+
"old": f"output_blocks.{i}.1",
|
349 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
350 |
+
}
|
351 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
352 |
+
else:
|
353 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
354 |
+
for path in resnet_0_paths:
|
355 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
356 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
357 |
+
|
358 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
359 |
+
|
360 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
361 |
+
if v2:
|
362 |
+
linear_transformer_to_conv(new_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
+
return new_checkpoint
|
365 |
|
366 |
|
367 |
def convert_ldm_vae_checkpoint(checkpoint, config):
|
368 |
+
# extract state dict for VAE
|
369 |
+
vae_state_dict = {}
|
370 |
+
vae_key = "first_stage_model."
|
371 |
+
keys = list(checkpoint.keys())
|
372 |
+
for key in keys:
|
373 |
+
if key.startswith(vae_key):
|
374 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
375 |
+
# if len(vae_state_dict) == 0:
|
376 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
377 |
+
# vae_state_dict = checkpoint
|
378 |
+
|
379 |
+
new_checkpoint = {}
|
380 |
+
|
381 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
382 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
383 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
384 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
385 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
386 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
387 |
+
|
388 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
389 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
390 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
391 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
392 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
393 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
394 |
+
|
395 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
396 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
397 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
398 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
399 |
+
|
400 |
+
# Retrieves the keys for the encoder down blocks only
|
401 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
402 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
403 |
+
|
404 |
+
# Retrieves the keys for the decoder up blocks only
|
405 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
406 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
407 |
+
|
408 |
+
for i in range(num_down_blocks):
|
409 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
410 |
+
|
411 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
412 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
413 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
414 |
+
)
|
415 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
416 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
417 |
+
)
|
418 |
+
|
419 |
+
paths = renew_vae_resnet_paths(resnets)
|
420 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
421 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
422 |
+
|
423 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
424 |
+
num_mid_res_blocks = 2
|
425 |
+
for i in range(1, num_mid_res_blocks + 1):
|
426 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
427 |
+
|
428 |
+
paths = renew_vae_resnet_paths(resnets)
|
429 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
430 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
431 |
+
|
432 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
433 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
434 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
435 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
436 |
+
conv_attn_to_linear(new_checkpoint)
|
437 |
+
|
438 |
+
for i in range(num_up_blocks):
|
439 |
+
block_id = num_up_blocks - 1 - i
|
440 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
441 |
+
|
442 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
443 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
444 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
445 |
+
]
|
446 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
447 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
448 |
+
]
|
449 |
+
|
450 |
+
paths = renew_vae_resnet_paths(resnets)
|
451 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
452 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
453 |
+
|
454 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
455 |
+
num_mid_res_blocks = 2
|
456 |
+
for i in range(1, num_mid_res_blocks + 1):
|
457 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
458 |
+
|
459 |
+
paths = renew_vae_resnet_paths(resnets)
|
460 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
461 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
462 |
+
|
463 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
464 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
465 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
466 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
467 |
+
conv_attn_to_linear(new_checkpoint)
|
468 |
+
return new_checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
|
471 |
def create_unet_diffusers_config(v2):
|
472 |
+
"""
|
473 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
474 |
+
"""
|
475 |
+
# unet_params = original_config.model.params.unet_config.params
|
476 |
+
|
477 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
478 |
+
|
479 |
+
down_block_types = []
|
480 |
+
resolution = 1
|
481 |
+
for i in range(len(block_out_channels)):
|
482 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
483 |
+
down_block_types.append(block_type)
|
484 |
+
if i != len(block_out_channels) - 1:
|
485 |
+
resolution *= 2
|
486 |
+
|
487 |
+
up_block_types = []
|
488 |
+
for i in range(len(block_out_channels)):
|
489 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
490 |
+
up_block_types.append(block_type)
|
491 |
+
resolution //= 2
|
492 |
+
|
493 |
+
config = dict(
|
494 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
495 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
496 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
497 |
+
down_block_types=tuple(down_block_types),
|
498 |
+
up_block_types=tuple(up_block_types),
|
499 |
+
block_out_channels=tuple(block_out_channels),
|
500 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
501 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
502 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
503 |
+
)
|
504 |
+
|
505 |
+
return config
|
506 |
|
507 |
|
508 |
def create_vae_diffusers_config():
|
509 |
+
"""
|
510 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
511 |
+
"""
|
512 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
513 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
514 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
515 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
516 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
517 |
+
|
518 |
+
config = dict(
|
519 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
520 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
521 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
522 |
+
down_block_types=tuple(down_block_types),
|
523 |
+
up_block_types=tuple(up_block_types),
|
524 |
+
block_out_channels=tuple(block_out_channels),
|
525 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
526 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
527 |
+
)
|
528 |
+
return config
|
529 |
|
530 |
|
531 |
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
532 |
+
keys = list(checkpoint.keys())
|
533 |
+
text_model_dict = {}
|
534 |
+
for key in keys:
|
535 |
+
if key.startswith("cond_stage_model.transformer"):
|
536 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
537 |
+
return text_model_dict
|
538 |
|
539 |
|
540 |
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
541 |
+
# 嫌になるくらい違うぞ!
|
542 |
+
def convert_key(key):
|
543 |
+
if not key.startswith("cond_stage_model"):
|
544 |
+
return None
|
545 |
+
|
546 |
+
# common conversion
|
547 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
548 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
549 |
+
|
550 |
+
if "resblocks" in key:
|
551 |
+
# resblocks conversion
|
552 |
+
key = key.replace(".resblocks.", ".layers.")
|
553 |
+
if ".ln_" in key:
|
554 |
+
key = key.replace(".ln_", ".layer_norm")
|
555 |
+
elif ".mlp." in key:
|
556 |
+
key = key.replace(".c_fc.", ".fc1.")
|
557 |
+
key = key.replace(".c_proj.", ".fc2.")
|
558 |
+
elif ".attn.out_proj" in key:
|
559 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
560 |
+
elif ".attn.in_proj" in key:
|
561 |
+
key = None # 特殊なので後で処理する
|
562 |
+
else:
|
563 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
564 |
+
elif ".positional_embedding" in key:
|
565 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
566 |
+
elif ".text_projection" in key:
|
567 |
+
key = None # 使われない???
|
568 |
+
elif ".logit_scale" in key:
|
569 |
+
key = None # 使われない???
|
570 |
+
elif ".token_embedding" in key:
|
571 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
572 |
+
elif ".ln_final" in key:
|
573 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
574 |
+
return key
|
575 |
+
|
576 |
+
keys = list(checkpoint.keys())
|
577 |
+
new_sd = {}
|
578 |
+
for key in keys:
|
579 |
+
# remove resblocks 23
|
580 |
+
if ".resblocks.23." in key:
|
581 |
+
continue
|
582 |
+
new_key = convert_key(key)
|
583 |
+
if new_key is None:
|
584 |
+
continue
|
585 |
+
new_sd[new_key] = checkpoint[key]
|
586 |
+
|
587 |
+
# attnの変換
|
588 |
+
for key in keys:
|
589 |
+
if ".resblocks.23." in key:
|
590 |
+
continue
|
591 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
592 |
+
# 三つに分割
|
593 |
+
values = torch.chunk(checkpoint[key], 3)
|
594 |
+
|
595 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
596 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
597 |
+
key_pfx = key_pfx.replace("_weight", "")
|
598 |
+
key_pfx = key_pfx.replace("_bias", "")
|
599 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
600 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
601 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
602 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
603 |
+
|
604 |
+
# rename or add position_ids
|
605 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
606 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
607 |
+
# waifu diffusion v1.4
|
608 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
609 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
610 |
+
else:
|
611 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
612 |
+
|
613 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
614 |
+
return new_sd
|
615 |
+
|
616 |
|
617 |
# endregion
|
618 |
|
|
|
620 |
# region Diffusers->StableDiffusion の変換コード
|
621 |
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
622 |
|
623 |
+
|
624 |
def conv_transformer_to_linear(checkpoint):
|
625 |
+
keys = list(checkpoint.keys())
|
626 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
627 |
+
for key in keys:
|
628 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
629 |
+
if checkpoint[key].ndim > 2:
|
630 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
631 |
|
632 |
|
633 |
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
634 |
+
unet_conversion_map = [
|
635 |
+
# (stable-diffusion, HF Diffusers)
|
636 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
637 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
638 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
639 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
640 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
641 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
642 |
+
("out.0.weight", "conv_norm_out.weight"),
|
643 |
+
("out.0.bias", "conv_norm_out.bias"),
|
644 |
+
("out.2.weight", "conv_out.weight"),
|
645 |
+
("out.2.bias", "conv_out.bias"),
|
646 |
+
]
|
647 |
+
|
648 |
+
unet_conversion_map_resnet = [
|
649 |
+
# (stable-diffusion, HF Diffusers)
|
650 |
+
("in_layers.0", "norm1"),
|
651 |
+
("in_layers.2", "conv1"),
|
652 |
+
("out_layers.0", "norm2"),
|
653 |
+
("out_layers.3", "conv2"),
|
654 |
+
("emb_layers.1", "time_emb_proj"),
|
655 |
+
("skip_connection", "conv_shortcut"),
|
656 |
+
]
|
657 |
+
|
658 |
+
unet_conversion_map_layer = []
|
659 |
+
for i in range(4):
|
660 |
+
# loop over downblocks/upblocks
|
661 |
+
|
662 |
+
for j in range(2):
|
663 |
+
# loop over resnets/attentions for downblocks
|
664 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
665 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
666 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
667 |
+
|
668 |
+
if i < 3:
|
669 |
+
# no attention layers in down_blocks.3
|
670 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
671 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
672 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
673 |
+
|
674 |
+
for j in range(3):
|
675 |
+
# loop over resnets/attentions for upblocks
|
676 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
677 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
678 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
679 |
+
|
680 |
+
if i > 0:
|
681 |
+
# no attention layers in up_blocks.0
|
682 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
683 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
684 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
685 |
+
|
686 |
+
if i < 3:
|
687 |
+
# no downsample in down_blocks.3
|
688 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
689 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
690 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
691 |
+
|
692 |
+
# no upsample in up_blocks.3
|
693 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
694 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
695 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
696 |
+
|
697 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
698 |
+
sd_mid_atn_prefix = "middle_block.1."
|
699 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
700 |
|
701 |
for j in range(2):
|
702 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
703 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
704 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
705 |
+
|
706 |
+
# buyer beware: this is a *brittle* function,
|
707 |
+
# and correct output requires that all of these pieces interact in
|
708 |
+
# the exact order in which I have arranged them.
|
709 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
710 |
+
for sd_name, hf_name in unet_conversion_map:
|
711 |
+
mapping[hf_name] = sd_name
|
712 |
+
for k, v in mapping.items():
|
713 |
+
if "resnets" in k:
|
714 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
715 |
+
v = v.replace(hf_part, sd_part)
|
716 |
+
mapping[k] = v
|
717 |
+
for k, v in mapping.items():
|
718 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
719 |
+
v = v.replace(hf_part, sd_part)
|
720 |
+
mapping[k] = v
|
721 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
722 |
+
|
723 |
+
if v2:
|
724 |
+
conv_transformer_to_linear(new_state_dict)
|
725 |
+
|
726 |
+
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
|
728 |
|
729 |
# ================#
|
730 |
# VAE Conversion #
|
731 |
# ================#
|
732 |
|
733 |
+
|
734 |
def reshape_weight_for_sd(w):
|
735 |
# convert HF linear weights to SD conv2d weights
|
736 |
+
return w.reshape(*w.shape, 1, 1)
|
737 |
|
738 |
|
739 |
def convert_vae_state_dict(vae_state_dict):
|
740 |
+
vae_conversion_map = [
|
741 |
+
# (stable-diffusion, HF Diffusers)
|
742 |
+
("nin_shortcut", "conv_shortcut"),
|
743 |
+
("norm_out", "conv_norm_out"),
|
744 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
745 |
+
]
|
746 |
+
|
747 |
+
for i in range(4):
|
748 |
+
# down_blocks have two resnets
|
749 |
+
for j in range(2):
|
750 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
751 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
752 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
753 |
+
|
754 |
+
if i < 3:
|
755 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
756 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
757 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
758 |
+
|
759 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
760 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
761 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
762 |
+
|
763 |
+
# up_blocks have three resnets
|
764 |
+
# also, up blocks in hf are numbered in reverse from sd
|
765 |
+
for j in range(3):
|
766 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
767 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
768 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
769 |
+
|
770 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
771 |
+
for i in range(2):
|
772 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
773 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
774 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
775 |
+
|
776 |
+
vae_conversion_map_attn = [
|
777 |
+
# (stable-diffusion, HF Diffusers)
|
778 |
+
("norm.", "group_norm."),
|
779 |
+
("q.", "query."),
|
780 |
+
("k.", "key."),
|
781 |
+
("v.", "value."),
|
782 |
+
("proj_out.", "proj_attn."),
|
783 |
+
]
|
784 |
+
|
785 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
786 |
+
for k, v in mapping.items():
|
787 |
+
for sd_part, hf_part in vae_conversion_map:
|
788 |
+
v = v.replace(hf_part, sd_part)
|
789 |
+
mapping[k] = v
|
790 |
+
for k, v in mapping.items():
|
791 |
+
if "attentions" in k:
|
792 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
793 |
+
v = v.replace(hf_part, sd_part)
|
794 |
+
mapping[k] = v
|
795 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
796 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
797 |
+
for k, v in new_state_dict.items():
|
798 |
+
for weight_name in weights_to_convert:
|
799 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
800 |
+
# print(f"Reshaping {k} for SD format")
|
801 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
802 |
+
|
803 |
+
return new_state_dict
|
804 |
|
805 |
|
806 |
# endregion
|
807 |
|
808 |
# region 自作のモデル読み書きなど
|
809 |
|
810 |
+
|
811 |
def is_safetensors(path):
|
812 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
813 |
+
|
814 |
+
|
815 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
816 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
817 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
818 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
819 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
820 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
821 |
+
]
|
822 |
+
|
823 |
+
if is_safetensors(ckpt_path):
|
824 |
+
checkpoint = None
|
825 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
|
|
|
|
|
|
|
|
826 |
else:
|
827 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
828 |
+
if "state_dict" in checkpoint:
|
829 |
+
state_dict = checkpoint["state_dict"]
|
830 |
+
else:
|
831 |
+
state_dict = checkpoint
|
832 |
+
checkpoint = None
|
833 |
|
834 |
+
key_reps = []
|
835 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
836 |
+
for key in state_dict.keys():
|
837 |
+
if key.startswith(rep_from):
|
838 |
+
new_key = rep_to + key[len(rep_from) :]
|
839 |
+
key_reps.append((key, new_key))
|
840 |
|
841 |
+
for key, new_key in key_reps:
|
842 |
+
state_dict[new_key] = state_dict[key]
|
843 |
+
del state_dict[key]
|
844 |
|
845 |
+
return checkpoint, state_dict
|
846 |
|
847 |
|
848 |
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
849 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
|
850 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
851 |
|
852 |
+
# Convert the UNet2DConditionModel model.
|
853 |
+
unet_config = create_unet_diffusers_config(v2)
|
854 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
855 |
|
856 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
857 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
858 |
+
print("loading u-net:", info)
|
859 |
|
860 |
+
# Convert the VAE model.
|
861 |
+
vae_config = create_vae_diffusers_config()
|
862 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
863 |
|
864 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
865 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
866 |
+
print("loading vae:", info)
|
867 |
|
868 |
+
# convert text_model
|
869 |
+
if v2:
|
870 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
871 |
+
cfg = CLIPTextConfig(
|
872 |
+
vocab_size=49408,
|
873 |
+
hidden_size=1024,
|
874 |
+
intermediate_size=4096,
|
875 |
+
num_hidden_layers=23,
|
876 |
+
num_attention_heads=16,
|
877 |
+
max_position_embeddings=77,
|
878 |
+
hidden_act="gelu",
|
879 |
+
layer_norm_eps=1e-05,
|
880 |
+
dropout=0.0,
|
881 |
+
attention_dropout=0.0,
|
882 |
+
initializer_range=0.02,
|
883 |
+
initializer_factor=1.0,
|
884 |
+
pad_token_id=1,
|
885 |
+
bos_token_id=0,
|
886 |
+
eos_token_id=2,
|
887 |
+
model_type="clip_text_model",
|
888 |
+
projection_dim=512,
|
889 |
+
torch_dtype="float32",
|
890 |
+
transformers_version="4.25.0.dev0",
|
891 |
+
)
|
892 |
+
text_model = CLIPTextModel._from_config(cfg)
|
893 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
894 |
+
else:
|
895 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
896 |
|
897 |
+
logging.set_verbosity_error() # don't show annoying warning
|
898 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
899 |
+
logging.set_verbosity_warning()
|
900 |
|
901 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
902 |
+
print("loading text encoder:", info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
903 |
|
904 |
+
return text_model, vae, unet
|
905 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
906 |
|
907 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
908 |
+
def convert_key(key):
|
909 |
+
# position_idsの除去
|
910 |
+
if ".position_ids" in key:
|
911 |
+
return None
|
912 |
+
|
913 |
+
# common
|
914 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
915 |
+
key = key.replace("text_model.", "")
|
916 |
+
if "layers" in key:
|
917 |
+
# resblocks conversion
|
918 |
+
key = key.replace(".layers.", ".resblocks.")
|
919 |
+
if ".layer_norm" in key:
|
920 |
+
key = key.replace(".layer_norm", ".ln_")
|
921 |
+
elif ".mlp." in key:
|
922 |
+
key = key.replace(".fc1.", ".c_fc.")
|
923 |
+
key = key.replace(".fc2.", ".c_proj.")
|
924 |
+
elif ".self_attn.out_proj" in key:
|
925 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
926 |
+
elif ".self_attn." in key:
|
927 |
+
key = None # 特殊なので後で処理する
|
928 |
+
else:
|
929 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
930 |
+
elif ".position_embedding" in key:
|
931 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
932 |
+
elif ".token_embedding" in key:
|
933 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
934 |
+
elif "final_layer_norm" in key:
|
935 |
+
key = key.replace("final_layer_norm", "ln_final")
|
936 |
+
return key
|
937 |
+
|
938 |
+
keys = list(checkpoint.keys())
|
939 |
+
new_sd = {}
|
940 |
+
for key in keys:
|
941 |
+
new_key = convert_key(key)
|
942 |
+
if new_key is None:
|
943 |
+
continue
|
944 |
+
new_sd[new_key] = checkpoint[key]
|
945 |
|
946 |
+
# attnの変換
|
947 |
+
for key in keys:
|
948 |
+
if "layers" in key and "q_proj" in key:
|
949 |
+
# 三つを結合
|
950 |
+
key_q = key
|
951 |
+
key_k = key.replace("q_proj", "k_proj")
|
952 |
+
key_v = key.replace("q_proj", "v_proj")
|
953 |
|
954 |
+
value_q = checkpoint[key_q]
|
955 |
+
value_k = checkpoint[key_k]
|
956 |
+
value_v = checkpoint[key_v]
|
957 |
+
value = torch.cat([value_q, value_k, value_v])
|
958 |
|
959 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
960 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
961 |
+
new_sd[new_key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
962 |
|
963 |
+
# 最後の層などを捏造するか
|
964 |
+
if make_dummy_weights:
|
965 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
966 |
+
keys = list(new_sd.keys())
|
967 |
+
for key in keys:
|
968 |
+
if key.startswith("transformer.resblocks.22."):
|
969 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
970 |
|
971 |
+
# Diffusersに含まれない重みを作っておく
|
972 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
973 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
974 |
|
975 |
+
return new_sd
|
|
|
|
|
976 |
|
|
|
977 |
|
978 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
979 |
+
if ckpt_path is not None:
|
980 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
981 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
982 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
983 |
+
checkpoint = {}
|
984 |
+
strict = False
|
985 |
+
else:
|
986 |
+
strict = True
|
987 |
+
if "state_dict" in state_dict:
|
988 |
+
del state_dict["state_dict"]
|
989 |
+
else:
|
990 |
+
# 新しく作る
|
991 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
992 |
+
checkpoint = {}
|
993 |
+
state_dict = {}
|
994 |
+
strict = False
|
995 |
+
|
996 |
+
def update_sd(prefix, sd):
|
997 |
+
for k, v in sd.items():
|
998 |
+
key = prefix + k
|
999 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1000 |
+
if save_dtype is not None:
|
1001 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1002 |
+
state_dict[key] = v
|
1003 |
+
|
1004 |
+
# Convert the UNet model
|
1005 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1006 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1007 |
+
|
1008 |
+
# Convert the text encoder model
|
1009 |
+
if v2:
|
1010 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1011 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1012 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1013 |
+
else:
|
1014 |
+
text_enc_dict = text_encoder.state_dict()
|
1015 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1016 |
|
1017 |
+
# Convert the VAE
|
1018 |
+
if vae is not None:
|
1019 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1020 |
+
update_sd("first_stage_model.", vae_dict)
|
1021 |
|
1022 |
+
# Put together new checkpoint
|
1023 |
+
key_count = len(state_dict.keys())
|
1024 |
+
new_ckpt = {"state_dict": state_dict}
|
1025 |
|
1026 |
+
# epoch and global_step are sometimes not int
|
1027 |
+
try:
|
1028 |
+
if "epoch" in checkpoint:
|
1029 |
+
epochs += checkpoint["epoch"]
|
1030 |
+
if "global_step" in checkpoint:
|
1031 |
+
steps += checkpoint["global_step"]
|
1032 |
+
except:
|
1033 |
+
pass
|
1034 |
+
|
1035 |
+
new_ckpt["epoch"] = epochs
|
1036 |
+
new_ckpt["global_step"] = steps
|
1037 |
+
|
1038 |
+
if is_safetensors(output_file):
|
1039 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1040 |
+
save_file(state_dict, output_file)
|
1041 |
+
else:
|
1042 |
+
torch.save(new_ckpt, output_file)
|
1043 |
|
1044 |
+
return key_count
|
|
|
|
|
|
|
|
|
|
|
1045 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1046 |
|
1047 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1048 |
+
if pretrained_model_name_or_path is None:
|
1049 |
+
# load default settings for v1/v2
|
1050 |
+
if v2:
|
1051 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1052 |
+
else:
|
1053 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1054 |
+
|
1055 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1056 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1057 |
+
if vae is None:
|
1058 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1059 |
+
|
1060 |
+
pipeline = StableDiffusionPipeline(
|
1061 |
+
unet=unet,
|
1062 |
+
text_encoder=text_encoder,
|
1063 |
+
vae=vae,
|
1064 |
+
scheduler=scheduler,
|
1065 |
+
tokenizer=tokenizer,
|
1066 |
+
safety_checker=None,
|
1067 |
+
feature_extractor=None,
|
1068 |
+
requires_safety_checker=None,
|
1069 |
+
)
|
1070 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1071 |
|
|
|
|
|
1072 |
|
1073 |
+
VAE_PREFIX = "first_stage_model."
|
|
|
1074 |
|
1075 |
|
1076 |
+
def load_vae(vae_id, dtype):
|
1077 |
+
print(f"load VAE: {vae_id}")
|
1078 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1079 |
+
# Diffusers local/remote
|
1080 |
+
try:
|
1081 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1082 |
+
except EnvironmentError as e:
|
1083 |
+
print(f"exception occurs in loading vae: {e}")
|
1084 |
+
print("retry with subfolder='vae'")
|
1085 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1086 |
+
return vae
|
1087 |
+
|
1088 |
+
# local
|
1089 |
+
vae_config = create_vae_diffusers_config()
|
1090 |
+
|
1091 |
+
if vae_id.endswith(".bin"):
|
1092 |
+
# SD 1.5 VAE on Huggingface
|
1093 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1094 |
+
else:
|
1095 |
+
# StableDiffusion
|
1096 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
1097 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
1098 |
+
|
1099 |
+
# vae only or full model
|
1100 |
+
full_model = False
|
1101 |
+
for vae_key in vae_sd:
|
1102 |
+
if vae_key.startswith(VAE_PREFIX):
|
1103 |
+
full_model = True
|
1104 |
+
break
|
1105 |
+
if not full_model:
|
1106 |
+
sd = {}
|
1107 |
+
for key, value in vae_sd.items():
|
1108 |
+
sd[VAE_PREFIX + key] = value
|
1109 |
+
vae_sd = sd
|
1110 |
+
del sd
|
1111 |
+
|
1112 |
+
# Convert the VAE model.
|
1113 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1114 |
+
|
1115 |
+
vae = AutoencoderKL(**vae_config)
|
1116 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1117 |
+
return vae
|
1118 |
|
1119 |
+
|
1120 |
+
# endregion
|
1121 |
+
|
1122 |
+
|
1123 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1124 |
+
max_width, max_height = max_reso
|
1125 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1126 |
+
|
1127 |
+
resos = set()
|
1128 |
+
|
1129 |
+
size = int(math.sqrt(max_area)) * divisible
|
1130 |
+
resos.add((size, size))
|
1131 |
+
|
1132 |
+
size = min_size
|
1133 |
+
while size <= max_size:
|
1134 |
+
width = size
|
1135 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1136 |
+
resos.add((width, height))
|
1137 |
+
resos.add((height, width))
|
1138 |
+
|
1139 |
+
# # make additional resos
|
1140 |
+
# if width >= height and width - divisible >= min_size:
|
1141 |
+
# resos.add((width - divisible, height))
|
1142 |
+
# resos.add((height, width - divisible))
|
1143 |
+
# if height >= width and height - divisible >= min_size:
|
1144 |
+
# resos.add((width, height - divisible))
|
1145 |
+
# resos.add((height - divisible, width))
|
1146 |
+
|
1147 |
+
size += divisible
|
1148 |
+
|
1149 |
+
resos = list(resos)
|
1150 |
+
resos.sort()
|
1151 |
+
return resos
|
1152 |
+
|
1153 |
+
|
1154 |
+
if __name__ == "__main__":
|
1155 |
+
resos = make_bucket_resolutions((512, 768))
|
1156 |
+
print(len(resos))
|
1157 |
+
print(resos)
|
1158 |
+
aspect_ratios = [w / h for w, h in resos]
|
1159 |
+
print(aspect_ratios)
|
1160 |
+
|
1161 |
+
ars = set()
|
1162 |
+
for ar in aspect_ratios:
|
1163 |
+
if ar in ars:
|
1164 |
+
print("error! duplicate ar:", ar)
|
1165 |
+
ars.add(ar)
|
lycoris/locon.py
CHANGED
@@ -16,7 +16,8 @@ class LoConModule(nn.Module):
|
|
16 |
multiplier=1.0,
|
17 |
lora_dim=4, alpha=1,
|
18 |
dropout=0.,
|
19 |
-
use_cp=
|
|
|
20 |
):
|
21 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
super().__init__()
|
|
|
16 |
multiplier=1.0,
|
17 |
lora_dim=4, alpha=1,
|
18 |
dropout=0.,
|
19 |
+
use_cp=False,
|
20 |
+
**kwargs,
|
21 |
):
|
22 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
23 |
super().__init__()
|
lycoris/loha.py
CHANGED
@@ -92,7 +92,8 @@ class LohaModule(nn.Module):
|
|
92 |
lora_name,
|
93 |
org_module: nn.Module,
|
94 |
multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
|
95 |
-
use_cp=
|
|
|
96 |
):
|
97 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
98 |
super().__init__()
|
|
|
92 |
lora_name,
|
93 |
org_module: nn.Module,
|
94 |
multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
|
95 |
+
use_cp=False,
|
96 |
+
**kwargs,
|
97 |
):
|
98 |
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
99 |
super().__init__()
|
lycoris/lokr.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# 4, build custom backward function
|
8 |
+
# -
|
9 |
+
|
10 |
+
|
11 |
+
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
12 |
+
'''
|
13 |
+
return a tuple of two value of input dimension decomposed by the number closest to factor
|
14 |
+
second value is higher or equal than first value.
|
15 |
+
|
16 |
+
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
17 |
+
secon value is a value for weight.
|
18 |
+
|
19 |
+
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
20 |
+
|
21 |
+
examples)
|
22 |
+
factor
|
23 |
+
-1 2 4 8 16 ...
|
24 |
+
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
|
25 |
+
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
|
26 |
+
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
|
27 |
+
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
|
28 |
+
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
29 |
+
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
30 |
+
'''
|
31 |
+
|
32 |
+
if factor > 0 and (dimension % factor) == 0:
|
33 |
+
m = factor
|
34 |
+
n = dimension // factor
|
35 |
+
return m, n
|
36 |
+
if factor == -1:
|
37 |
+
factor = dimension
|
38 |
+
m, n = 1, dimension
|
39 |
+
length = m + n
|
40 |
+
while m<n:
|
41 |
+
new_m = m + 1
|
42 |
+
while dimension%new_m != 0:
|
43 |
+
new_m += 1
|
44 |
+
new_n = dimension // new_m
|
45 |
+
if new_m + new_n > length or new_m>factor:
|
46 |
+
break
|
47 |
+
else:
|
48 |
+
m, n = new_m, new_n
|
49 |
+
if m > n:
|
50 |
+
n, m = m, n
|
51 |
+
return m, n
|
52 |
+
|
53 |
+
|
54 |
+
def make_weight_cp(t, wa, wb):
|
55 |
+
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
56 |
+
return rebuild2
|
57 |
+
|
58 |
+
|
59 |
+
def make_kron(orig_weight, w1, w2, scale):
|
60 |
+
if len(w2.shape) == 4:
|
61 |
+
w1 = w1.unsqueeze(2).unsqueeze(2)
|
62 |
+
w2 = w2.contiguous()
|
63 |
+
return orig_weight + torch.kron(w1, w2).reshape(orig_weight.shape)*scale
|
64 |
+
|
65 |
+
|
66 |
+
class LokrModule(nn.Module):
|
67 |
+
"""
|
68 |
+
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
69 |
+
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
70 |
+
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
lora_name, org_module: nn.Module,
|
76 |
+
multiplier=1.0,
|
77 |
+
lora_dim=4, alpha=1,
|
78 |
+
dropout=0.,
|
79 |
+
use_cp=False,
|
80 |
+
decompose_both = False,
|
81 |
+
factor:int=-1, # factorization factor
|
82 |
+
**kwargs,
|
83 |
+
):
|
84 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
85 |
+
super().__init__()
|
86 |
+
factor = int(factor)
|
87 |
+
self.lora_name = lora_name
|
88 |
+
self.lora_dim = lora_dim
|
89 |
+
self.cp = False
|
90 |
+
self.use_w1 = False
|
91 |
+
self.use_w2 = False
|
92 |
+
|
93 |
+
self.shape = org_module.weight.shape
|
94 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
95 |
+
in_dim = org_module.in_channels
|
96 |
+
k_size = org_module.kernel_size
|
97 |
+
out_dim = org_module.out_channels
|
98 |
+
|
99 |
+
in_m, in_n = factorization(in_dim, factor)
|
100 |
+
out_l, out_k = factorization(out_dim, factor)
|
101 |
+
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
102 |
+
|
103 |
+
self.cp = use_cp and k_size!=(1, 1)
|
104 |
+
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
105 |
+
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
106 |
+
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
107 |
+
else:
|
108 |
+
self.use_w1 = True
|
109 |
+
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
110 |
+
|
111 |
+
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
112 |
+
self.use_w2 = True
|
113 |
+
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
114 |
+
elif self.cp:
|
115 |
+
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
116 |
+
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
117 |
+
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
118 |
+
else: # Conv2d not cp
|
119 |
+
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
120 |
+
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
121 |
+
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
122 |
+
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
123 |
+
|
124 |
+
self.op = F.conv2d
|
125 |
+
self.extra_args = {
|
126 |
+
"stride": org_module.stride,
|
127 |
+
"padding": org_module.padding,
|
128 |
+
"dilation": org_module.dilation,
|
129 |
+
"groups": org_module.groups
|
130 |
+
}
|
131 |
+
|
132 |
+
else: # Linear
|
133 |
+
in_dim = org_module.in_features
|
134 |
+
out_dim = org_module.out_features
|
135 |
+
|
136 |
+
in_m, in_n = factorization(in_dim, factor)
|
137 |
+
out_l, out_k = factorization(out_dim, factor)
|
138 |
+
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
139 |
+
|
140 |
+
# smaller part. weight scale
|
141 |
+
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
142 |
+
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
143 |
+
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
144 |
+
else:
|
145 |
+
self.use_w1 = True
|
146 |
+
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
147 |
+
|
148 |
+
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
149 |
+
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
150 |
+
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
151 |
+
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
152 |
+
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
153 |
+
else:
|
154 |
+
self.use_w2 = True
|
155 |
+
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
156 |
+
|
157 |
+
self.op = F.linear
|
158 |
+
self.extra_args = {}
|
159 |
+
|
160 |
+
if dropout:
|
161 |
+
self.dropout = nn.Dropout(dropout)
|
162 |
+
else:
|
163 |
+
self.dropout = nn.Identity()
|
164 |
+
|
165 |
+
if isinstance(alpha, torch.Tensor):
|
166 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
167 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
168 |
+
if self.use_w2 and self.use_w1:
|
169 |
+
#use scale = 1
|
170 |
+
alpha = lora_dim
|
171 |
+
self.scale = alpha / self.lora_dim
|
172 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
173 |
+
|
174 |
+
if self.use_w2:
|
175 |
+
torch.nn.init.constant_(self.lokr_w2, 0)
|
176 |
+
else:
|
177 |
+
if self.cp:
|
178 |
+
torch.nn.init.normal_(self.lokr_t2, std=0.1)
|
179 |
+
torch.nn.init.normal_(self.lokr_w2_a, std=1)
|
180 |
+
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
181 |
+
|
182 |
+
if self.use_w1:
|
183 |
+
torch.nn.init.normal_(self.lokr_w1, std=1)
|
184 |
+
else:
|
185 |
+
torch.nn.init.normal_(self.lokr_w1_a, std=1)
|
186 |
+
torch.nn.init.normal_(self.lokr_w1_b, std=0.1)
|
187 |
+
|
188 |
+
self.multiplier = multiplier
|
189 |
+
self.org_module = [org_module]
|
190 |
+
weight = make_kron(
|
191 |
+
self.org_module[0].weight.data,
|
192 |
+
self.lokr_w1 if self.use_w1 else [email protected]_w1_b,
|
193 |
+
(self.lokr_w2 if self.use_w2
|
194 |
+
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
195 |
+
else [email protected]_w2_b),
|
196 |
+
torch.tensor(self.multiplier * self.scale)
|
197 |
+
)
|
198 |
+
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
|
199 |
+
|
200 |
+
# Same as locon.py
|
201 |
+
def apply_to(self):
|
202 |
+
self.org_forward = self.org_module[0].forward
|
203 |
+
self.org_module[0].forward = self.forward
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
weight = make_kron(
|
207 |
+
self.org_module[0].weight.data,
|
208 |
+
self.lokr_w1 if self.use_w1 else [email protected]_w1_b,
|
209 |
+
(self.lokr_w2 if self.use_w2
|
210 |
+
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
211 |
+
else [email protected]_w2_b),
|
212 |
+
torch.tensor(self.multiplier * self.scale)
|
213 |
+
)
|
214 |
+
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
215 |
+
return self.op(
|
216 |
+
x,
|
217 |
+
weight.view(self.shape),
|
218 |
+
bias,
|
219 |
+
**self.extra_args
|
220 |
+
)
|
lycoris/utils.py
CHANGED
@@ -24,6 +24,7 @@ def extract_conv(
|
|
24 |
mode = 'fixed',
|
25 |
mode_param = 0,
|
26 |
device = 'cpu',
|
|
|
27 |
) -> Tuple[nn.Parameter, nn.Parameter]:
|
28 |
weight = weight.to(device)
|
29 |
out_ch, in_ch, kernel_size, _ = weight.shape
|
@@ -48,6 +49,8 @@ def extract_conv(
|
|
48 |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
49 |
lora_rank = max(1, lora_rank)
|
50 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
|
|
|
|
51 |
|
52 |
U = U[:, :lora_rank]
|
53 |
S = S[:lora_rank]
|
@@ -58,29 +61,7 @@ def extract_conv(
|
|
58 |
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
59 |
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
60 |
del U, S, Vh, weight
|
61 |
-
return extract_weight_A, extract_weight_B, diff
|
62 |
-
|
63 |
-
|
64 |
-
def merge_conv(
|
65 |
-
weight_a: Union[torch.Tensor, nn.Parameter],
|
66 |
-
weight_b: Union[torch.Tensor, nn.Parameter],
|
67 |
-
device = 'cpu'
|
68 |
-
):
|
69 |
-
rank, in_ch, kernel_size, k_ = weight_a.shape
|
70 |
-
out_ch, rank_, _, _ = weight_b.shape
|
71 |
-
assert rank == rank_ and kernel_size == k_
|
72 |
-
|
73 |
-
wa = weight_a.to(device)
|
74 |
-
wb = weight_b.to(device)
|
75 |
-
|
76 |
-
if device == 'cpu':
|
77 |
-
wa = wa.float()
|
78 |
-
wb = wb.float()
|
79 |
-
|
80 |
-
merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
|
81 |
-
weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
|
82 |
-
del wb, wa
|
83 |
-
return weight
|
84 |
|
85 |
|
86 |
def extract_linear(
|
@@ -112,6 +93,8 @@ def extract_linear(
|
|
112 |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
113 |
lora_rank = max(1, lora_rank)
|
114 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
|
|
|
|
115 |
|
116 |
U = U[:, :lora_rank]
|
117 |
S = S[:lora_rank]
|
@@ -122,28 +105,7 @@ def extract_linear(
|
|
122 |
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
123 |
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
124 |
del U, S, Vh, weight
|
125 |
-
return extract_weight_A, extract_weight_B, diff
|
126 |
-
|
127 |
-
|
128 |
-
def merge_linear(
|
129 |
-
weight_a: Union[torch.Tensor, nn.Parameter],
|
130 |
-
weight_b: Union[torch.Tensor, nn.Parameter],
|
131 |
-
device = 'cpu'
|
132 |
-
):
|
133 |
-
rank, in_ch = weight_a.shape
|
134 |
-
out_ch, rank_ = weight_b.shape
|
135 |
-
assert rank == rank_
|
136 |
-
|
137 |
-
wa = weight_a.to(device)
|
138 |
-
wb = weight_b.to(device)
|
139 |
-
|
140 |
-
if device == 'cpu':
|
141 |
-
wa = wa.float()
|
142 |
-
wb = wb.float()
|
143 |
-
|
144 |
-
weight = wb @ wa
|
145 |
-
del wb, wa
|
146 |
-
return weight
|
147 |
|
148 |
|
149 |
def extract_diff(
|
@@ -200,30 +162,38 @@ def extract_diff(
|
|
200 |
for child_name, child_module in module.named_modules():
|
201 |
lora_name = prefix + '.' + name + '.' + child_name
|
202 |
lora_name = lora_name.replace('.', '_')
|
203 |
-
|
204 |
layer = child_module.__class__.__name__
|
|
|
|
|
|
|
|
|
|
|
205 |
if layer == 'Linear':
|
206 |
-
|
207 |
(child_module.weight - weights[child_name]),
|
208 |
mode,
|
209 |
linear_mode_param,
|
210 |
device = extract_device,
|
211 |
)
|
|
|
|
|
212 |
elif layer == 'Conv2d':
|
213 |
is_linear = (child_module.weight.shape[2] == 1
|
214 |
and child_module.weight.shape[3] == 1)
|
215 |
-
|
216 |
(child_module.weight - weights[child_name]),
|
217 |
mode,
|
218 |
linear_mode_param if is_linear else conv_mode_param,
|
219 |
device = extract_device,
|
220 |
)
|
221 |
-
if
|
|
|
|
|
222 |
dim = extract_a.size(0)
|
223 |
-
extract_c, extract_a, _ = extract_conv(
|
224 |
extract_a.transpose(0, 1),
|
225 |
'fixed', dim,
|
226 |
-
extract_device
|
227 |
)
|
228 |
extract_a = extract_a.transpose(0, 1)
|
229 |
extract_c = extract_c.transpose(0, 1)
|
@@ -235,77 +205,92 @@ def extract_diff(
|
|
235 |
del extract_c
|
236 |
else:
|
237 |
continue
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
252 |
elif name in temp_name:
|
253 |
-
|
254 |
lora_name = prefix + '.' + name
|
255 |
lora_name = lora_name.replace('.', '_')
|
|
|
256 |
|
257 |
-
if
|
258 |
-
|
259 |
-
|
|
|
260 |
|
261 |
-
layer = module.__class__.__name__
|
262 |
if layer == 'Linear':
|
263 |
-
|
264 |
-
(
|
265 |
mode,
|
266 |
linear_mode_param,
|
267 |
device = extract_device,
|
268 |
)
|
|
|
|
|
269 |
elif layer == 'Conv2d':
|
270 |
-
is_linear = (
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
274 |
mode,
|
275 |
linear_mode_param if is_linear else conv_mode_param,
|
276 |
device = extract_device,
|
277 |
)
|
278 |
-
if
|
|
|
|
|
279 |
dim = extract_a.size(0)
|
280 |
-
extract_c, extract_a, _ = extract_conv(
|
281 |
extract_a.transpose(0, 1),
|
282 |
'fixed', dim,
|
283 |
-
extract_device
|
284 |
)
|
285 |
extract_a = extract_a.transpose(0, 1)
|
286 |
extract_c = extract_c.transpose(0, 1)
|
287 |
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
288 |
-
diff =
|
289 |
'i j k l, j r, p i -> p r k l',
|
290 |
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
291 |
).detach().cpu().contiguous()
|
292 |
del extract_c
|
293 |
else:
|
294 |
continue
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
309 |
return loras
|
310 |
|
311 |
text_encoder_loras = make_state_dict(
|
@@ -324,70 +309,125 @@ def extract_diff(
|
|
324 |
return text_encoder_loras|unet_loras
|
325 |
|
326 |
|
327 |
-
def
|
328 |
-
|
329 |
-
|
330 |
-
scale: float = 1.0,
|
331 |
-
device = 'cpu'
|
332 |
):
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
-
|
377 |
-
LORA_PREFIX_TEXT_ENCODER,
|
378 |
-
base_model[0],
|
379 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
380 |
-
)
|
381 |
-
merge(
|
382 |
-
LORA_PREFIX_UNET,
|
383 |
-
base_model[2],
|
384 |
-
UNET_TARGET_REPLACE_MODULE
|
385 |
-
)
|
386 |
|
387 |
|
388 |
-
def
|
389 |
base_model,
|
390 |
-
|
391 |
scale: float = 1.0,
|
392 |
device = 'cpu'
|
393 |
):
|
@@ -398,51 +438,67 @@ def merge_loha(
|
|
398 |
"Downsample2D",
|
399 |
"Upsample2D"
|
400 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
402 |
LORA_PREFIX_UNET = 'lora_unet'
|
403 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
404 |
-
|
|
|
405 |
prefix,
|
406 |
root_module: torch.nn.Module,
|
407 |
-
|
|
|
|
|
408 |
):
|
409 |
-
|
410 |
-
|
411 |
-
for name, module in tqdm(list(root_module.named_modules())):
|
412 |
if module.__class__.__name__ in target_replace_modules:
|
413 |
-
temp[name] = {}
|
414 |
for child_name, child_module in module.named_modules():
|
415 |
-
|
416 |
-
if layer not in {'Linear', 'Conv2d'}:
|
417 |
continue
|
418 |
lora_name = prefix + '.' + name + '.' + child_name
|
419 |
lora_name = lora_name.replace('.', '_')
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
)
|
444 |
-
|
445 |
LORA_PREFIX_UNET,
|
446 |
-
base_model[2],
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
24 |
mode = 'fixed',
|
25 |
mode_param = 0,
|
26 |
device = 'cpu',
|
27 |
+
is_cp = False,
|
28 |
) -> Tuple[nn.Parameter, nn.Parameter]:
|
29 |
weight = weight.to(device)
|
30 |
out_ch, in_ch, kernel_size, _ = weight.shape
|
|
|
49 |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
50 |
lora_rank = max(1, lora_rank)
|
51 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
52 |
+
if lora_rank>=out_ch/2 and not is_cp:
|
53 |
+
return weight, 'full'
|
54 |
|
55 |
U = U[:, :lora_rank]
|
56 |
S = S[:lora_rank]
|
|
|
61 |
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
62 |
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
63 |
del U, S, Vh, weight
|
64 |
+
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
def extract_linear(
|
|
|
93 |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
94 |
lora_rank = max(1, lora_rank)
|
95 |
lora_rank = min(out_ch, in_ch, lora_rank)
|
96 |
+
if lora_rank>=out_ch/2:
|
97 |
+
return weight, 'full'
|
98 |
|
99 |
U = U[:, :lora_rank]
|
100 |
S = S[:lora_rank]
|
|
|
105 |
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
106 |
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
107 |
del U, S, Vh, weight
|
108 |
+
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
def extract_diff(
|
|
|
162 |
for child_name, child_module in module.named_modules():
|
163 |
lora_name = prefix + '.' + name + '.' + child_name
|
164 |
lora_name = lora_name.replace('.', '_')
|
|
|
165 |
layer = child_module.__class__.__name__
|
166 |
+
if layer in {'Linear', 'Conv2d'}:
|
167 |
+
root_weight = child_module.weight
|
168 |
+
if torch.allclose(root_weight, weights[child_name]):
|
169 |
+
continue
|
170 |
+
|
171 |
if layer == 'Linear':
|
172 |
+
weight, decompose_mode = extract_linear(
|
173 |
(child_module.weight - weights[child_name]),
|
174 |
mode,
|
175 |
linear_mode_param,
|
176 |
device = extract_device,
|
177 |
)
|
178 |
+
if decompose_mode == 'low rank':
|
179 |
+
extract_a, extract_b, diff = weight
|
180 |
elif layer == 'Conv2d':
|
181 |
is_linear = (child_module.weight.shape[2] == 1
|
182 |
and child_module.weight.shape[3] == 1)
|
183 |
+
weight, decompose_mode = extract_conv(
|
184 |
(child_module.weight - weights[child_name]),
|
185 |
mode,
|
186 |
linear_mode_param if is_linear else conv_mode_param,
|
187 |
device = extract_device,
|
188 |
)
|
189 |
+
if decompose_mode == 'low rank':
|
190 |
+
extract_a, extract_b, diff = weight
|
191 |
+
if small_conv and not is_linear and decompose_mode == 'low rank':
|
192 |
dim = extract_a.size(0)
|
193 |
+
(extract_c, extract_a, _), _ = extract_conv(
|
194 |
extract_a.transpose(0, 1),
|
195 |
'fixed', dim,
|
196 |
+
extract_device, True
|
197 |
)
|
198 |
extract_a = extract_a.transpose(0, 1)
|
199 |
extract_c = extract_c.transpose(0, 1)
|
|
|
205 |
del extract_c
|
206 |
else:
|
207 |
continue
|
208 |
+
if decompose_mode == 'low rank':
|
209 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
210 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
211 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
212 |
+
if use_bias:
|
213 |
+
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
214 |
+
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
215 |
+
|
216 |
+
indices = sparse_diff.indices().to(torch.int16)
|
217 |
+
values = sparse_diff.values().half()
|
218 |
+
loras[f'{lora_name}.bias_indices'] = indices
|
219 |
+
loras[f'{lora_name}.bias_values'] = values
|
220 |
+
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
221 |
+
del extract_a, extract_b, diff
|
222 |
+
elif decompose_mode == 'full':
|
223 |
+
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
224 |
+
else:
|
225 |
+
raise NotImplementedError
|
226 |
elif name in temp_name:
|
227 |
+
weights = temp_name[name]
|
228 |
lora_name = prefix + '.' + name
|
229 |
lora_name = lora_name.replace('.', '_')
|
230 |
+
layer = module.__class__.__name__
|
231 |
|
232 |
+
if layer in {'Linear', 'Conv2d'}:
|
233 |
+
root_weight = module.weight
|
234 |
+
if torch.allclose(root_weight, weights):
|
235 |
+
continue
|
236 |
|
|
|
237 |
if layer == 'Linear':
|
238 |
+
weight, decompose_mode = extract_linear(
|
239 |
+
(root_weight - weights),
|
240 |
mode,
|
241 |
linear_mode_param,
|
242 |
device = extract_device,
|
243 |
)
|
244 |
+
if decompose_mode == 'low rank':
|
245 |
+
extract_a, extract_b, diff = weight
|
246 |
elif layer == 'Conv2d':
|
247 |
+
is_linear = (
|
248 |
+
root_weight.shape[2] == 1
|
249 |
+
and root_weight.shape[3] == 1
|
250 |
+
)
|
251 |
+
weight, decompose_mode = extract_conv(
|
252 |
+
(root_weight - weights),
|
253 |
mode,
|
254 |
linear_mode_param if is_linear else conv_mode_param,
|
255 |
device = extract_device,
|
256 |
)
|
257 |
+
if decompose_mode == 'low rank':
|
258 |
+
extract_a, extract_b, diff = weight
|
259 |
+
if small_conv and not is_linear and decompose_mode == 'low rank':
|
260 |
dim = extract_a.size(0)
|
261 |
+
(extract_c, extract_a, _), _ = extract_conv(
|
262 |
extract_a.transpose(0, 1),
|
263 |
'fixed', dim,
|
264 |
+
extract_device, True
|
265 |
)
|
266 |
extract_a = extract_a.transpose(0, 1)
|
267 |
extract_c = extract_c.transpose(0, 1)
|
268 |
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
269 |
+
diff = root_weight - torch.einsum(
|
270 |
'i j k l, j r, p i -> p r k l',
|
271 |
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
272 |
).detach().cpu().contiguous()
|
273 |
del extract_c
|
274 |
else:
|
275 |
continue
|
276 |
+
if decompose_mode == 'low rank':
|
277 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
278 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
279 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
280 |
+
if use_bias:
|
281 |
+
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
282 |
+
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
283 |
+
|
284 |
+
indices = sparse_diff.indices().to(torch.int16)
|
285 |
+
values = sparse_diff.values().half()
|
286 |
+
loras[f'{lora_name}.bias_indices'] = indices
|
287 |
+
loras[f'{lora_name}.bias_values'] = values
|
288 |
+
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
289 |
+
del extract_a, extract_b, diff
|
290 |
+
elif decompose_mode == 'full':
|
291 |
+
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
292 |
+
else:
|
293 |
+
raise NotImplementedError
|
294 |
return loras
|
295 |
|
296 |
text_encoder_loras = make_state_dict(
|
|
|
309 |
return text_encoder_loras|unet_loras
|
310 |
|
311 |
|
312 |
+
def get_module(
|
313 |
+
lyco_state_dict: Dict,
|
314 |
+
lora_name
|
|
|
|
|
315 |
):
|
316 |
+
if f'{lora_name}.lora_up.weight' in lyco_state_dict:
|
317 |
+
up = lyco_state_dict[f'{lora_name}.lora_up.weight']
|
318 |
+
down = lyco_state_dict[f'{lora_name}.lora_down.weight']
|
319 |
+
mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
|
320 |
+
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
321 |
+
return 'locon', (up, down, mid, alpha)
|
322 |
+
elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
|
323 |
+
w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
|
324 |
+
w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
|
325 |
+
w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
|
326 |
+
w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
|
327 |
+
t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
|
328 |
+
t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
|
329 |
+
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
330 |
+
return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
|
331 |
+
elif f'{lora_name}.weight' in lyco_state_dict:
|
332 |
+
weight = lyco_state_dict[f'{lora_name}.weight']
|
333 |
+
on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
|
334 |
+
return 'ia3', (weight, on_input)
|
335 |
+
elif (f'{lora_name}.lokr_w1' in lyco_state_dict
|
336 |
+
or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
|
337 |
+
w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
|
338 |
+
w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
|
339 |
+
w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
|
340 |
+
w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
|
341 |
+
w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
|
342 |
+
w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
|
343 |
+
t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
|
344 |
+
t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
|
345 |
+
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
346 |
+
return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
|
347 |
+
elif f'{lora_name}.diff' in lyco_state_dict:
|
348 |
+
return 'full', lyco_state_dict[f'{lora_name}.diff']
|
349 |
+
else:
|
350 |
+
return 'None', ()
|
351 |
+
|
352 |
+
|
353 |
+
def cp_weight_from_conv(
|
354 |
+
up, down, mid
|
355 |
+
):
|
356 |
+
up = up.reshape(up.size(0), up.size(1))
|
357 |
+
down = down.reshape(down.size(0), down.size(1))
|
358 |
+
return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
|
359 |
+
|
360 |
+
def cp_weight(
|
361 |
+
wa, wb, t
|
362 |
+
):
|
363 |
+
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
364 |
+
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
365 |
+
|
366 |
+
|
367 |
+
@torch.no_grad()
|
368 |
+
def rebuild_weight(module_type, params, orig_weight, scale=1):
|
369 |
+
if orig_weight is None:
|
370 |
+
return orig_weight
|
371 |
+
merged = orig_weight
|
372 |
+
if module_type == 'locon':
|
373 |
+
up, down, mid, alpha = params
|
374 |
+
if alpha is not None:
|
375 |
+
scale *= alpha/up.size(1)
|
376 |
+
if mid is not None:
|
377 |
+
rebuild = cp_weight_from_conv(up, down, mid)
|
378 |
+
else:
|
379 |
+
rebuild = up.reshape(up.size(0),-1) @ down.reshape(down.size(0), -1)
|
380 |
+
merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
|
381 |
+
del up, down, mid, alpha, params, rebuild
|
382 |
+
elif module_type == 'hada':
|
383 |
+
w1a, w1b, w2a, w2b, t1, t2, alpha = params
|
384 |
+
if alpha is not None:
|
385 |
+
scale *= alpha / w1b.size(0)
|
386 |
+
if t1 is not None:
|
387 |
+
rebuild1 = cp_weight(w1a, w1b, t1)
|
388 |
+
else:
|
389 |
+
rebuild1 = w1a @ w1b
|
390 |
+
if t2 is not None:
|
391 |
+
rebuild2 = cp_weight(w2a, w2b, t2)
|
392 |
+
else:
|
393 |
+
rebuild2 = w2a @ w2b
|
394 |
+
rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
|
395 |
+
merged = orig_weight + rebuild * scale
|
396 |
+
del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
|
397 |
+
elif module_type == 'ia3':
|
398 |
+
weight, on_input = params
|
399 |
+
if not on_input:
|
400 |
+
weight = weight.reshape(-1, 1)
|
401 |
+
merged = orig_weight + weight * orig_weight * scale
|
402 |
+
del weight, on_input, params
|
403 |
+
elif module_type == 'kron':
|
404 |
+
w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
|
405 |
+
if alpha is not None and (w1b is not None or w2b is not None):
|
406 |
+
scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
|
407 |
+
if w1a is not None and w1b is not None:
|
408 |
+
if t1:
|
409 |
+
w1 = cp_weight(w1a, w1b, t1)
|
410 |
+
else:
|
411 |
+
w1 = w1a @ w1b
|
412 |
+
if w2a is not None and w2b is not None:
|
413 |
+
if t2:
|
414 |
+
w2 = cp_weight(w2a, w2b, t2)
|
415 |
+
else:
|
416 |
+
w2 = w2a @ w2b
|
417 |
+
rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
|
418 |
+
merged = orig_weight + rebuild* scale
|
419 |
+
del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
|
420 |
+
elif module_type == 'full':
|
421 |
+
rebuild = params.reshape(orig_weight.shape)
|
422 |
+
merged = orig_weight + rebuild * scale
|
423 |
+
del params, rebuild
|
424 |
|
425 |
+
return merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
|
428 |
+
def merge(
|
429 |
base_model,
|
430 |
+
lyco_state_dict,
|
431 |
scale: float = 1.0,
|
432 |
device = 'cpu'
|
433 |
):
|
|
|
438 |
"Downsample2D",
|
439 |
"Upsample2D"
|
440 |
]
|
441 |
+
UNET_TARGET_REPLACE_NAME = [
|
442 |
+
"conv_in",
|
443 |
+
"conv_out",
|
444 |
+
"time_embedding.linear_1",
|
445 |
+
"time_embedding.linear_2",
|
446 |
+
]
|
447 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
448 |
LORA_PREFIX_UNET = 'lora_unet'
|
449 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
450 |
+
merged = 0
|
451 |
+
def merge_state_dict(
|
452 |
prefix,
|
453 |
root_module: torch.nn.Module,
|
454 |
+
lyco_state_dict: Dict[str,torch.Tensor],
|
455 |
+
target_replace_modules,
|
456 |
+
target_replace_names = []
|
457 |
):
|
458 |
+
nonlocal merged
|
459 |
+
for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
|
|
|
460 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
461 |
for child_name, child_module in module.named_modules():
|
462 |
+
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
|
|
463 |
continue
|
464 |
lora_name = prefix + '.' + name + '.' + child_name
|
465 |
lora_name = lora_name.replace('.', '_')
|
466 |
|
467 |
+
result = rebuild_weight(*get_module(
|
468 |
+
lyco_state_dict, lora_name
|
469 |
+
), getattr(child_module, 'weight'), scale)
|
470 |
+
if result is not None:
|
471 |
+
merged += 1
|
472 |
+
child_module.requires_grad_(False)
|
473 |
+
child_module.weight.copy_(result)
|
474 |
+
elif name in target_replace_names:
|
475 |
+
lora_name = prefix + '.' + name
|
476 |
+
lora_name = lora_name.replace('.', '_')
|
477 |
|
478 |
+
result = rebuild_weight(*get_module(
|
479 |
+
lyco_state_dict, lora_name
|
480 |
+
), getattr(module, 'weight'), scale)
|
481 |
+
if result is not None:
|
482 |
+
merged += 1
|
483 |
+
module.requires_grad_(False)
|
484 |
+
module.weight.copy_(result)
|
485 |
|
486 |
+
if device == 'cpu':
|
487 |
+
for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
|
488 |
+
lyco_state_dict[k] = v.float()
|
489 |
+
|
490 |
+
merge_state_dict(
|
491 |
+
LORA_PREFIX_TEXT_ENCODER,
|
492 |
+
base_model[0],
|
493 |
+
lyco_state_dict,
|
494 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE,
|
495 |
+
UNET_TARGET_REPLACE_NAME
|
496 |
)
|
497 |
+
merge_state_dict(
|
498 |
LORA_PREFIX_UNET,
|
499 |
+
base_model[2],
|
500 |
+
lyco_state_dict,
|
501 |
+
UNET_TARGET_REPLACE_MODULE,
|
502 |
+
UNET_TARGET_REPLACE_NAME
|
503 |
+
)
|
504 |
+
print(f'{merged} Modules been merged')
|