abc
commited on
Commit
·
8999a58
1
Parent(s):
497a25a
Delete locon
Browse files- locon/__init__.py +0 -0
- locon/kohya_model_utils.py +0 -1184
- locon/kohya_utils.py +0 -48
- locon/locon.py +0 -53
- locon/locon_kohya.py +0 -243
- locon/utils.py +0 -148
locon/__init__.py
DELETED
File without changes
|
locon/kohya_model_utils.py
DELETED
@@ -1,1184 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
3 |
-
'''
|
4 |
-
# v1: split from train_db_fixed.py.
|
5 |
-
# v2: support safetensors
|
6 |
-
|
7 |
-
import math
|
8 |
-
import os
|
9 |
-
import torch
|
10 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
11 |
-
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
12 |
-
from safetensors.torch import load_file, save_file
|
13 |
-
|
14 |
-
# DiffUsers版StableDiffusionのモデルパラメータ
|
15 |
-
NUM_TRAIN_TIMESTEPS = 1000
|
16 |
-
BETA_START = 0.00085
|
17 |
-
BETA_END = 0.0120
|
18 |
-
|
19 |
-
UNET_PARAMS_MODEL_CHANNELS = 320
|
20 |
-
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
21 |
-
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
22 |
-
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
23 |
-
UNET_PARAMS_IN_CHANNELS = 4
|
24 |
-
UNET_PARAMS_OUT_CHANNELS = 4
|
25 |
-
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
26 |
-
UNET_PARAMS_CONTEXT_DIM = 768
|
27 |
-
UNET_PARAMS_NUM_HEADS = 8
|
28 |
-
|
29 |
-
VAE_PARAMS_Z_CHANNELS = 4
|
30 |
-
VAE_PARAMS_RESOLUTION = 256
|
31 |
-
VAE_PARAMS_IN_CHANNELS = 3
|
32 |
-
VAE_PARAMS_OUT_CH = 3
|
33 |
-
VAE_PARAMS_CH = 128
|
34 |
-
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
35 |
-
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
36 |
-
|
37 |
-
# V2
|
38 |
-
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
39 |
-
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
40 |
-
|
41 |
-
# Diffusersの設定を読み込むための参照モデル
|
42 |
-
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
43 |
-
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
44 |
-
|
45 |
-
|
46 |
-
# region StableDiffusion->Diffusersの変換コード
|
47 |
-
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
48 |
-
|
49 |
-
|
50 |
-
def shave_segments(path, n_shave_prefix_segments=1):
|
51 |
-
"""
|
52 |
-
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
53 |
-
"""
|
54 |
-
if n_shave_prefix_segments >= 0:
|
55 |
-
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
56 |
-
else:
|
57 |
-
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
58 |
-
|
59 |
-
|
60 |
-
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
61 |
-
"""
|
62 |
-
Updates paths inside resnets to the new naming scheme (local renaming)
|
63 |
-
"""
|
64 |
-
mapping = []
|
65 |
-
for old_item in old_list:
|
66 |
-
new_item = old_item.replace("in_layers.0", "norm1")
|
67 |
-
new_item = new_item.replace("in_layers.2", "conv1")
|
68 |
-
|
69 |
-
new_item = new_item.replace("out_layers.0", "norm2")
|
70 |
-
new_item = new_item.replace("out_layers.3", "conv2")
|
71 |
-
|
72 |
-
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
73 |
-
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
74 |
-
|
75 |
-
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
76 |
-
|
77 |
-
mapping.append({"old": old_item, "new": new_item})
|
78 |
-
|
79 |
-
return mapping
|
80 |
-
|
81 |
-
|
82 |
-
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
83 |
-
"""
|
84 |
-
Updates paths inside resnets to the new naming scheme (local renaming)
|
85 |
-
"""
|
86 |
-
mapping = []
|
87 |
-
for old_item in old_list:
|
88 |
-
new_item = old_item
|
89 |
-
|
90 |
-
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
91 |
-
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
92 |
-
|
93 |
-
mapping.append({"old": old_item, "new": new_item})
|
94 |
-
|
95 |
-
return mapping
|
96 |
-
|
97 |
-
|
98 |
-
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
99 |
-
"""
|
100 |
-
Updates paths inside attentions to the new naming scheme (local renaming)
|
101 |
-
"""
|
102 |
-
mapping = []
|
103 |
-
for old_item in old_list:
|
104 |
-
new_item = old_item
|
105 |
-
|
106 |
-
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
107 |
-
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
108 |
-
|
109 |
-
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
110 |
-
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
111 |
-
|
112 |
-
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
113 |
-
|
114 |
-
mapping.append({"old": old_item, "new": new_item})
|
115 |
-
|
116 |
-
return mapping
|
117 |
-
|
118 |
-
|
119 |
-
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
120 |
-
"""
|
121 |
-
Updates paths inside attentions to the new naming scheme (local renaming)
|
122 |
-
"""
|
123 |
-
mapping = []
|
124 |
-
for old_item in old_list:
|
125 |
-
new_item = old_item
|
126 |
-
|
127 |
-
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
128 |
-
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
129 |
-
|
130 |
-
new_item = new_item.replace("q.weight", "query.weight")
|
131 |
-
new_item = new_item.replace("q.bias", "query.bias")
|
132 |
-
|
133 |
-
new_item = new_item.replace("k.weight", "key.weight")
|
134 |
-
new_item = new_item.replace("k.bias", "key.bias")
|
135 |
-
|
136 |
-
new_item = new_item.replace("v.weight", "value.weight")
|
137 |
-
new_item = new_item.replace("v.bias", "value.bias")
|
138 |
-
|
139 |
-
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
140 |
-
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
141 |
-
|
142 |
-
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
143 |
-
|
144 |
-
mapping.append({"old": old_item, "new": new_item})
|
145 |
-
|
146 |
-
return mapping
|
147 |
-
|
148 |
-
|
149 |
-
def assign_to_checkpoint(
|
150 |
-
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
151 |
-
):
|
152 |
-
"""
|
153 |
-
This does the final conversion step: take locally converted weights and apply a global renaming
|
154 |
-
to them. It splits attention layers, and takes into account additional replacements
|
155 |
-
that may arise.
|
156 |
-
|
157 |
-
Assigns the weights to the new checkpoint.
|
158 |
-
"""
|
159 |
-
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
160 |
-
|
161 |
-
# Splits the attention layers into three variables.
|
162 |
-
if attention_paths_to_split is not None:
|
163 |
-
for path, path_map in attention_paths_to_split.items():
|
164 |
-
old_tensor = old_checkpoint[path]
|
165 |
-
channels = old_tensor.shape[0] // 3
|
166 |
-
|
167 |
-
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
168 |
-
|
169 |
-
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
170 |
-
|
171 |
-
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
172 |
-
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
173 |
-
|
174 |
-
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
175 |
-
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
176 |
-
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
177 |
-
|
178 |
-
for path in paths:
|
179 |
-
new_path = path["new"]
|
180 |
-
|
181 |
-
# These have already been assigned
|
182 |
-
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
183 |
-
continue
|
184 |
-
|
185 |
-
# Global renaming happens here
|
186 |
-
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
187 |
-
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
188 |
-
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
189 |
-
|
190 |
-
if additional_replacements is not None:
|
191 |
-
for replacement in additional_replacements:
|
192 |
-
new_path = new_path.replace(replacement["old"], replacement["new"])
|
193 |
-
|
194 |
-
# proj_attn.weight has to be converted from conv 1D to linear
|
195 |
-
if "proj_attn.weight" in new_path:
|
196 |
-
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
197 |
-
else:
|
198 |
-
checkpoint[new_path] = old_checkpoint[path["old"]]
|
199 |
-
|
200 |
-
|
201 |
-
def conv_attn_to_linear(checkpoint):
|
202 |
-
keys = list(checkpoint.keys())
|
203 |
-
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
204 |
-
for key in keys:
|
205 |
-
if ".".join(key.split(".")[-2:]) in attn_keys:
|
206 |
-
if checkpoint[key].ndim > 2:
|
207 |
-
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
208 |
-
elif "proj_attn.weight" in key:
|
209 |
-
if checkpoint[key].ndim > 2:
|
210 |
-
checkpoint[key] = checkpoint[key][:, :, 0]
|
211 |
-
|
212 |
-
|
213 |
-
def linear_transformer_to_conv(checkpoint):
|
214 |
-
keys = list(checkpoint.keys())
|
215 |
-
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
216 |
-
for key in keys:
|
217 |
-
if ".".join(key.split(".")[-2:]) in tf_keys:
|
218 |
-
if checkpoint[key].ndim == 2:
|
219 |
-
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
220 |
-
|
221 |
-
|
222 |
-
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
223 |
-
"""
|
224 |
-
Takes a state dict and a config, and returns a converted checkpoint.
|
225 |
-
"""
|
226 |
-
|
227 |
-
# extract state_dict for UNet
|
228 |
-
unet_state_dict = {}
|
229 |
-
unet_key = "model.diffusion_model."
|
230 |
-
keys = list(checkpoint.keys())
|
231 |
-
for key in keys:
|
232 |
-
if key.startswith(unet_key):
|
233 |
-
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
234 |
-
|
235 |
-
new_checkpoint = {}
|
236 |
-
|
237 |
-
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
238 |
-
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
239 |
-
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
240 |
-
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
241 |
-
|
242 |
-
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
243 |
-
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
244 |
-
|
245 |
-
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
246 |
-
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
247 |
-
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
248 |
-
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
249 |
-
|
250 |
-
# Retrieves the keys for the input blocks only
|
251 |
-
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
252 |
-
input_blocks = {
|
253 |
-
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
254 |
-
for layer_id in range(num_input_blocks)
|
255 |
-
}
|
256 |
-
|
257 |
-
# Retrieves the keys for the middle blocks only
|
258 |
-
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
259 |
-
middle_blocks = {
|
260 |
-
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
261 |
-
for layer_id in range(num_middle_blocks)
|
262 |
-
}
|
263 |
-
|
264 |
-
# Retrieves the keys for the output blocks only
|
265 |
-
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
266 |
-
output_blocks = {
|
267 |
-
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
268 |
-
for layer_id in range(num_output_blocks)
|
269 |
-
}
|
270 |
-
|
271 |
-
for i in range(1, num_input_blocks):
|
272 |
-
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
273 |
-
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
274 |
-
|
275 |
-
resnets = [
|
276 |
-
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
277 |
-
]
|
278 |
-
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
279 |
-
|
280 |
-
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
281 |
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
282 |
-
f"input_blocks.{i}.0.op.weight"
|
283 |
-
)
|
284 |
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
285 |
-
f"input_blocks.{i}.0.op.bias"
|
286 |
-
)
|
287 |
-
|
288 |
-
paths = renew_resnet_paths(resnets)
|
289 |
-
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
290 |
-
assign_to_checkpoint(
|
291 |
-
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
292 |
-
)
|
293 |
-
|
294 |
-
if len(attentions):
|
295 |
-
paths = renew_attention_paths(attentions)
|
296 |
-
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
297 |
-
assign_to_checkpoint(
|
298 |
-
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
299 |
-
)
|
300 |
-
|
301 |
-
resnet_0 = middle_blocks[0]
|
302 |
-
attentions = middle_blocks[1]
|
303 |
-
resnet_1 = middle_blocks[2]
|
304 |
-
|
305 |
-
resnet_0_paths = renew_resnet_paths(resnet_0)
|
306 |
-
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
-
|
308 |
-
resnet_1_paths = renew_resnet_paths(resnet_1)
|
309 |
-
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
310 |
-
|
311 |
-
attentions_paths = renew_attention_paths(attentions)
|
312 |
-
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
313 |
-
assign_to_checkpoint(
|
314 |
-
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
315 |
-
)
|
316 |
-
|
317 |
-
for i in range(num_output_blocks):
|
318 |
-
block_id = i // (config["layers_per_block"] + 1)
|
319 |
-
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
320 |
-
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
321 |
-
output_block_list = {}
|
322 |
-
|
323 |
-
for layer in output_block_layers:
|
324 |
-
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
325 |
-
if layer_id in output_block_list:
|
326 |
-
output_block_list[layer_id].append(layer_name)
|
327 |
-
else:
|
328 |
-
output_block_list[layer_id] = [layer_name]
|
329 |
-
|
330 |
-
if len(output_block_list) > 1:
|
331 |
-
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
332 |
-
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
333 |
-
|
334 |
-
resnet_0_paths = renew_resnet_paths(resnets)
|
335 |
-
paths = renew_resnet_paths(resnets)
|
336 |
-
|
337 |
-
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
338 |
-
assign_to_checkpoint(
|
339 |
-
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
340 |
-
)
|
341 |
-
|
342 |
-
# オリジナル:
|
343 |
-
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
344 |
-
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
345 |
-
|
346 |
-
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
347 |
-
for l in output_block_list.values():
|
348 |
-
l.sort()
|
349 |
-
|
350 |
-
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
351 |
-
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
352 |
-
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
353 |
-
f"output_blocks.{i}.{index}.conv.bias"
|
354 |
-
]
|
355 |
-
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
356 |
-
f"output_blocks.{i}.{index}.conv.weight"
|
357 |
-
]
|
358 |
-
|
359 |
-
# Clear attentions as they have been attributed above.
|
360 |
-
if len(attentions) == 2:
|
361 |
-
attentions = []
|
362 |
-
|
363 |
-
if len(attentions):
|
364 |
-
paths = renew_attention_paths(attentions)
|
365 |
-
meta_path = {
|
366 |
-
"old": f"output_blocks.{i}.1",
|
367 |
-
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
368 |
-
}
|
369 |
-
assign_to_checkpoint(
|
370 |
-
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
371 |
-
)
|
372 |
-
else:
|
373 |
-
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
374 |
-
for path in resnet_0_paths:
|
375 |
-
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
376 |
-
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
377 |
-
|
378 |
-
new_checkpoint[new_path] = unet_state_dict[old_path]
|
379 |
-
|
380 |
-
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
381 |
-
if v2:
|
382 |
-
linear_transformer_to_conv(new_checkpoint)
|
383 |
-
|
384 |
-
return new_checkpoint
|
385 |
-
|
386 |
-
|
387 |
-
def convert_ldm_vae_checkpoint(checkpoint, config):
|
388 |
-
# extract state dict for VAE
|
389 |
-
vae_state_dict = {}
|
390 |
-
vae_key = "first_stage_model."
|
391 |
-
keys = list(checkpoint.keys())
|
392 |
-
for key in keys:
|
393 |
-
if key.startswith(vae_key):
|
394 |
-
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
395 |
-
# if len(vae_state_dict) == 0:
|
396 |
-
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
397 |
-
# vae_state_dict = checkpoint
|
398 |
-
|
399 |
-
new_checkpoint = {}
|
400 |
-
|
401 |
-
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
402 |
-
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
403 |
-
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
404 |
-
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
405 |
-
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
406 |
-
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
407 |
-
|
408 |
-
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
409 |
-
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
410 |
-
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
411 |
-
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
412 |
-
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
413 |
-
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
414 |
-
|
415 |
-
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
416 |
-
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
417 |
-
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
418 |
-
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
419 |
-
|
420 |
-
# Retrieves the keys for the encoder down blocks only
|
421 |
-
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
422 |
-
down_blocks = {
|
423 |
-
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
424 |
-
}
|
425 |
-
|
426 |
-
# Retrieves the keys for the decoder up blocks only
|
427 |
-
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
428 |
-
up_blocks = {
|
429 |
-
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
430 |
-
}
|
431 |
-
|
432 |
-
for i in range(num_down_blocks):
|
433 |
-
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
434 |
-
|
435 |
-
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
436 |
-
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
437 |
-
f"encoder.down.{i}.downsample.conv.weight"
|
438 |
-
)
|
439 |
-
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
440 |
-
f"encoder.down.{i}.downsample.conv.bias"
|
441 |
-
)
|
442 |
-
|
443 |
-
paths = renew_vae_resnet_paths(resnets)
|
444 |
-
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
445 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
446 |
-
|
447 |
-
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
448 |
-
num_mid_res_blocks = 2
|
449 |
-
for i in range(1, num_mid_res_blocks + 1):
|
450 |
-
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
451 |
-
|
452 |
-
paths = renew_vae_resnet_paths(resnets)
|
453 |
-
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
454 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
-
|
456 |
-
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
457 |
-
paths = renew_vae_attention_paths(mid_attentions)
|
458 |
-
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
459 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
460 |
-
conv_attn_to_linear(new_checkpoint)
|
461 |
-
|
462 |
-
for i in range(num_up_blocks):
|
463 |
-
block_id = num_up_blocks - 1 - i
|
464 |
-
resnets = [
|
465 |
-
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
466 |
-
]
|
467 |
-
|
468 |
-
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
469 |
-
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
470 |
-
f"decoder.up.{block_id}.upsample.conv.weight"
|
471 |
-
]
|
472 |
-
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
473 |
-
f"decoder.up.{block_id}.upsample.conv.bias"
|
474 |
-
]
|
475 |
-
|
476 |
-
paths = renew_vae_resnet_paths(resnets)
|
477 |
-
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
478 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
479 |
-
|
480 |
-
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
481 |
-
num_mid_res_blocks = 2
|
482 |
-
for i in range(1, num_mid_res_blocks + 1):
|
483 |
-
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
484 |
-
|
485 |
-
paths = renew_vae_resnet_paths(resnets)
|
486 |
-
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
487 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
488 |
-
|
489 |
-
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
490 |
-
paths = renew_vae_attention_paths(mid_attentions)
|
491 |
-
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
492 |
-
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
493 |
-
conv_attn_to_linear(new_checkpoint)
|
494 |
-
return new_checkpoint
|
495 |
-
|
496 |
-
|
497 |
-
def create_unet_diffusers_config(v2):
|
498 |
-
"""
|
499 |
-
Creates a config for the diffusers based on the config of the LDM model.
|
500 |
-
"""
|
501 |
-
# unet_params = original_config.model.params.unet_config.params
|
502 |
-
|
503 |
-
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
504 |
-
|
505 |
-
down_block_types = []
|
506 |
-
resolution = 1
|
507 |
-
for i in range(len(block_out_channels)):
|
508 |
-
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
509 |
-
down_block_types.append(block_type)
|
510 |
-
if i != len(block_out_channels) - 1:
|
511 |
-
resolution *= 2
|
512 |
-
|
513 |
-
up_block_types = []
|
514 |
-
for i in range(len(block_out_channels)):
|
515 |
-
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
516 |
-
up_block_types.append(block_type)
|
517 |
-
resolution //= 2
|
518 |
-
|
519 |
-
config = dict(
|
520 |
-
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
521 |
-
in_channels=UNET_PARAMS_IN_CHANNELS,
|
522 |
-
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
523 |
-
down_block_types=tuple(down_block_types),
|
524 |
-
up_block_types=tuple(up_block_types),
|
525 |
-
block_out_channels=tuple(block_out_channels),
|
526 |
-
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
527 |
-
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
528 |
-
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
529 |
-
)
|
530 |
-
|
531 |
-
return config
|
532 |
-
|
533 |
-
|
534 |
-
def create_vae_diffusers_config():
|
535 |
-
"""
|
536 |
-
Creates a config for the diffusers based on the config of the LDM model.
|
537 |
-
"""
|
538 |
-
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
539 |
-
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
540 |
-
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
541 |
-
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
542 |
-
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
543 |
-
|
544 |
-
config = dict(
|
545 |
-
sample_size=VAE_PARAMS_RESOLUTION,
|
546 |
-
in_channels=VAE_PARAMS_IN_CHANNELS,
|
547 |
-
out_channels=VAE_PARAMS_OUT_CH,
|
548 |
-
down_block_types=tuple(down_block_types),
|
549 |
-
up_block_types=tuple(up_block_types),
|
550 |
-
block_out_channels=tuple(block_out_channels),
|
551 |
-
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
552 |
-
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
553 |
-
)
|
554 |
-
return config
|
555 |
-
|
556 |
-
|
557 |
-
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
558 |
-
keys = list(checkpoint.keys())
|
559 |
-
text_model_dict = {}
|
560 |
-
for key in keys:
|
561 |
-
if key.startswith("cond_stage_model.transformer"):
|
562 |
-
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
563 |
-
return text_model_dict
|
564 |
-
|
565 |
-
|
566 |
-
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
567 |
-
# 嫌になるくらい違うぞ!
|
568 |
-
def convert_key(key):
|
569 |
-
if not key.startswith("cond_stage_model"):
|
570 |
-
return None
|
571 |
-
|
572 |
-
# common conversion
|
573 |
-
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
574 |
-
key = key.replace("cond_stage_model.model.", "text_model.")
|
575 |
-
|
576 |
-
if "resblocks" in key:
|
577 |
-
# resblocks conversion
|
578 |
-
key = key.replace(".resblocks.", ".layers.")
|
579 |
-
if ".ln_" in key:
|
580 |
-
key = key.replace(".ln_", ".layer_norm")
|
581 |
-
elif ".mlp." in key:
|
582 |
-
key = key.replace(".c_fc.", ".fc1.")
|
583 |
-
key = key.replace(".c_proj.", ".fc2.")
|
584 |
-
elif '.attn.out_proj' in key:
|
585 |
-
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
586 |
-
elif '.attn.in_proj' in key:
|
587 |
-
key = None # 特殊なので後で処理する
|
588 |
-
else:
|
589 |
-
raise ValueError(f"unexpected key in SD: {key}")
|
590 |
-
elif '.positional_embedding' in key:
|
591 |
-
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
592 |
-
elif '.text_projection' in key:
|
593 |
-
key = None # 使われない???
|
594 |
-
elif '.logit_scale' in key:
|
595 |
-
key = None # 使われない???
|
596 |
-
elif '.token_embedding' in key:
|
597 |
-
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
598 |
-
elif '.ln_final' in key:
|
599 |
-
key = key.replace(".ln_final", ".final_layer_norm")
|
600 |
-
return key
|
601 |
-
|
602 |
-
keys = list(checkpoint.keys())
|
603 |
-
new_sd = {}
|
604 |
-
for key in keys:
|
605 |
-
# remove resblocks 23
|
606 |
-
if '.resblocks.23.' in key:
|
607 |
-
continue
|
608 |
-
new_key = convert_key(key)
|
609 |
-
if new_key is None:
|
610 |
-
continue
|
611 |
-
new_sd[new_key] = checkpoint[key]
|
612 |
-
|
613 |
-
# attnの変換
|
614 |
-
for key in keys:
|
615 |
-
if '.resblocks.23.' in key:
|
616 |
-
continue
|
617 |
-
if '.resblocks' in key and '.attn.in_proj_' in key:
|
618 |
-
# 三つに分割
|
619 |
-
values = torch.chunk(checkpoint[key], 3)
|
620 |
-
|
621 |
-
key_suffix = ".weight" if "weight" in key else ".bias"
|
622 |
-
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
623 |
-
key_pfx = key_pfx.replace("_weight", "")
|
624 |
-
key_pfx = key_pfx.replace("_bias", "")
|
625 |
-
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
626 |
-
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
627 |
-
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
628 |
-
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
629 |
-
|
630 |
-
# rename or add position_ids
|
631 |
-
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
632 |
-
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
633 |
-
# waifu diffusion v1.4
|
634 |
-
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
635 |
-
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
636 |
-
else:
|
637 |
-
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
638 |
-
|
639 |
-
new_sd["text_model.embeddings.position_ids"] = position_ids
|
640 |
-
return new_sd
|
641 |
-
|
642 |
-
# endregion
|
643 |
-
|
644 |
-
|
645 |
-
# region Diffusers->StableDiffusion の変換コード
|
646 |
-
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
647 |
-
|
648 |
-
def conv_transformer_to_linear(checkpoint):
|
649 |
-
keys = list(checkpoint.keys())
|
650 |
-
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
651 |
-
for key in keys:
|
652 |
-
if ".".join(key.split(".")[-2:]) in tf_keys:
|
653 |
-
if checkpoint[key].ndim > 2:
|
654 |
-
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
655 |
-
|
656 |
-
|
657 |
-
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
658 |
-
unet_conversion_map = [
|
659 |
-
# (stable-diffusion, HF Diffusers)
|
660 |
-
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
661 |
-
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
662 |
-
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
663 |
-
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
664 |
-
("input_blocks.0.0.weight", "conv_in.weight"),
|
665 |
-
("input_blocks.0.0.bias", "conv_in.bias"),
|
666 |
-
("out.0.weight", "conv_norm_out.weight"),
|
667 |
-
("out.0.bias", "conv_norm_out.bias"),
|
668 |
-
("out.2.weight", "conv_out.weight"),
|
669 |
-
("out.2.bias", "conv_out.bias"),
|
670 |
-
]
|
671 |
-
|
672 |
-
unet_conversion_map_resnet = [
|
673 |
-
# (stable-diffusion, HF Diffusers)
|
674 |
-
("in_layers.0", "norm1"),
|
675 |
-
("in_layers.2", "conv1"),
|
676 |
-
("out_layers.0", "norm2"),
|
677 |
-
("out_layers.3", "conv2"),
|
678 |
-
("emb_layers.1", "time_emb_proj"),
|
679 |
-
("skip_connection", "conv_shortcut"),
|
680 |
-
]
|
681 |
-
|
682 |
-
unet_conversion_map_layer = []
|
683 |
-
for i in range(4):
|
684 |
-
# loop over downblocks/upblocks
|
685 |
-
|
686 |
-
for j in range(2):
|
687 |
-
# loop over resnets/attentions for downblocks
|
688 |
-
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
689 |
-
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
690 |
-
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
691 |
-
|
692 |
-
if i < 3:
|
693 |
-
# no attention layers in down_blocks.3
|
694 |
-
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
695 |
-
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
696 |
-
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
697 |
-
|
698 |
-
for j in range(3):
|
699 |
-
# loop over resnets/attentions for upblocks
|
700 |
-
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
701 |
-
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
702 |
-
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
703 |
-
|
704 |
-
if i > 0:
|
705 |
-
# no attention layers in up_blocks.0
|
706 |
-
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
707 |
-
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
708 |
-
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
709 |
-
|
710 |
-
if i < 3:
|
711 |
-
# no downsample in down_blocks.3
|
712 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
713 |
-
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
714 |
-
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
715 |
-
|
716 |
-
# no upsample in up_blocks.3
|
717 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
718 |
-
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
719 |
-
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
720 |
-
|
721 |
-
hf_mid_atn_prefix = "mid_block.attentions.0."
|
722 |
-
sd_mid_atn_prefix = "middle_block.1."
|
723 |
-
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
724 |
-
|
725 |
-
for j in range(2):
|
726 |
-
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
727 |
-
sd_mid_res_prefix = f"middle_block.{2*j}."
|
728 |
-
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
729 |
-
|
730 |
-
# buyer beware: this is a *brittle* function,
|
731 |
-
# and correct output requires that all of these pieces interact in
|
732 |
-
# the exact order in which I have arranged them.
|
733 |
-
mapping = {k: k for k in unet_state_dict.keys()}
|
734 |
-
for sd_name, hf_name in unet_conversion_map:
|
735 |
-
mapping[hf_name] = sd_name
|
736 |
-
for k, v in mapping.items():
|
737 |
-
if "resnets" in k:
|
738 |
-
for sd_part, hf_part in unet_conversion_map_resnet:
|
739 |
-
v = v.replace(hf_part, sd_part)
|
740 |
-
mapping[k] = v
|
741 |
-
for k, v in mapping.items():
|
742 |
-
for sd_part, hf_part in unet_conversion_map_layer:
|
743 |
-
v = v.replace(hf_part, sd_part)
|
744 |
-
mapping[k] = v
|
745 |
-
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
746 |
-
|
747 |
-
if v2:
|
748 |
-
conv_transformer_to_linear(new_state_dict)
|
749 |
-
|
750 |
-
return new_state_dict
|
751 |
-
|
752 |
-
|
753 |
-
# ================#
|
754 |
-
# VAE Conversion #
|
755 |
-
# ================#
|
756 |
-
|
757 |
-
def reshape_weight_for_sd(w):
|
758 |
-
# convert HF linear weights to SD conv2d weights
|
759 |
-
return w.reshape(*w.shape, 1, 1)
|
760 |
-
|
761 |
-
|
762 |
-
def convert_vae_state_dict(vae_state_dict):
|
763 |
-
vae_conversion_map = [
|
764 |
-
# (stable-diffusion, HF Diffusers)
|
765 |
-
("nin_shortcut", "conv_shortcut"),
|
766 |
-
("norm_out", "conv_norm_out"),
|
767 |
-
("mid.attn_1.", "mid_block.attentions.0."),
|
768 |
-
]
|
769 |
-
|
770 |
-
for i in range(4):
|
771 |
-
# down_blocks have two resnets
|
772 |
-
for j in range(2):
|
773 |
-
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
774 |
-
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
775 |
-
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
776 |
-
|
777 |
-
if i < 3:
|
778 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
779 |
-
sd_downsample_prefix = f"down.{i}.downsample."
|
780 |
-
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
781 |
-
|
782 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
783 |
-
sd_upsample_prefix = f"up.{3-i}.upsample."
|
784 |
-
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
785 |
-
|
786 |
-
# up_blocks have three resnets
|
787 |
-
# also, up blocks in hf are numbered in reverse from sd
|
788 |
-
for j in range(3):
|
789 |
-
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
790 |
-
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
791 |
-
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
792 |
-
|
793 |
-
# this part accounts for mid blocks in both the encoder and the decoder
|
794 |
-
for i in range(2):
|
795 |
-
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
796 |
-
sd_mid_res_prefix = f"mid.block_{i+1}."
|
797 |
-
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
798 |
-
|
799 |
-
vae_conversion_map_attn = [
|
800 |
-
# (stable-diffusion, HF Diffusers)
|
801 |
-
("norm.", "group_norm."),
|
802 |
-
("q.", "query."),
|
803 |
-
("k.", "key."),
|
804 |
-
("v.", "value."),
|
805 |
-
("proj_out.", "proj_attn."),
|
806 |
-
]
|
807 |
-
|
808 |
-
mapping = {k: k for k in vae_state_dict.keys()}
|
809 |
-
for k, v in mapping.items():
|
810 |
-
for sd_part, hf_part in vae_conversion_map:
|
811 |
-
v = v.replace(hf_part, sd_part)
|
812 |
-
mapping[k] = v
|
813 |
-
for k, v in mapping.items():
|
814 |
-
if "attentions" in k:
|
815 |
-
for sd_part, hf_part in vae_conversion_map_attn:
|
816 |
-
v = v.replace(hf_part, sd_part)
|
817 |
-
mapping[k] = v
|
818 |
-
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
819 |
-
weights_to_convert = ["q", "k", "v", "proj_out"]
|
820 |
-
for k, v in new_state_dict.items():
|
821 |
-
for weight_name in weights_to_convert:
|
822 |
-
if f"mid.attn_1.{weight_name}.weight" in k:
|
823 |
-
# print(f"Reshaping {k} for SD format")
|
824 |
-
new_state_dict[k] = reshape_weight_for_sd(v)
|
825 |
-
|
826 |
-
return new_state_dict
|
827 |
-
|
828 |
-
|
829 |
-
# endregion
|
830 |
-
|
831 |
-
# region 自作のモデル読み書きなど
|
832 |
-
|
833 |
-
def is_safetensors(path):
|
834 |
-
return os.path.splitext(path)[1].lower() == '.safetensors'
|
835 |
-
|
836 |
-
|
837 |
-
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
838 |
-
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
839 |
-
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
840 |
-
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
841 |
-
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
842 |
-
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
843 |
-
]
|
844 |
-
|
845 |
-
if is_safetensors(ckpt_path):
|
846 |
-
checkpoint = None
|
847 |
-
state_dict = load_file(ckpt_path, "cpu")
|
848 |
-
else:
|
849 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
850 |
-
if "state_dict" in checkpoint:
|
851 |
-
state_dict = checkpoint["state_dict"]
|
852 |
-
else:
|
853 |
-
state_dict = checkpoint
|
854 |
-
checkpoint = None
|
855 |
-
|
856 |
-
key_reps = []
|
857 |
-
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
858 |
-
for key in state_dict.keys():
|
859 |
-
if key.startswith(rep_from):
|
860 |
-
new_key = rep_to + key[len(rep_from):]
|
861 |
-
key_reps.append((key, new_key))
|
862 |
-
|
863 |
-
for key, new_key in key_reps:
|
864 |
-
state_dict[new_key] = state_dict[key]
|
865 |
-
del state_dict[key]
|
866 |
-
|
867 |
-
return checkpoint, state_dict
|
868 |
-
|
869 |
-
|
870 |
-
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
871 |
-
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
872 |
-
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
873 |
-
if dtype is not None:
|
874 |
-
for k, v in state_dict.items():
|
875 |
-
if type(v) is torch.Tensor:
|
876 |
-
state_dict[k] = v.to(dtype)
|
877 |
-
|
878 |
-
# Convert the UNet2DConditionModel model.
|
879 |
-
unet_config = create_unet_diffusers_config(v2)
|
880 |
-
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
881 |
-
|
882 |
-
unet = UNet2DConditionModel(**unet_config)
|
883 |
-
info = unet.load_state_dict(converted_unet_checkpoint)
|
884 |
-
print("loading u-net:", info)
|
885 |
-
|
886 |
-
# Convert the VAE model.
|
887 |
-
vae_config = create_vae_diffusers_config()
|
888 |
-
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
889 |
-
|
890 |
-
vae = AutoencoderKL(**vae_config)
|
891 |
-
info = vae.load_state_dict(converted_vae_checkpoint)
|
892 |
-
print("loading vae:", info)
|
893 |
-
|
894 |
-
# convert text_model
|
895 |
-
if v2:
|
896 |
-
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
897 |
-
cfg = CLIPTextConfig(
|
898 |
-
vocab_size=49408,
|
899 |
-
hidden_size=1024,
|
900 |
-
intermediate_size=4096,
|
901 |
-
num_hidden_layers=23,
|
902 |
-
num_attention_heads=16,
|
903 |
-
max_position_embeddings=77,
|
904 |
-
hidden_act="gelu",
|
905 |
-
layer_norm_eps=1e-05,
|
906 |
-
dropout=0.0,
|
907 |
-
attention_dropout=0.0,
|
908 |
-
initializer_range=0.02,
|
909 |
-
initializer_factor=1.0,
|
910 |
-
pad_token_id=1,
|
911 |
-
bos_token_id=0,
|
912 |
-
eos_token_id=2,
|
913 |
-
model_type="clip_text_model",
|
914 |
-
projection_dim=512,
|
915 |
-
torch_dtype="float32",
|
916 |
-
transformers_version="4.25.0.dev0",
|
917 |
-
)
|
918 |
-
text_model = CLIPTextModel._from_config(cfg)
|
919 |
-
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
920 |
-
else:
|
921 |
-
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
922 |
-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
923 |
-
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
924 |
-
print("loading text encoder:", info)
|
925 |
-
|
926 |
-
return text_model, vae, unet
|
927 |
-
|
928 |
-
|
929 |
-
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
930 |
-
def convert_key(key):
|
931 |
-
# position_idsの除去
|
932 |
-
if ".position_ids" in key:
|
933 |
-
return None
|
934 |
-
|
935 |
-
# common
|
936 |
-
key = key.replace("text_model.encoder.", "transformer.")
|
937 |
-
key = key.replace("text_model.", "")
|
938 |
-
if "layers" in key:
|
939 |
-
# resblocks conversion
|
940 |
-
key = key.replace(".layers.", ".resblocks.")
|
941 |
-
if ".layer_norm" in key:
|
942 |
-
key = key.replace(".layer_norm", ".ln_")
|
943 |
-
elif ".mlp." in key:
|
944 |
-
key = key.replace(".fc1.", ".c_fc.")
|
945 |
-
key = key.replace(".fc2.", ".c_proj.")
|
946 |
-
elif '.self_attn.out_proj' in key:
|
947 |
-
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
948 |
-
elif '.self_attn.' in key:
|
949 |
-
key = None # 特殊なので後で処理する
|
950 |
-
else:
|
951 |
-
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
952 |
-
elif '.position_embedding' in key:
|
953 |
-
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
954 |
-
elif '.token_embedding' in key:
|
955 |
-
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
956 |
-
elif 'final_layer_norm' in key:
|
957 |
-
key = key.replace("final_layer_norm", "ln_final")
|
958 |
-
return key
|
959 |
-
|
960 |
-
keys = list(checkpoint.keys())
|
961 |
-
new_sd = {}
|
962 |
-
for key in keys:
|
963 |
-
new_key = convert_key(key)
|
964 |
-
if new_key is None:
|
965 |
-
continue
|
966 |
-
new_sd[new_key] = checkpoint[key]
|
967 |
-
|
968 |
-
# attnの変換
|
969 |
-
for key in keys:
|
970 |
-
if 'layers' in key and 'q_proj' in key:
|
971 |
-
# 三つを結合
|
972 |
-
key_q = key
|
973 |
-
key_k = key.replace("q_proj", "k_proj")
|
974 |
-
key_v = key.replace("q_proj", "v_proj")
|
975 |
-
|
976 |
-
value_q = checkpoint[key_q]
|
977 |
-
value_k = checkpoint[key_k]
|
978 |
-
value_v = checkpoint[key_v]
|
979 |
-
value = torch.cat([value_q, value_k, value_v])
|
980 |
-
|
981 |
-
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
982 |
-
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
983 |
-
new_sd[new_key] = value
|
984 |
-
|
985 |
-
# 最後の層などを捏造するか
|
986 |
-
if make_dummy_weights:
|
987 |
-
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
988 |
-
keys = list(new_sd.keys())
|
989 |
-
for key in keys:
|
990 |
-
if key.startswith("transformer.resblocks.22."):
|
991 |
-
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
992 |
-
|
993 |
-
# Diffusersに含まれない重みを作っておく
|
994 |
-
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
995 |
-
new_sd['logit_scale'] = torch.tensor(1)
|
996 |
-
|
997 |
-
return new_sd
|
998 |
-
|
999 |
-
|
1000 |
-
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
1001 |
-
if ckpt_path is not None:
|
1002 |
-
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1003 |
-
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1004 |
-
if checkpoint is None: # safetensors または state_dictのckpt
|
1005 |
-
checkpoint = {}
|
1006 |
-
strict = False
|
1007 |
-
else:
|
1008 |
-
strict = True
|
1009 |
-
if "state_dict" in state_dict:
|
1010 |
-
del state_dict["state_dict"]
|
1011 |
-
else:
|
1012 |
-
# 新しく作る
|
1013 |
-
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1014 |
-
checkpoint = {}
|
1015 |
-
state_dict = {}
|
1016 |
-
strict = False
|
1017 |
-
|
1018 |
-
def update_sd(prefix, sd):
|
1019 |
-
for k, v in sd.items():
|
1020 |
-
key = prefix + k
|
1021 |
-
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1022 |
-
if save_dtype is not None:
|
1023 |
-
v = v.detach().clone().to("cpu").to(save_dtype)
|
1024 |
-
state_dict[key] = v
|
1025 |
-
|
1026 |
-
# Convert the UNet model
|
1027 |
-
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1028 |
-
update_sd("model.diffusion_model.", unet_state_dict)
|
1029 |
-
|
1030 |
-
# Convert the text encoder model
|
1031 |
-
if v2:
|
1032 |
-
make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1033 |
-
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1034 |
-
update_sd("cond_stage_model.model.", text_enc_dict)
|
1035 |
-
else:
|
1036 |
-
text_enc_dict = text_encoder.state_dict()
|
1037 |
-
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1038 |
-
|
1039 |
-
# Convert the VAE
|
1040 |
-
if vae is not None:
|
1041 |
-
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1042 |
-
update_sd("first_stage_model.", vae_dict)
|
1043 |
-
|
1044 |
-
# Put together new checkpoint
|
1045 |
-
key_count = len(state_dict.keys())
|
1046 |
-
new_ckpt = {'state_dict': state_dict}
|
1047 |
-
|
1048 |
-
if 'epoch' in checkpoint:
|
1049 |
-
epochs += checkpoint['epoch']
|
1050 |
-
if 'global_step' in checkpoint:
|
1051 |
-
steps += checkpoint['global_step']
|
1052 |
-
|
1053 |
-
new_ckpt['epoch'] = epochs
|
1054 |
-
new_ckpt['global_step'] = steps
|
1055 |
-
|
1056 |
-
if is_safetensors(output_file):
|
1057 |
-
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1058 |
-
save_file(state_dict, output_file)
|
1059 |
-
else:
|
1060 |
-
torch.save(new_ckpt, output_file)
|
1061 |
-
|
1062 |
-
return key_count
|
1063 |
-
|
1064 |
-
|
1065 |
-
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1066 |
-
if pretrained_model_name_or_path is None:
|
1067 |
-
# load default settings for v1/v2
|
1068 |
-
if v2:
|
1069 |
-
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1070 |
-
else:
|
1071 |
-
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1072 |
-
|
1073 |
-
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1074 |
-
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1075 |
-
if vae is None:
|
1076 |
-
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1077 |
-
|
1078 |
-
pipeline = StableDiffusionPipeline(
|
1079 |
-
unet=unet,
|
1080 |
-
text_encoder=text_encoder,
|
1081 |
-
vae=vae,
|
1082 |
-
scheduler=scheduler,
|
1083 |
-
tokenizer=tokenizer,
|
1084 |
-
safety_checker=None,
|
1085 |
-
feature_extractor=None,
|
1086 |
-
requires_safety_checker=None,
|
1087 |
-
)
|
1088 |
-
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1089 |
-
|
1090 |
-
|
1091 |
-
VAE_PREFIX = "first_stage_model."
|
1092 |
-
|
1093 |
-
|
1094 |
-
def load_vae(vae_id, dtype):
|
1095 |
-
print(f"load VAE: {vae_id}")
|
1096 |
-
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1097 |
-
# Diffusers local/remote
|
1098 |
-
try:
|
1099 |
-
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1100 |
-
except EnvironmentError as e:
|
1101 |
-
print(f"exception occurs in loading vae: {e}")
|
1102 |
-
print("retry with subfolder='vae'")
|
1103 |
-
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1104 |
-
return vae
|
1105 |
-
|
1106 |
-
# local
|
1107 |
-
vae_config = create_vae_diffusers_config()
|
1108 |
-
|
1109 |
-
if vae_id.endswith(".bin"):
|
1110 |
-
# SD 1.5 VAE on Huggingface
|
1111 |
-
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1112 |
-
else:
|
1113 |
-
# StableDiffusion
|
1114 |
-
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1115 |
-
else torch.load(vae_id, map_location="cpu"))
|
1116 |
-
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1117 |
-
|
1118 |
-
# vae only or full model
|
1119 |
-
full_model = False
|
1120 |
-
for vae_key in vae_sd:
|
1121 |
-
if vae_key.startswith(VAE_PREFIX):
|
1122 |
-
full_model = True
|
1123 |
-
break
|
1124 |
-
if not full_model:
|
1125 |
-
sd = {}
|
1126 |
-
for key, value in vae_sd.items():
|
1127 |
-
sd[VAE_PREFIX + key] = value
|
1128 |
-
vae_sd = sd
|
1129 |
-
del sd
|
1130 |
-
|
1131 |
-
# Convert the VAE model.
|
1132 |
-
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1133 |
-
|
1134 |
-
vae = AutoencoderKL(**vae_config)
|
1135 |
-
vae.load_state_dict(converted_vae_checkpoint)
|
1136 |
-
return vae
|
1137 |
-
|
1138 |
-
# endregion
|
1139 |
-
|
1140 |
-
|
1141 |
-
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1142 |
-
max_width, max_height = max_reso
|
1143 |
-
max_area = (max_width // divisible) * (max_height // divisible)
|
1144 |
-
|
1145 |
-
resos = set()
|
1146 |
-
|
1147 |
-
size = int(math.sqrt(max_area)) * divisible
|
1148 |
-
resos.add((size, size))
|
1149 |
-
|
1150 |
-
size = min_size
|
1151 |
-
while size <= max_size:
|
1152 |
-
width = size
|
1153 |
-
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1154 |
-
resos.add((width, height))
|
1155 |
-
resos.add((height, width))
|
1156 |
-
|
1157 |
-
# # make additional resos
|
1158 |
-
# if width >= height and width - divisible >= min_size:
|
1159 |
-
# resos.add((width - divisible, height))
|
1160 |
-
# resos.add((height, width - divisible))
|
1161 |
-
# if height >= width and height - divisible >= min_size:
|
1162 |
-
# resos.add((width, height - divisible))
|
1163 |
-
# resos.add((height - divisible, width))
|
1164 |
-
|
1165 |
-
size += divisible
|
1166 |
-
|
1167 |
-
resos = list(resos)
|
1168 |
-
resos.sort()
|
1169 |
-
|
1170 |
-
aspect_ratios = [w / h for w, h in resos]
|
1171 |
-
return resos, aspect_ratios
|
1172 |
-
|
1173 |
-
|
1174 |
-
if __name__ == '__main__':
|
1175 |
-
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
1176 |
-
print(len(resos))
|
1177 |
-
print(resos)
|
1178 |
-
print(aspect_ratios)
|
1179 |
-
|
1180 |
-
ars = set()
|
1181 |
-
for ar in aspect_ratios:
|
1182 |
-
if ar in ars:
|
1183 |
-
print("error! duplicate ar:", ar)
|
1184 |
-
ars.add(ar)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
locon/kohya_utils.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
# part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
|
2 |
-
|
3 |
-
import hashlib
|
4 |
-
import safetensors
|
5 |
-
from io import BytesIO
|
6 |
-
|
7 |
-
|
8 |
-
def addnet_hash_legacy(b):
|
9 |
-
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
10 |
-
m = hashlib.sha256()
|
11 |
-
|
12 |
-
b.seek(0x100000)
|
13 |
-
m.update(b.read(0x10000))
|
14 |
-
return m.hexdigest()[0:8]
|
15 |
-
|
16 |
-
|
17 |
-
def addnet_hash_safetensors(b):
|
18 |
-
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
19 |
-
hash_sha256 = hashlib.sha256()
|
20 |
-
blksize = 1024 * 1024
|
21 |
-
|
22 |
-
b.seek(0)
|
23 |
-
header = b.read(8)
|
24 |
-
n = int.from_bytes(header, "little")
|
25 |
-
|
26 |
-
offset = n + 8
|
27 |
-
b.seek(offset)
|
28 |
-
for chunk in iter(lambda: b.read(blksize), b""):
|
29 |
-
hash_sha256.update(chunk)
|
30 |
-
|
31 |
-
return hash_sha256.hexdigest()
|
32 |
-
|
33 |
-
|
34 |
-
def precalculate_safetensors_hashes(tensors, metadata):
|
35 |
-
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
36 |
-
save time on indexing the model later."""
|
37 |
-
|
38 |
-
# Because writing user metadata to the file can change the result of
|
39 |
-
# sd_models.model_hash(), only retain the training metadata for purposes of
|
40 |
-
# calculating the hash, as they are meant to be immutable
|
41 |
-
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
42 |
-
|
43 |
-
bytes = safetensors.torch.save(tensors, metadata)
|
44 |
-
b = BytesIO(bytes)
|
45 |
-
|
46 |
-
model_hash = addnet_hash_safetensors(b)
|
47 |
-
legacy_hash = addnet_hash_legacy(b)
|
48 |
-
return model_hash, legacy_hash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
locon/locon.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
|
8 |
-
class LoConModule(nn.Module):
|
9 |
-
"""
|
10 |
-
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
11 |
-
"""
|
12 |
-
|
13 |
-
def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
14 |
-
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
15 |
-
super().__init__()
|
16 |
-
self.lora_name = lora_name
|
17 |
-
self.lora_dim = lora_dim
|
18 |
-
|
19 |
-
if org_module.__class__.__name__ == 'Conv2d':
|
20 |
-
# For general LoCon
|
21 |
-
in_dim = org_module.in_channels
|
22 |
-
k_size = org_module.kernel_size
|
23 |
-
stride = org_module.stride
|
24 |
-
padding = org_module.padding
|
25 |
-
out_dim = org_module.out_channels
|
26 |
-
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
27 |
-
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
28 |
-
else:
|
29 |
-
in_dim = org_module.in_features
|
30 |
-
out_dim = org_module.out_features
|
31 |
-
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
32 |
-
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
33 |
-
|
34 |
-
if type(alpha) == torch.Tensor:
|
35 |
-
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
36 |
-
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
37 |
-
self.scale = alpha / self.lora_dim
|
38 |
-
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
39 |
-
|
40 |
-
# same as microsoft's
|
41 |
-
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
42 |
-
torch.nn.init.zeros_(self.lora_up.weight)
|
43 |
-
|
44 |
-
self.multiplier = multiplier
|
45 |
-
self.org_module = org_module # remove in applying
|
46 |
-
|
47 |
-
def apply_to(self):
|
48 |
-
self.org_forward = self.org_module.forward
|
49 |
-
self.org_module.forward = self.forward
|
50 |
-
del self.org_module
|
51 |
-
|
52 |
-
def forward(self, x):
|
53 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
locon/locon_kohya.py
DELETED
@@ -1,243 +0,0 @@
|
|
1 |
-
# LoCon network module
|
2 |
-
# reference:
|
3 |
-
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
-
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
-
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
-
|
7 |
-
import math
|
8 |
-
import os
|
9 |
-
from typing import List
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from .kohya_utils import *
|
13 |
-
from .locon import LoConModule
|
14 |
-
|
15 |
-
|
16 |
-
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
17 |
-
if network_dim is None:
|
18 |
-
network_dim = 4 # default
|
19 |
-
conv_dim = kwargs.get('conv_dim', network_dim)
|
20 |
-
conv_alpha = kwargs.get('conv_alpha', network_alpha)
|
21 |
-
network = LoRANetwork(
|
22 |
-
text_encoder, unet,
|
23 |
-
multiplier=multiplier,
|
24 |
-
lora_dim=network_dim, conv_lora_dim=conv_dim,
|
25 |
-
alpha=network_alpha, conv_alpha=conv_alpha
|
26 |
-
)
|
27 |
-
return network
|
28 |
-
|
29 |
-
|
30 |
-
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
31 |
-
if os.path.splitext(file)[1] == '.safetensors':
|
32 |
-
from safetensors.torch import load_file, safe_open
|
33 |
-
weights_sd = load_file(file)
|
34 |
-
else:
|
35 |
-
weights_sd = torch.load(file, map_location='cpu')
|
36 |
-
|
37 |
-
# get dim (rank)
|
38 |
-
network_alpha = None
|
39 |
-
network_dim = None
|
40 |
-
for key, value in weights_sd.items():
|
41 |
-
if network_alpha is None and 'alpha' in key:
|
42 |
-
network_alpha = value
|
43 |
-
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
44 |
-
network_dim = value.size()[0]
|
45 |
-
|
46 |
-
if network_alpha is None:
|
47 |
-
network_alpha = network_dim
|
48 |
-
|
49 |
-
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
50 |
-
network.weights_sd = weights_sd
|
51 |
-
return network
|
52 |
-
|
53 |
-
torch.nn.Conv2d
|
54 |
-
class LoRANetwork(torch.nn.Module):
|
55 |
-
'''
|
56 |
-
LoRA + LoCon
|
57 |
-
'''
|
58 |
-
# Ignore proj_in or proj_out, their channels is only a few.
|
59 |
-
UNET_TARGET_REPLACE_MODULE = [
|
60 |
-
"Transformer2DModel",
|
61 |
-
"Attention",
|
62 |
-
"ResnetBlock2D",
|
63 |
-
"Downsample2D",
|
64 |
-
"Upsample2D"
|
65 |
-
]
|
66 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
67 |
-
LORA_PREFIX_UNET = 'lora_unet'
|
68 |
-
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
69 |
-
|
70 |
-
def __init__(
|
71 |
-
self,
|
72 |
-
text_encoder, unet,
|
73 |
-
multiplier=1.0,
|
74 |
-
lora_dim=4, conv_lora_dim=4,
|
75 |
-
alpha=1, conv_alpha=1
|
76 |
-
) -> None:
|
77 |
-
super().__init__()
|
78 |
-
self.multiplier = multiplier
|
79 |
-
self.lora_dim = lora_dim
|
80 |
-
self.conv_lora_dim = int(conv_lora_dim)
|
81 |
-
if self.conv_lora_dim != self.lora_dim:
|
82 |
-
print('Apply different lora dim for conv layer')
|
83 |
-
print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
|
84 |
-
self.alpha = alpha
|
85 |
-
self.conv_alpha = float(conv_alpha)
|
86 |
-
if self.alpha != self.conv_alpha:
|
87 |
-
print('Apply different alpha value for conv layer')
|
88 |
-
print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
|
89 |
-
|
90 |
-
# create module instances
|
91 |
-
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoConModule]:
|
92 |
-
print('Create LoCon Module')
|
93 |
-
loras = []
|
94 |
-
for name, module in root_module.named_modules():
|
95 |
-
if module.__class__.__name__ in target_replace_modules:
|
96 |
-
for child_name, child_module in module.named_modules():
|
97 |
-
lora_name = prefix + '.' + name + '.' + child_name
|
98 |
-
lora_name = lora_name.replace('.', '_')
|
99 |
-
if child_module.__class__.__name__ == 'Linear':
|
100 |
-
lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
101 |
-
elif child_module.__class__.__name__ == 'Conv2d':
|
102 |
-
k_size, *_ = child_module.kernel_size
|
103 |
-
if k_size==1:
|
104 |
-
lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
105 |
-
else:
|
106 |
-
lora = LoConModule(lora_name, child_module, self.multiplier, self.conv_lora_dim, self.conv_alpha)
|
107 |
-
else:
|
108 |
-
continue
|
109 |
-
loras.append(lora)
|
110 |
-
return loras
|
111 |
-
|
112 |
-
self.text_encoder_loras = create_modules(
|
113 |
-
LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
114 |
-
text_encoder,
|
115 |
-
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
116 |
-
)
|
117 |
-
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
118 |
-
|
119 |
-
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
120 |
-
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
121 |
-
|
122 |
-
self.weights_sd = None
|
123 |
-
|
124 |
-
# assertion
|
125 |
-
names = set()
|
126 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
127 |
-
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
128 |
-
names.add(lora.lora_name)
|
129 |
-
|
130 |
-
def set_multiplier(self, multiplier):
|
131 |
-
self.multiplier = multiplier
|
132 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
133 |
-
lora.multiplier = self.multiplier
|
134 |
-
|
135 |
-
def load_weights(self, file):
|
136 |
-
if os.path.splitext(file)[1] == '.safetensors':
|
137 |
-
from safetensors.torch import load_file, safe_open
|
138 |
-
self.weights_sd = load_file(file)
|
139 |
-
else:
|
140 |
-
self.weights_sd = torch.load(file, map_location='cpu')
|
141 |
-
|
142 |
-
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
143 |
-
if self.weights_sd:
|
144 |
-
weights_has_text_encoder = weights_has_unet = False
|
145 |
-
for key in self.weights_sd.keys():
|
146 |
-
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
147 |
-
weights_has_text_encoder = True
|
148 |
-
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
149 |
-
weights_has_unet = True
|
150 |
-
|
151 |
-
if apply_text_encoder is None:
|
152 |
-
apply_text_encoder = weights_has_text_encoder
|
153 |
-
else:
|
154 |
-
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
155 |
-
|
156 |
-
if apply_unet is None:
|
157 |
-
apply_unet = weights_has_unet
|
158 |
-
else:
|
159 |
-
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
160 |
-
else:
|
161 |
-
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
162 |
-
|
163 |
-
if apply_text_encoder:
|
164 |
-
print("enable LoRA for text encoder")
|
165 |
-
else:
|
166 |
-
self.text_encoder_loras = []
|
167 |
-
|
168 |
-
if apply_unet:
|
169 |
-
print("enable LoRA for U-Net")
|
170 |
-
else:
|
171 |
-
self.unet_loras = []
|
172 |
-
|
173 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
174 |
-
lora.apply_to()
|
175 |
-
self.add_module(lora.lora_name, lora)
|
176 |
-
|
177 |
-
if self.weights_sd:
|
178 |
-
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
179 |
-
info = self.load_state_dict(self.weights_sd, False)
|
180 |
-
print(f"weights are loaded: {info}")
|
181 |
-
|
182 |
-
def enable_gradient_checkpointing(self):
|
183 |
-
# not supported
|
184 |
-
pass
|
185 |
-
|
186 |
-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
187 |
-
def enumerate_params(loras):
|
188 |
-
params = []
|
189 |
-
for lora in loras:
|
190 |
-
params.extend(lora.parameters())
|
191 |
-
return params
|
192 |
-
|
193 |
-
self.requires_grad_(True)
|
194 |
-
all_params = []
|
195 |
-
|
196 |
-
if self.text_encoder_loras:
|
197 |
-
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
198 |
-
if text_encoder_lr is not None:
|
199 |
-
param_data['lr'] = text_encoder_lr
|
200 |
-
all_params.append(param_data)
|
201 |
-
|
202 |
-
if self.unet_loras:
|
203 |
-
param_data = {'params': enumerate_params(self.unet_loras)}
|
204 |
-
if unet_lr is not None:
|
205 |
-
param_data['lr'] = unet_lr
|
206 |
-
all_params.append(param_data)
|
207 |
-
|
208 |
-
return all_params
|
209 |
-
|
210 |
-
def prepare_grad_etc(self, text_encoder, unet):
|
211 |
-
self.requires_grad_(True)
|
212 |
-
|
213 |
-
def on_epoch_start(self, text_encoder, unet):
|
214 |
-
self.train()
|
215 |
-
|
216 |
-
def get_trainable_params(self):
|
217 |
-
return self.parameters()
|
218 |
-
|
219 |
-
def save_weights(self, file, dtype, metadata):
|
220 |
-
if metadata is not None and len(metadata) == 0:
|
221 |
-
metadata = None
|
222 |
-
|
223 |
-
state_dict = self.state_dict()
|
224 |
-
|
225 |
-
if dtype is not None:
|
226 |
-
for key in list(state_dict.keys()):
|
227 |
-
v = state_dict[key]
|
228 |
-
v = v.detach().clone().to("cpu").to(dtype)
|
229 |
-
state_dict[key] = v
|
230 |
-
|
231 |
-
if os.path.splitext(file)[1] == '.safetensors':
|
232 |
-
from safetensors.torch import save_file
|
233 |
-
|
234 |
-
# Precalculate model hashes to save time on indexing
|
235 |
-
if metadata is None:
|
236 |
-
metadata = {}
|
237 |
-
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
238 |
-
metadata["sshs_model_hash"] = model_hash
|
239 |
-
metadata["sshs_legacy_hash"] = legacy_hash
|
240 |
-
|
241 |
-
save_file(state_dict, file, metadata)
|
242 |
-
else:
|
243 |
-
torch.save(state_dict, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
locon/utils.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
import torch.linalg as linalg
|
6 |
-
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
|
10 |
-
def extract_conv(
|
11 |
-
weight: nn.Parameter|torch.Tensor,
|
12 |
-
lora_rank = 8
|
13 |
-
) -> tuple[nn.Parameter, nn.Parameter]:
|
14 |
-
out_ch, in_ch, kernel_size, _ = weight.shape
|
15 |
-
lora_rank = min(out_ch, in_ch, lora_rank)
|
16 |
-
|
17 |
-
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
|
18 |
-
|
19 |
-
U = U[:, :lora_rank]
|
20 |
-
S = S[:lora_rank]
|
21 |
-
U = U @ torch.diag(S)
|
22 |
-
Vh = Vh[:lora_rank, :]
|
23 |
-
|
24 |
-
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
|
25 |
-
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
|
26 |
-
del U, S, Vh, weight
|
27 |
-
return extract_weight_A, extract_weight_B
|
28 |
-
|
29 |
-
|
30 |
-
def merge_conv(
|
31 |
-
weight_a: nn.Parameter|torch.Tensor,
|
32 |
-
weight_b: nn.Parameter|torch.Tensor,
|
33 |
-
):
|
34 |
-
rank, in_ch, kernel_size, k_ = weight_a.shape
|
35 |
-
out_ch, rank_, _, _ = weight_b.shape
|
36 |
-
|
37 |
-
assert rank == rank_ and kernel_size == k_
|
38 |
-
|
39 |
-
merged = weight_b.reshape(out_ch, -1) @ weight_a.reshape(rank, -1)
|
40 |
-
weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
|
41 |
-
return weight
|
42 |
-
|
43 |
-
|
44 |
-
def extract_linear(
|
45 |
-
weight: nn.Parameter|torch.Tensor,
|
46 |
-
lora_rank = 8
|
47 |
-
) -> tuple[nn.Parameter, nn.Parameter]:
|
48 |
-
out_ch, in_ch = weight.shape
|
49 |
-
lora_rank = min(out_ch, in_ch, lora_rank)
|
50 |
-
|
51 |
-
U, S, Vh = linalg.svd(weight)
|
52 |
-
|
53 |
-
U = U[:, :lora_rank]
|
54 |
-
S = S[:lora_rank]
|
55 |
-
U = U @ torch.diag(S)
|
56 |
-
Vh = Vh[:lora_rank, :]
|
57 |
-
|
58 |
-
extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
|
59 |
-
extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
|
60 |
-
del U, S, Vh, weight
|
61 |
-
return extract_weight_A, extract_weight_B
|
62 |
-
|
63 |
-
|
64 |
-
def merge_linear(
|
65 |
-
weight_a: nn.Parameter|torch.Tensor,
|
66 |
-
weight_b: nn.Parameter|torch.Tensor,
|
67 |
-
):
|
68 |
-
rank, in_ch = weight_a.shape
|
69 |
-
out_ch, rank_ = weight_b.shape
|
70 |
-
|
71 |
-
assert rank == rank_
|
72 |
-
|
73 |
-
weight = weight_b @ weight_a
|
74 |
-
return weight
|
75 |
-
|
76 |
-
|
77 |
-
def extract_diff(
|
78 |
-
base_model,
|
79 |
-
db_model,
|
80 |
-
lora_dim=4,
|
81 |
-
conv_lora_dim=4,
|
82 |
-
extract_device = 'cuda',
|
83 |
-
):
|
84 |
-
UNET_TARGET_REPLACE_MODULE = [
|
85 |
-
"Transformer2DModel",
|
86 |
-
"Attention",
|
87 |
-
"ResnetBlock2D",
|
88 |
-
"Downsample2D",
|
89 |
-
"Upsample2D"
|
90 |
-
]
|
91 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
-
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
-
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
-
def make_state_dict(
|
95 |
-
prefix,
|
96 |
-
root_module: torch.nn.Module,
|
97 |
-
target_module: torch.nn.Module,
|
98 |
-
target_replace_modules
|
99 |
-
):
|
100 |
-
loras = {}
|
101 |
-
temp = {}
|
102 |
-
|
103 |
-
for name, module in root_module.named_modules():
|
104 |
-
if module.__class__.__name__ in target_replace_modules:
|
105 |
-
temp[name] = {}
|
106 |
-
for child_name, child_module in module.named_modules():
|
107 |
-
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
108 |
-
continue
|
109 |
-
temp[name][child_name] = child_module.weight
|
110 |
-
|
111 |
-
for name, module in tqdm(list(target_module.named_modules())):
|
112 |
-
if name in temp:
|
113 |
-
weights = temp[name]
|
114 |
-
for child_name, child_module in module.named_modules():
|
115 |
-
lora_name = prefix + '.' + name + '.' + child_name
|
116 |
-
lora_name = lora_name.replace('.', '_')
|
117 |
-
if child_module.__class__.__name__ == 'Linear':
|
118 |
-
extract_a, extract_b = extract_linear(
|
119 |
-
(child_module.weight - weights[child_name]),
|
120 |
-
lora_dim
|
121 |
-
)
|
122 |
-
elif child_module.__class__.__name__ == 'Conv2d':
|
123 |
-
extract_a, extract_b = extract_conv(
|
124 |
-
(child_module.weight - weights[child_name]),
|
125 |
-
conv_lora_dim
|
126 |
-
)
|
127 |
-
else:
|
128 |
-
continue
|
129 |
-
|
130 |
-
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().half()
|
131 |
-
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().half()
|
132 |
-
loras[f'{lora_name}.alpha'] = torch.Tensor([int(extract_a.shape[0])]).detach().cpu().half()
|
133 |
-
del extract_a, extract_b
|
134 |
-
return loras
|
135 |
-
|
136 |
-
text_encoder_loras = make_state_dict(
|
137 |
-
LORA_PREFIX_TEXT_ENCODER,
|
138 |
-
base_model[0], db_model[0],
|
139 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
140 |
-
)
|
141 |
-
|
142 |
-
unet_loras = make_state_dict(
|
143 |
-
LORA_PREFIX_UNET,
|
144 |
-
base_model[2], db_model[2],
|
145 |
-
UNET_TARGET_REPLACE_MODULE
|
146 |
-
)
|
147 |
-
print(len(text_encoder_loras), len(unet_loras))
|
148 |
-
return text_encoder_loras|unet_loras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|